tf.compat.v1.distributions.Categorical

View source on GitHub

Categorical distribution.

Inherits From: Distribution

tf.compat.v1.distributions.Categorical(
    logits=None, probs=None, dtype=tf.dtypes.int32, validate_args=False,
    allow_nan_stats=True, name='Categorical'
)

The Categorical distribution is parameterized by either probabilities or log-probabilities of a set of K classes. It is defined over the integers {0, 1, ..., K}.

The Categorical distribution is closely related to the OneHotCategorical and Multinomial distributions. The Categorical distribution can be intuited as generating samples according to argmax{ OneHotCategorical(probs) } itself being identical to argmax{ Multinomial(probs, total_count=1) }.

Mathematical Details

The probability mass function (pmf) is,

pmf(k; pi) = prod_j pi_j**[k == j]

Pitfalls

The number of classes, K, must not exceed: - the largest integer representable by self.dtype, i.e., 2**(mantissa_bits+1) (IEEE 754), - the maximum Tensor index, i.e., 2**31-1.

In other words,

K <= min(2**31-1, {
  tf.float16: 2**11,
  tf.float32: 2**24,
  tf.float64: 2**53 }[param.dtype])

Note: This condition is validated only when self.validate_args = True.

Examples

Creates a 3-class distribution with the 2nd class being most likely.

dist = Categorical(probs=[0.1, 0.5, 0.4])
n = 1e4
empirical_prob = tf.cast(
    tf.histogram_fixed_width(
      dist.sample(int(n)),
      [0., 2],
      nbins=3),
    dtype=tf.float32) / n
# ==> array([ 0.1005,  0.5037,  0.3958], dtype=float32)

Creates a 3-class distribution with the 2nd class being most likely. Parameterized by logits rather than probabilities.

dist = Categorical(logits=np.log([0.1, 0.5, 0.4])
n = 1e4
empirical_prob = tf.cast(
    tf.histogram_fixed_width(
      dist.sample(int(n)),
      [0., 2],
      nbins=3),
    dtype=tf.float32) / n
# ==> array([0.1045,  0.5047, 0.3908], dtype=float32)

Creates a 3-class distribution with the 3rd class being most likely. The distribution functions can be evaluated on counts.

# counts is a scalar.
p = [0.1, 0.4, 0.5]
dist = Categorical(probs=p)
dist.prob(0)  # Shape []

# p will be broadcast to [[0.1, 0.4, 0.5], [0.1, 0.4, 0.5]] to match counts.
counts = [1, 0]
dist.prob(counts)  # Shape [2]

# p will be broadcast to shape [3, 5, 7, 3] to match counts.
counts = [[...]] # Shape [5, 7, 3]
dist.prob(counts)  # Shape [5, 7, 3]

Args:

Attributes:

Methods

batch_shape_tensor

View source

batch_shape_tensor(
    name='batch_shape_tensor'
)

Shape of a single sample from a single event index as a 1-D Tensor.

The batch dimensions are indexes into independent, non-identical parameterizations of this distribution.

Args:

Returns:

cdf

View source

cdf(
    value, name='cdf'
)

Cumulative distribution function.

Given random variable X, the cumulative distribution function cdf is:

cdf(x) := P[X <= x]

Args:

Returns:

copy

View source

copy(
    **override_parameters_kwargs
)

Creates a deep copy of the distribution.

Note: the copy distribution may continue to depend on the original initialization arguments.

Args:

Returns:

covariance

View source

covariance(
    name='covariance'
)

Covariance.

Covariance is (possibly) defined only for non-scalar-event distributions.

For example, for a length-k, vector-valued distribution, it is calculated as,

Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]

where Cov is a (batch of) k x k matrix, 0 <= (i, j) < k, and E denotes expectation.

Alternatively, for non-vector, multivariate distributions (e.g., matrix-valued, Wishart), Covariance shall return a (batch of) matrices under some vectorization of the events, i.e.,

Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]

where Cov is a (batch of) k' x k' matrices, 0 <= (i, j) < k' = reduce_prod(event_shape), and Vec is some function mapping indices of this distribution's event dimensions to indices of a length-k' vector.

Args:

Returns:

cross_entropy

View source

cross_entropy(
    other, name='cross_entropy'
)

Computes the (Shannon) cross entropy.

Denote this distribution (self) by P and the other distribution by Q. Assuming P, Q are absolutely continuous with respect to one another and permit densities p(x) dr(x) and q(x) dr(x), (Shanon) cross entropy is defined as:

H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)

where F denotes the support of the random variable X ~ P.

Args:

Returns:

entropy

View source

entropy(
    name='entropy'
)

Shannon entropy in nats.

event_shape_tensor

View source

event_shape_tensor(
    name='event_shape_tensor'
)

Shape of a single sample from a single batch as a 1-D int32 Tensor.

Args:

Returns:

is_scalar_batch

View source

is_scalar_batch(
    name='is_scalar_batch'
)

Indicates that batch_shape == [].

Args:

Returns:

is_scalar_event

View source

is_scalar_event(
    name='is_scalar_event'
)

Indicates that event_shape == [].

Args:

Returns:

kl_divergence

View source

kl_divergence(
    other, name='kl_divergence'
)

Computes the Kullback--Leibler divergence.

Denote this distribution (self) by p and the other distribution by q. Assuming p, q are absolutely continuous with respect to reference measure r, the KL divergence is defined as:

KL[p, q] = E_p[log(p(X)/q(X))]
         = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
         = H[p, q] - H[p]

where F denotes the support of the random variable X ~ p, H[., .] denotes (Shanon) cross entropy, and H[.] denotes (Shanon) entropy.

Args:

Returns:

log_cdf

View source

log_cdf(
    value, name='log_cdf'
)

Log cumulative distribution function.

Given random variable X, the cumulative distribution function cdf is:

log_cdf(x) := Log[ P[X <= x] ]

Often, a numerical approximation can be used for log_cdf(x) that yields a more accurate answer than simply taking the logarithm of the cdf when x << -1.

Args:

Returns:

log_prob

View source

log_prob(
    value, name='log_prob'
)

Log probability density/mass function.

Args:

Returns:

log_survival_function

View source

log_survival_function(
    value, name='log_survival_function'
)

Log survival function.

Given random variable X, the survival function is defined:

log_survival_function(x) = Log[ P[X > x] ]
                         = Log[ 1 - P[X <= x] ]
                         = Log[ 1 - cdf(x) ]

Typically, different numerical approximations can be used for the log survival function, which are more accurate than 1 - cdf(x) when x >> 1.

Args:

Returns:

Tensor of shape sample_shape(x) + self.batch_shape with values of type self.dtype.

mean

View source

mean(
    name='mean'
)

Mean.

mode

View source

mode(
    name='mode'
)

Mode.

param_shapes

View source

@classmethod
param_shapes(
    sample_shape, name='DistributionParamShapes'
)

Shapes of parameters given the desired shape of a call to sample().

This is a class method that describes what key/value arguments are required to instantiate the given Distribution so that a particular shape is returned for that instance's call to sample().

Subclasses should override class method _param_shapes.

Args:

Returns:

dict of parameter name to Tensor shapes.

param_static_shapes

View source

@classmethod
param_static_shapes(
    sample_shape
)

param_shapes with static (i.e. TensorShape) shapes.

This is a class method that describes what key/value arguments are required to instantiate the given Distribution so that a particular shape is returned for that instance's call to sample(). Assumes that the sample's shape is known statically.

Subclasses should override class method _param_shapes to return constant-valued tensors when constant values are fed.

Args:

Returns:

dict of parameter name to TensorShape.

Raises:

prob

View source

prob(
    value, name='prob'
)

Probability density/mass function.

Args:

Returns:

quantile

View source

quantile(
    value, name='quantile'
)

Quantile function. Aka "inverse cdf" or "percent point function".

Given random variable X and p in [0, 1], the quantile is:

quantile(p) := x such that P[X <= x] == p

Args:

Returns:

sample

View source

sample(
    sample_shape=(), seed=None, name='sample'
)

Generate samples of the specified shape.

Note that a call to sample() without arguments will generate a single sample.

Args:

Returns:

stddev

View source

stddev(
    name='stddev'
)

Standard deviation.

Standard deviation is defined as,

stddev = E[(X - E[X])**2]**0.5

where X is the random variable associated with this distribution, E denotes expectation, and stddev.shape = batch_shape + event_shape.

Args:

Returns:

survival_function

View source

survival_function(
    value, name='survival_function'
)

Survival function.

Given random variable X, the survival function is defined:

survival_function(x) = P[X > x]
                     = 1 - P[X <= x]
                     = 1 - cdf(x).

Args:

Returns:

Tensor of shape sample_shape(x) + self.batch_shape with values of type self.dtype.

variance

View source

variance(
    name='variance'
)

Variance.

Variance is defined as,

Var = E[(X - E[X])**2]

where X is the random variable associated with this distribution, E denotes expectation, and Var.shape = batch_shape + event_shape.

Args:

Returns: