View source on GitHub
|
Decorator to register a KL divergence implementation function.
tf.compat.v1.distributions.RegisterKL(
dist_cls_a, dist_cls_b
)
@distributions.RegisterKL(distributions.Normal, distributions.Normal) def _kl_normal_mvn(norm_a, norm_b): # Return KL(norm_a || norm_b)
dist_cls_a: the class of the first argument of the KL divergence.dist_cls_b: the class of the second argument of the KL divergence.__call____call__(
kl_fn
)
Perform the KL registration.
kl_fn: The function to use for the KL divergence.kl_fn
TypeError: if kl_fn is not a callable.ValueError: if a KL divergence function has already been registered for
the given argument classes.