chainer.register_kl

chainer.register_kl(Dist1, Dist2)[source]

Decorator to register KL divergence function.

This decorator registers a function which computes Kullback-Leibler divergence. This function will be called by kl_divergence() based on the argument types.

Parameters
  • Dist1 (type) – type of a class inherit from Distribution to calculate KL divergence.

  • Dist2 (type) – type of a class inherit from Distribution to calculate KL divergence.

The decorated functoion takes an instance of Dist1 and Dist2 and returns KL divergence value.

Example

This is a simple example to register KL divergence. A function to calculate a KL divergence value between an instance of Dist1 and an instance of Dist2 is registered.

from chainer import distributions
@distributions.register_kl(Dist1, Dist2)
def _kl_dist1_dist2(dist1, dist2):
    return KL