![]() |
Class RegisterKL
Decorator to register a KL divergence implementation function.
Aliases:
Usage:
@distributions.RegisterKL(distributions.Normal, distributions.Normal) def _kl_normal_mvn(norm_a, norm_b): # Return KL(norm_a || norm_b)
__init__
__init__(
dist_cls_a,
dist_cls_b
)
Initialize the KL registrar. (deprecated)
Args:
dist_cls_a
: the class of the first argument of the KL divergence.dist_cls_b
: the class of the second argument of the KL divergence.
Methods
tf.distributions.RegisterKL.__call__
__call__(kl_fn)
Perform the KL registration.
Args:
kl_fn
: The function to use for the KL divergence.
Returns:
kl_fn
Raises:
TypeError
: if kl_fn is not a callable.ValueError
: if a KL divergence function has already been registered for the given argument classes.