description: Computes f(*args) and its gradients wrt to *args.
View source on GitHub |
Computes f(*args)
and its gradients wrt to *args
.
tfp.substrates.jax.math.value_and_gradient(
f, *args, output_gradients=None, use_gradient_tape=False,
auto_unpack_single_arg=True, has_aux=False, name=None, **kwargs
)