tf.compat.v1.distributions.RegisterKL

View source on GitHub

Decorator to register a KL divergence implementation function.

tf.compat.v1.distributions.RegisterKL(
    dist_cls_a, dist_cls_b
)

Usage:

@distributions.RegisterKL(distributions.Normal, distributions.Normal) def _kl_normal_mvn(norm_a, norm_b): # Return KL(norm_a || norm_b)

Args:

Methods

__call__

View source

__call__(
    kl_fn
)

Perform the KL registration.

Args:

Returns:

kl_fn

Raises: