diff --git a/tensorflow_probability/python/internal/custom_gradient.py b/tensorflow_probability/python/internal/custom_gradient.py index 5cbf4896af..4931b41b5f 100644 --- a/tensorflow_probability/python/internal/custom_gradient.py +++ b/tensorflow_probability/python/internal/custom_gradient.py @@ -88,7 +88,7 @@ def f_wrapped(*args, **kwargs): args = args[1:] val, aux = vjp_fwd(*reconstruct_args, **kwargs) - def vjp_bwd_wrapped(*g, **kwargs): + def vjp_bwd_wrapped(*g): # We don't want to use an explicit `variables` arg, because TF will # complain if the wrapped function doesn't actually have variables # in it. TF will only specify this arg if there are variables.