tf.compat.v1.distributions.Dirichlet

View source on GitHub

Dirichlet distribution.

Inherits From: Distribution

tf.compat.v1.distributions.Dirichlet(
    concentration, validate_args=False, allow_nan_stats=True, name='Dirichlet'
)

The Dirichlet distribution is defined over the (k-1)-simplex using a positive, length-k vector concentration (k > 1). The Dirichlet is identically the Beta distribution when k = 2.

Mathematical Details

The Dirichlet is a distribution over the open (k-1)-simplex, i.e.,

S^{k-1} = { (x_0, ..., x_{k-1}) in R^k : sum_j x_j = 1 and all_j x_j > 0 }.

The probability density function (pdf) is,

pdf(x; alpha) = prod_j x_j**(alpha_j - 1) / Z
Z = prod_j Gamma(alpha_j) / Gamma(sum_j alpha_j)

where:

The concentration represents mean total counts of class occurrence, i.e.,

concentration = alpha = mean * total_concentration

where mean in S^{k-1} and total_concentration is a positive real number representing a mean total count.

Distribution parameters are automatically broadcast in all functions; see examples for details.

Warning: Some components of the samples can be zero due to finite precision. This happens more often when some of the concentrations are very small. Make sure to round the samples to np.finfo(dtype).tiny before computing the density.

Samples of this distribution are reparameterized (pathwise differentiable). The derivatives are computed using the approach described in the paper

Michael Figurnov, Shakir Mohamed, Andriy Mnih. Implicit Reparameterization Gradients, 2018

Examples

import tensorflow_probability as tfp
tfd = tfp.distributions

# Create a single trivariate Dirichlet, with the 3rd class being three times
# more frequent than the first. I.e., batch_shape=[], event_shape=[3].
alpha = [1., 2, 3]
dist = tfd.Dirichlet(alpha)

dist.sample([4, 5])  # shape: [4, 5, 3]

# x has one sample, one batch, three classes:
x = [.2, .3, .5]   # shape: [3]
dist.prob(x)       # shape: []

# x has two samples from one batch:
x = [[.1, .4, .5],
     [.2, .3, .5]]
dist.prob(x)         # shape: [2]

# alpha will be broadcast to shape [5, 7, 3] to match x.
x = [[...]]   # shape: [5, 7, 3]
dist.prob(x)  # shape: [5, 7]
# Create batch_shape=[2], event_shape=[3]:
alpha = [[1., 2, 3],
         [4, 5, 6]]   # shape: [2, 3]
dist = tfd.Dirichlet(alpha)

dist.sample([4, 5])  # shape: [4, 5, 2, 3]

x = [.2, .3, .5]
# x will be broadcast as [[.2, .3, .5],
#                         [.2, .3, .5]],
# thus matching batch_shape [2, 3].
dist.prob(x)         # shape: [2]

Compute the gradients of samples w.r.t. the parameters:

alpha = tf.constant([1.0, 2.0, 3.0])
dist = tfd.Dirichlet(alpha)
samples = dist.sample(5)  # Shape [5, 3]
loss = tf.reduce_mean(tf.square(samples))  # Arbitrary loss function
# Unbiased stochastic gradients of the loss function
grads = tf.gradients(loss, alpha)

Args:

Attributes:

Methods

batch_shape_tensor

View source

batch_shape_tensor(
    name='batch_shape_tensor'
)

Shape of a single sample from a single event index as a 1-D Tensor.

The batch dimensions are indexes into independent, non-identical parameterizations of this distribution.

Args:

Returns:

cdf

View source

cdf(
    value, name='cdf'
)

Cumulative distribution function.

Given random variable X, the cumulative distribution function cdf is:

cdf(x) := P[X <= x]

Args:

Returns:

copy

View source

copy(
    **override_parameters_kwargs
)

Creates a deep copy of the distribution.

Note: the copy distribution may continue to depend on the original initialization arguments.

Args:

Returns:

covariance

View source

covariance(
    name='covariance'
)

Covariance.

Covariance is (possibly) defined only for non-scalar-event distributions.

For example, for a length-k, vector-valued distribution, it is calculated as,

Cov[i, j] = Covariance(X_i, X_j) = E[(X_i - E[X_i]) (X_j - E[X_j])]

where Cov is a (batch of) k x k matrix, 0 <= (i, j) < k, and E denotes expectation.

Alternatively, for non-vector, multivariate distributions (e.g., matrix-valued, Wishart), Covariance shall return a (batch of) matrices under some vectorization of the events, i.e.,

Cov[i, j] = Covariance(Vec(X)_i, Vec(X)_j) = [as above]

where Cov is a (batch of) k' x k' matrices, 0 <= (i, j) < k' = reduce_prod(event_shape), and Vec is some function mapping indices of this distribution's event dimensions to indices of a length-k' vector.

Args:

Returns:

cross_entropy

View source

cross_entropy(
    other, name='cross_entropy'
)

Computes the (Shannon) cross entropy.

Denote this distribution (self) by P and the other distribution by Q. Assuming P, Q are absolutely continuous with respect to one another and permit densities p(x) dr(x) and q(x) dr(x), (Shanon) cross entropy is defined as:

H[P, Q] = E_p[-log q(X)] = -int_F p(x) log q(x) dr(x)

where F denotes the support of the random variable X ~ P.

Args:

Returns:

entropy

View source

entropy(
    name='entropy'
)

Shannon entropy in nats.

event_shape_tensor

View source

event_shape_tensor(
    name='event_shape_tensor'
)

Shape of a single sample from a single batch as a 1-D int32 Tensor.

Args:

Returns:

is_scalar_batch

View source

is_scalar_batch(
    name='is_scalar_batch'
)

Indicates that batch_shape == [].

Args:

Returns:

is_scalar_event

View source

is_scalar_event(
    name='is_scalar_event'
)

Indicates that event_shape == [].

Args:

Returns:

kl_divergence

View source

kl_divergence(
    other, name='kl_divergence'
)

Computes the Kullback--Leibler divergence.

Denote this distribution (self) by p and the other distribution by q. Assuming p, q are absolutely continuous with respect to reference measure r, the KL divergence is defined as:

KL[p, q] = E_p[log(p(X)/q(X))]
         = -int_F p(x) log q(x) dr(x) + int_F p(x) log p(x) dr(x)
         = H[p, q] - H[p]

where F denotes the support of the random variable X ~ p, H[., .] denotes (Shanon) cross entropy, and H[.] denotes (Shanon) entropy.

Args:

Returns:

log_cdf

View source

log_cdf(
    value, name='log_cdf'
)

Log cumulative distribution function.

Given random variable X, the cumulative distribution function cdf is:

log_cdf(x) := Log[ P[X <= x] ]

Often, a numerical approximation can be used for log_cdf(x) that yields a more accurate answer than simply taking the logarithm of the cdf when x << -1.

Args:

Returns:

log_prob

View source

log_prob(
    value, name='log_prob'
)

Log probability density/mass function.

Additional documentation from Dirichlet:

Note: value must be a non-negative tensor with dtype self.dtype and be in the (self.event_shape() - 1)-simplex, i.e., tf.reduce_sum(value, -1) = 1. It must have a shape compatible with self.batch_shape() + self.event_shape().

Args:

Returns:

log_survival_function

View source

log_survival_function(
    value, name='log_survival_function'
)

Log survival function.

Given random variable X, the survival function is defined:

log_survival_function(x) = Log[ P[X > x] ]
                         = Log[ 1 - P[X <= x] ]
                         = Log[ 1 - cdf(x) ]

Typically, different numerical approximations can be used for the log survival function, which are more accurate than 1 - cdf(x) when x >> 1.

Args:

Returns:

Tensor of shape sample_shape(x) + self.batch_shape with values of type self.dtype.

mean

View source

mean(
    name='mean'
)

Mean.

mode

View source

mode(
    name='mode'
)

Mode.

Additional documentation from Dirichlet:

Note: The mode is undefined when any concentration <= 1. If self.allow_nan_stats is True, NaN is used for undefined modes. If self.allow_nan_stats is False an exception is raised when one or more modes are undefined.

param_shapes

View source

@classmethod
param_shapes(
    sample_shape, name='DistributionParamShapes'
)

Shapes of parameters given the desired shape of a call to sample().

This is a class method that describes what key/value arguments are required to instantiate the given Distribution so that a particular shape is returned for that instance's call to sample().

Subclasses should override class method _param_shapes.

Args:

Returns:

dict of parameter name to Tensor shapes.

param_static_shapes

View source

@classmethod
param_static_shapes(
    sample_shape
)

param_shapes with static (i.e. TensorShape) shapes.

This is a class method that describes what key/value arguments are required to instantiate the given Distribution so that a particular shape is returned for that instance's call to sample(). Assumes that the sample's shape is known statically.

Subclasses should override class method _param_shapes to return constant-valued tensors when constant values are fed.

Args:

Returns:

dict of parameter name to TensorShape.

Raises:

prob

View source

prob(
    value, name='prob'
)

Probability density/mass function.

Additional documentation from Dirichlet:

Note: value must be a non-negative tensor with dtype self.dtype and be in the (self.event_shape() - 1)-simplex, i.e., tf.reduce_sum(value, -1) = 1. It must have a shape compatible with self.batch_shape() + self.event_shape().

Args:

Returns:

quantile

View source

quantile(
    value, name='quantile'
)

Quantile function. Aka "inverse cdf" or "percent point function".

Given random variable X and p in [0, 1], the quantile is:

quantile(p) := x such that P[X <= x] == p

Args:

Returns:

sample

View source

sample(
    sample_shape=(), seed=None, name='sample'
)

Generate samples of the specified shape.

Note that a call to sample() without arguments will generate a single sample.

Args:

Returns:

stddev

View source

stddev(
    name='stddev'
)

Standard deviation.

Standard deviation is defined as,

stddev = E[(X - E[X])**2]**0.5

where X is the random variable associated with this distribution, E denotes expectation, and stddev.shape = batch_shape + event_shape.

Args:

Returns:

survival_function

View source

survival_function(
    value, name='survival_function'
)

Survival function.

Given random variable X, the survival function is defined:

survival_function(x) = P[X > x]
                     = 1 - P[X <= x]
                     = 1 - cdf(x).

Args:

Returns:

Tensor of shape sample_shape(x) + self.batch_shape with values of type self.dtype.

variance

View source

variance(
    name='variance'
)

Variance.

Variance is defined as,

Var = E[(X - E[X])**2]

where X is the random variable associated with this distribution, E denotes expectation, and Var.shape = batch_shape + event_shape.

Args:

Returns: