• Docs >
  • Probability distributions - torch.distributions
Shortcuts

Probability distributions - torch.distributions

The distributions package contains parameterizable probability distributions and sampling functions. This allows the construction of stochastic computation graphs and stochastic gradient estimators for optimization. This package generally follows the design of the TensorFlow Distributions package.

It is not possible to directly backpropagate through random samples. However, there are two main methods for creating surrogate functions that can be backpropagated through. These are the score function estimator/likelihood ratio estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly seen as the basis for policy gradient methods in reinforcement learning, and the pathwise derivative estimator is commonly seen in the reparameterization trick in variational autoencoders. Whilst the score function only requires the value of samples \(f(x)\), the pathwise derivative requires the derivative \(f'(x)\). The next sections discuss these two in a reinforcement learning example. For more details see Gradient Estimation Using Stochastic Computation Graphs .

Score function

When the probability density function is differentiable with respect to its parameters, we only need sample() and log_prob() to implement REINFORCE:

\[\Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta}\]

where \(\theta\) are the parameters, \(\alpha\) is the learning rate, \(r\) is the reward and \(p(a|\pi^\theta(s))\) is the probability of taking action \(a\) in state \(s\) given policy \(\pi^\theta\).

In practice we would sample an action from the output of a network, apply this action in an environment, and then use log_prob to construct an equivalent loss function. Note that we use a negative because optimizers use gradient descent, whilst the rule above assumes gradient ascent. With a categorical policy, the code for implementing REINFORCE would be as follows:

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

Pathwise derivative

The other way to implement these stochastic/policy gradients would be to use the reparameterization trick from the rsample() method, where the parameterized random variable can be constructed via a parameterized deterministic function of a parameter-free random variable. The reparameterized sample therefore becomes differentiable. The code for implementing the pathwise derivative would be as follows:

params = policy_network(state)
m = Normal(*params)
# Any distribution with .has_rsample == True could work based on the application
action = m.rsample()
next_state, reward = env.step(action)  # Assuming that reward is differentiable
loss = -reward
loss.backward()

Distribution

class torch.distributions.distribution.Distribution(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]

Bases: object

Distribution is the abstract base class for probability distributions.

arg_constraints

Returns a dictionary from argument names to Constraint objects that should be satisfied by each argument of this distribution. Args that are not tensors need not appear in this dict.

batch_shape

Returns the shape over which parameters are batched.

cdf(value)[source]

Returns the cumulative density/mass function evaluated at value.

Parameters:value (Tensor) –
entropy()[source]

Returns entropy of distribution, batched over batch_shape.

Returns:Tensor of shape batch_shape.
enumerate_support(expand=True)[source]

Returns tensor containing all values supported by a discrete distribution. The result will enumerate over dimension 0, so the shape of the result will be (cardinality,) + batch_shape + event_shape (where event_shape = () for univariate distributions).

Note that this enumerates over all batched tensors in lock-step [[0, 0], [1, 1], …]. With expand=False, enumeration happens along dim 0, but with the remaining batch dimensions being singleton dimensions, [[0], [1], ...

To iterate over the full Cartesian product use itertools.product(m.enumerate_support()).

Parameters:expand (bool) – whether to expand the support over the batch dims to match the distribution’s batch_shape.
Returns:Tensor iterating over dimension 0.
event_shape

Returns the shape of a single sample (without batching).

expand(batch_shape, _instance=None)[source]

Returns a new distribution instance (or populates an existing instance provided by a derived class) with batch dimensions expanded to batch_shape. This method calls expand on the distribution’s parameters. As such, this does not allocate new memory for the expanded distribution instance. Additionally, this does not repeat any args checking or parameter broadcasting in __init__.py, when an instance is first created.

Parameters:
  • batch_shape (torch.Size) – the desired expanded size.
  • _instance – new instance provided by subclasses that need to override .expand.
Returns:

New distribution instance with batch dimensions expanded to batch_size.

icdf(value)[source]

Returns the inverse cumulative density/mass function evaluated at value.

Parameters:value (Tensor) –
log_prob(value)[source]

Returns the log of the probability density/mass function evaluated at value.

Parameters:value (Tensor) –
mean

Returns the mean of the distribution.

perplexity()[source]

Returns perplexity of distribution, batched over batch_shape.

Returns:Tensor of shape batch_shape.
rsample(sample_shape=torch.Size([]))[source]

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched.

sample(sample_shape=torch.Size([]))[source]

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched.

sample_n(n)[source]

Generates n samples or n batches of samples if the distribution parameters are batched.

stddev

Returns the standard deviation of the distribution.

support

Returns a Constraint object representing this distribution’s support.

variance

Returns the variance of the distribution.

ExponentialFamily

class torch.distributions.exp_family.ExponentialFamily(batch_shape=torch.Size([]), event_shape=torch.Size([]), validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

ExponentialFamily is the abstract base class for probability distributions belonging to an exponential family, whose probability mass/density function has the form is defined below

\[p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle) - F(\theta) + k(x))\]

where \(\theta\) denotes the natural parameters, \(t(x)\) denotes the sufficient statistic, \(F(\theta)\) is the log normalizer function for a given family and \(k(x)\) is the carrier measure.

Note

This class is an intermediary between the Distribution class and distributions which belong to an exponential family mainly to check the correctness of the .entropy() and analytic KL divergence methods. We use this class to compute the entropy and KL divergence using the AD frame- work and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and Cross-entropies of Exponential Families).

entropy()[source]

Method to compute the entropy using Bregman divergence of the log normalizer.

Bernoulli

class torch.distributions.bernoulli.Bernoulli(probs=None, logits=None, validate_args=None)[source]

Bases: torch.distributions.exp_family.ExponentialFamily

Creates a Bernoulli distribution parameterized by probs or logits (but not both).

Samples are binary (0 or 1). They take the value 1 with probability p and 0 with probability 1 - p.

Example:

>>> m = Bernoulli(torch.tensor([0.3]))
>>> m.sample()  # 30% chance 1; 70% chance 0
tensor([ 0.])
Parameters:
  • probs (Number, Tensor) – the probabilty of sampling 1
  • logits (Number, Tensor) – the log-odds of sampling 1
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_enumerate_support = True
log_prob(value)[source]
logits[source]
mean
param_shape
probs[source]
sample(sample_shape=torch.Size([]))[source]
support = Boolean()
variance

Beta

class torch.distributions.beta.Beta(concentration1, concentration0, validate_args=None)[source]

Bases: torch.distributions.exp_family.ExponentialFamily

Beta distribution parameterized by concentration1 and concentration0.

Example:

>>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5]))
>>> m.sample()  # Beta distributed with concentration concentration1 and concentration0
tensor([ 0.1046])
Parameters:
  • concentration1 (float or Tensor) – 1st concentration parameter of the distribution (often referred to as alpha)
  • concentration0 (float or Tensor) – 2nd concentration parameter of the distribution (often referred to as beta)
arg_constraints = {'concentration0': GreaterThan(lower_bound=0.0), 'concentration1': GreaterThan(lower_bound=0.0)}
concentration0
concentration1
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
mean
rsample(sample_shape=())[source]
support = Interval(lower_bound=0.0, upper_bound=1.0)
variance

Binomial

class torch.distributions.binomial.Binomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a Binomial distribution parameterized by total_count and either probs or logits (but not both). total_count must be broadcastable with probs/logits.

Example:

>>> m = Binomial(100, torch.tensor([0 , .2, .8, 1]))
>>> x = m.sample()
tensor([   0.,   22.,   71.,  100.])

>>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8]))
>>> x = m.sample()
tensor([[ 4.,  5.],
        [ 7.,  6.]])
Parameters:
  • total_count (int or Tensor) – number of Bernoulli trials
  • probs (Tensor) – Event probabilities
  • logits (Tensor) – Event log-odds
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0), 'total_count': IntegerGreaterThan(lower_bound=0)}
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_enumerate_support = True
log_prob(value)[source]
logits[source]
mean
param_shape
probs[source]
sample(sample_shape=torch.Size([]))[source]
support
variance

Categorical

class torch.distributions.categorical.Categorical(probs=None, logits=None, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a categorical distribution parameterized by either probs or logits (but not both).

Note

It is equivalent to the distribution that torch.multinomial() samples from.

Samples are integers from \(\{0, \ldots, K-1\}\) where K is probs.size(-1).

If probs is 1D with length-K, each element is the relative probability of sampling the class at that index.

If probs is 2D, it is treated as a batch of relative probability vectors.

Note

probs must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1.

See also: torch.multinomial()

Example:

>>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample()  # equal probability of 0, 1, 2, 3
tensor(3)
Parameters:
  • probs (Tensor) – event probabilities
  • logits (Tensor) – event log probabilities
arg_constraints = {'logits': Real(), 'probs': Simplex()}
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_enumerate_support = True
log_prob(value)[source]
logits[source]
mean
param_shape
probs[source]
sample(sample_shape=torch.Size([]))[source]
support
variance

Cauchy

class torch.distributions.cauchy.Cauchy(loc, scale, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of independent normally distributed random variables with means 0 follows a Cauchy distribution.

Example:

>>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # sample from a Cauchy distribution with loc=0 and scale=1
tensor([ 2.3214])
Parameters:
  • loc (float or Tensor) – mode or median of the distribution.
  • scale (float or Tensor) – half width at half maximum.
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
support = Real()
variance

Chi2

class torch.distributions.chi2.Chi2(df, validate_args=None)[source]

Bases: torch.distributions.gamma.Gamma

Creates a Chi2 distribution parameterized by shape parameter df. This is exactly equivalent to Gamma(alpha=0.5*df, beta=0.5)

Example:

>>> m = Chi2(torch.tensor([1.0]))
>>> m.sample()  # Chi2 distributed with shape df=1
tensor([ 0.1046])
Parameters:df (float or Tensor) – shape parameter of the distribution
arg_constraints = {'df': GreaterThan(lower_bound=0.0)}
df
expand(batch_shape, _instance=None)[source]

Dirichlet

class torch.distributions.dirichlet.Dirichlet(concentration, validate_args=None)[source]

Bases: torch.distributions.exp_family.ExponentialFamily

Creates a Dirichlet distribution parameterized by concentration concentration.

Example:

>>> m = Dirichlet(torch.tensor([0.5, 0.5]))
>>> m.sample()  # Dirichlet distributed with concentrarion concentration
tensor([ 0.1046,  0.8954])
Parameters:concentration (Tensor) – concentration parameter of the distribution (often referred to as alpha)
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
mean
rsample(sample_shape=())[source]
support = Simplex()
variance

Exponential

class torch.distributions.exponential.Exponential(rate, validate_args=None)[source]

Bases: torch.distributions.exp_family.ExponentialFamily

Creates a Exponential distribution parameterized by rate.

Example:

>>> m = Exponential(torch.tensor([1.0]))
>>> m.sample()  # Exponential distributed with rate=1
tensor([ 0.1046])
Parameters:rate (float or Tensor) – rate = 1 / scale of the distribution
arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
stddev
support = GreaterThan(lower_bound=0.0)
variance

FisherSnedecor

class torch.distributions.fishersnedecor.FisherSnedecor(df1, df2, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a Fisher-Snedecor distribution parameterized by df1 and df2.

Example:

>>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample()  # Fisher-Snedecor-distributed with df1=1 and df2=2
tensor([ 0.2453])
Parameters:
  • df1 (float or Tensor) – degrees of freedom parameter 1
  • df2 (float or Tensor) – degrees of freedom parameter 2
arg_constraints = {'df1': GreaterThan(lower_bound=0.0), 'df2': GreaterThan(lower_bound=0.0)}
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
support = GreaterThan(lower_bound=0.0)
variance

Gamma

class torch.distributions.gamma.Gamma(concentration, rate, validate_args=None)[source]

Bases: torch.distributions.exp_family.ExponentialFamily

Creates a Gamma distribution parameterized by shape concentration and rate.

Example:

>>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # Gamma distributed with concentration=1 and rate=1
tensor([ 0.1046])
Parameters:
  • concentration (float or Tensor) – shape parameter of the distribution (often referred to as alpha)
  • rate (float or Tensor) – rate = 1 / scale of the distribution (often referred to as beta)
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'rate': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
support = GreaterThan(lower_bound=0.0)
variance

Geometric

class torch.distributions.geometric.Geometric(probs=None, logits=None, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a Geometric distribution parameterized by probs, where probs is the probability of success of Bernoulli trials. It represents the probability that in \(k + 1\) Bernoulli trials, the first \(k\) trials failed, before seeing a success.

Samples are non-negative integers [0, \(\inf\)).

Example:

>>> m = Geometric(torch.tensor([0.3]))
>>> m.sample()  # underlying Bernoulli has 30% chance 1; 70% chance 0
tensor([ 2.])
Parameters:
  • probs (Number, Tensor) – the probabilty of sampling 1. Must be in range (0, 1]
  • logits (Number, Tensor) – the log-odds of sampling 1.
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
logits[source]
mean
probs[source]
sample(sample_shape=torch.Size([]))[source]
support = IntegerGreaterThan(lower_bound=0)
variance

Gumbel

class torch.distributions.gumbel.Gumbel(loc, scale, validate_args=None)[source]

Bases: torch.distributions.transformed_distribution.TransformedDistribution

Samples from a Gumbel Distribution.

Examples:

>>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
>>> m.sample()  # sample from Gumbel distribution with loc=1, scale=2
tensor([ 1.0124])
Parameters:
  • loc (float or Tensor) – Location parameter of the distribution
  • scale (float or Tensor) – Scale parameter of the distribution
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
mean
stddev
support = Real()
variance

HalfCauchy

class torch.distributions.half_cauchy.HalfCauchy(scale, validate_args=None)[source]

Bases: torch.distributions.transformed_distribution.TransformedDistribution

Creates a half-normal distribution parameterized by scale where:

X ~ Cauchy(0, scale)
Y = |X| ~ HalfCauchy(scale)

Example:

>>> m = HalfCauchy(torch.tensor([1.0]))
>>> m.sample()  # half-cauchy distributed with scale=1
tensor([ 2.3214])
Parameters:scale (float or Tensor) – scale of the full Cauchy distribution
arg_constraints = {'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(prob)[source]
log_prob(value)[source]
mean
scale
support = GreaterThan(lower_bound=0.0)
variance

HalfNormal

class torch.distributions.half_normal.HalfNormal(scale, validate_args=None)[source]

Bases: torch.distributions.transformed_distribution.TransformedDistribution

Creates a half-normal distribution parameterized by scale where:

X ~ Normal(0, scale)
Y = |X| ~ HalfNormal(scale)

Example:

>>> m = HalfNormal(torch.tensor([1.0]))
>>> m.sample()  # half-normal distributed with scale=1
tensor([ 0.1046])
Parameters:scale (float or Tensor) – scale of the full Normal distribution
arg_constraints = {'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(prob)[source]
log_prob(value)[source]
mean
scale
support = GreaterThan(lower_bound=0.0)
variance

Independent

class torch.distributions.independent.Independent(base_distribution, reinterpreted_batch_ndims, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Reinterprets some of the batch dims of a distribution as event dims.

This is mainly useful for changing the shape of the result of log_prob(). For example to create a diagonal Normal distribution with the same shape as a Multivariate Normal distribution (so they are interchangeable), you can:

>>> loc = torch.zeros(3)
>>> scale = torch.ones(3)
>>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
>>> [mvn.batch_shape, mvn.event_shape]
[torch.Size(()), torch.Size((3,))]
>>> normal = Normal(loc, scale)
>>> [normal.batch_shape, normal.event_shape]
[torch.Size((3,)), torch.Size(())]
>>> diagn = Independent(normal, 1)
>>> [diagn.batch_shape, diagn.event_shape]
[torch.Size(()), torch.Size((3,))]
Parameters:
arg_constraints = {}
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_enumerate_support
has_rsample
log_prob(value)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
sample(sample_shape=torch.Size([]))[source]
support
variance

Laplace

class torch.distributions.laplace.Laplace(loc, scale, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a Laplace distribution parameterized by loc and :attr:’scale’.

Example:

>>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # Laplace distributed with loc=0, scale=1
tensor([ 0.1046])
Parameters:
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
stddev
support = Real()
variance

LogNormal

class torch.distributions.log_normal.LogNormal(loc, scale, validate_args=None)[source]

Bases: torch.distributions.transformed_distribution.TransformedDistribution

Creates a log-normal distribution parameterized by loc and scale where:

X ~ Normal(loc, scale)
Y = exp(X) ~ LogNormal(loc, scale)

Example:

>>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # log-normal distributed with mean=0 and stddev=1
tensor([ 0.1046])
Parameters:
  • loc (float or Tensor) – mean of log of distribution
  • scale (float or Tensor) – standard deviation of log of the distribution
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
loc
mean
scale
support = GreaterThan(lower_bound=0.0)
variance

LowRankMultivariateNormal

class torch.distributions.lowrank_multivariate_normal.LowRankMultivariateNormal(loc, cov_factor, cov_diag, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a multivariate normal distribution with covariance matrix having a low-rank form parameterized by cov_factor and cov_diag:

covariance_matrix = cov_factor @ cov_factor.T + cov_diag

Example

>>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1]))
>>> m.sample()  # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]`
tensor([-0.2102, -0.5429])
Parameters:
  • loc (Tensor) – mean of the distribution with shape batch_shape + event_shape
  • cov_factor (Tensor) – factor part of low-rank form of covariance matrix with shape batch_shape + event_shape + (rank,)
  • cov_diag (Tensor) – diagonal part of low-rank form of covariance matrix with shape batch_shape + event_shape

Note

The computation for determinant and inverse of covariance matrix is avoided when cov_factor.shape[1] << cov_factor.shape[0] thanks to Woodbury matrix identity and matrix determinant lemma. Thanks to these formulas, we just need to compute the determinant and inverse of the small size “capacitance” matrix:

capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
arg_constraints = {'cov_diag': GreaterThan(lower_bound=0.0), 'cov_factor': Real(), 'loc': Real()}
covariance_matrix[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
mean
precision_matrix[source]
rsample(sample_shape=torch.Size([]))[source]
scale_tril[source]
support = Real()
variance[source]

Multinomial

class torch.distributions.multinomial.Multinomial(total_count=1, probs=None, logits=None, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a Multinomial distribution parameterized by total_count and either probs or logits (but not both). The innermost dimension of probs indexes over categories. All other dimensions index over batches.

Note that total_count need not be specified if only log_prob() is called (see example below)

Note

probs must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1.

  • sample() requires a single shared total_count for all parameters and samples.
  • log_prob() allows different total_count for each parameter and sample.

Example:

>>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
>>> x = m.sample()  # equal probability of 0, 1, 2, 3
tensor([ 21.,  24.,  30.,  25.])

>>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
tensor([-4.1338])
Parameters:
  • total_count (int) – number of trials
  • probs (Tensor) – event probabilities
  • logits (Tensor) – event log probabilities
arg_constraints = {'logits': Real(), 'probs': Simplex()}
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
logits
mean
param_shape
probs
sample(sample_shape=torch.Size([]))[source]
support
variance

MultivariateNormal

class torch.distributions.multivariate_normal.MultivariateNormal(loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a multivariate normal (also called Gaussian) distribution parameterized by a mean vector and a covariance matrix.

The multivariate normal distribution can be parameterized either in terms of a positive definite covariance matrix \(\mathbf{\Sigma}\) or a positive definite precision matrix \(\mathbf{\Sigma}^{-1}\) or a lower-triangular matrix \(\mathbf{L}\) with positive-valued diagonal entries, such that \(\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top\). This triangular matrix can be obtained via e.g. Cholesky decomposition of the covariance.

Example

>>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
>>> m.sample()  # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
tensor([-0.2102, -0.5429])
Parameters:
  • loc (Tensor) – mean of the distribution
  • covariance_matrix (Tensor) – positive-definite covariance matrix
  • precision_matrix (Tensor) – positive-definite precision matrix
  • scale_tril (Tensor) – lower-triangular factor of covariance, with positive-valued diagonal

Note

Only one of covariance_matrix or precision_matrix or scale_tril can be specified.

Using scale_tril will be more efficient: all computations internally are based on scale_tril. If covariance_matrix or precision_matrix is passed instead, it is only used to compute the corresponding lower triangular matrices using a Cholesky decomposition.

arg_constraints = {'covariance_matrix': PositiveDefinite(), 'loc': RealVector(), 'precision_matrix': PositiveDefinite(), 'scale_tril': LowerCholesky()}
covariance_matrix[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
mean
precision_matrix[source]
rsample(sample_shape=torch.Size([]))[source]
scale_tril[source]
support = Real()
variance

NegativeBinomial

class torch.distributions.negative_binomial.NegativeBinomial(total_count, probs=None, logits=None, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a Negative Binomial distribution, i.e. distribution of the number of independent identical Bernoulli trials needed before total_count failures are achieved. The probability of success of each Bernoulli trial is probs.

Parameters:
  • total_count (float or Tensor) – non-negative number of negative Bernoulli trials to stop, although the distribution is still valid for real valued count
  • probs (Tensor) – Event probabilities of success in the half open interval [0, 1)
  • logits (Tensor) – Event log-odds for probabilities of success
arg_constraints = {'logits': Real(), 'probs': HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), 'total_count': GreaterThanEq(lower_bound=0)}
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
logits[source]
mean
param_shape
probs[source]
sample(sample_shape=torch.Size([]))[source]
support = IntegerGreaterThan(lower_bound=0)
variance

Normal

class torch.distributions.normal.Normal(loc, scale, validate_args=None)[source]

Bases: torch.distributions.exp_family.ExponentialFamily

Creates a normal (also called Gaussian) distribution parameterized by loc and scale.

Example:

>>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0]))
>>> m.sample()  # normally distributed with loc=0 and scale=1
tensor([ 0.1046])
Parameters:
  • loc (float or Tensor) – mean of the distribution (often referred to as mu)
  • scale (float or Tensor) – standard deviation of the distribution (often referred to as sigma)
arg_constraints = {'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
sample(sample_shape=torch.Size([]))[source]
stddev
support = Real()
variance

OneHotCategorical

class torch.distributions.one_hot_categorical.OneHotCategorical(probs=None, logits=None, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a one-hot categorical distribution parameterized by probs or logits.

Samples are one-hot coded vectors of size probs.size(-1).

Note

probs must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1.

See also: torch.distributions.Categorical() for specifications of probs and logits.

Example:

>>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
>>> m.sample()  # equal probability of 0, 1, 2, 3
tensor([ 0.,  0.,  0.,  1.])
Parameters:
  • probs (Tensor) – event probabilities
  • logits (Tensor) – event log probabilities
arg_constraints = {'logits': Real(), 'probs': Simplex()}
entropy()[source]
enumerate_support(expand=True)[source]
expand(batch_shape, _instance=None)[source]
has_enumerate_support = True
log_prob(value)[source]
logits
mean
param_shape
probs
sample(sample_shape=torch.Size([]))[source]
support = Simplex()
variance

Pareto

class torch.distributions.pareto.Pareto(scale, alpha, validate_args=None)[source]

Bases: torch.distributions.transformed_distribution.TransformedDistribution

Samples from a Pareto Type 1 distribution.

Example:

>>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # sample from a Pareto distribution with scale=1 and alpha=1
tensor([ 1.5623])
Parameters:
  • scale (float or Tensor) – Scale parameter of the distribution
  • alpha (float or Tensor) – Shape parameter of the distribution
arg_constraints = {'alpha': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
mean
support
variance

Poisson

class torch.distributions.poisson.Poisson(rate, validate_args=None)[source]

Bases: torch.distributions.exp_family.ExponentialFamily

Creates a Poisson distribution parameterized by rate, the rate parameter.

Samples are nonnegative integers, with a pmf given by

\[\mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!} \]

Example:

>>> m = Poisson(torch.tensor([4]))
>>> m.sample()
tensor([ 3.])
Parameters:rate (Number, Tensor) – the rate parameter
arg_constraints = {'rate': GreaterThan(lower_bound=0.0)}
expand(batch_shape, _instance=None)[source]
log_prob(value)[source]
mean
sample(sample_shape=torch.Size([]))[source]
support = IntegerGreaterThan(lower_bound=0)
variance

RelaxedBernoulli

class torch.distributions.relaxed_bernoulli.RelaxedBernoulli(temperature, probs=None, logits=None, validate_args=None)[source]

Bases: torch.distributions.transformed_distribution.TransformedDistribution

Creates a RelaxedBernoulli distribution, parametrized by temperature, and either probs or logits (but not both). This is a relaxed version of the Bernoulli distribution, so the values are in (0, 1), and has reparametrizable samples.

Example:

>>> m = RelaxedBernoulli(torch.tensor([2.2]),
                         torch.tensor([0.1, 0.2, 0.3, 0.99]))
>>> m.sample()
tensor([ 0.2951,  0.3442,  0.8918,  0.9021])
Parameters:
  • temperature (Tensor) – relaxation temperature
  • probs (Number, Tensor) – the probabilty of sampling 1
  • logits (Number, Tensor) – the log-odds of sampling 1
arg_constraints = {'logits': Real(), 'probs': Interval(lower_bound=0.0, upper_bound=1.0)}
expand(batch_shape, _instance=None)[source]
has_rsample = True
logits
probs
support = Interval(lower_bound=0.0, upper_bound=1.0)
temperature

RelaxedOneHotCategorical

class torch.distributions.relaxed_categorical.RelaxedOneHotCategorical(temperature, probs=None, logits=None, validate_args=None)[source]

Bases: torch.distributions.transformed_distribution.TransformedDistribution

Creates a RelaxedOneHotCategorical distribution parametrized by temperature, and either probs or logits. This is a relaxed version of the OneHotCategorical distribution, so its samples are on simplex, and are reparametrizable.

Example:

>>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
                                 torch.tensor([0.1, 0.2, 0.3, 0.4]))
>>> m.sample()
tensor([ 0.1294,  0.2324,  0.3859,  0.2523])
Parameters:
  • temperature (Tensor) – relaxation temperature
  • probs (Tensor) – event probabilities
  • logits (Tensor) – the log probability of each event.
arg_constraints = {'logits': Real(), 'probs': Simplex()}
expand(batch_shape, _instance=None)[source]
has_rsample = True
logits
probs
support = Simplex()
temperature

StudentT

class torch.distributions.studentT.StudentT(df, loc=0.0, scale=1.0, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Creates a Student’s t-distribution parameterized by degree of freedom df, mean loc and scale scale.

Example:

>>> m = StudentT(torch.tensor([2.0]))
>>> m.sample()  # Student's t-distributed with degrees of freedom=2
tensor([ 0.1046])
Parameters:
arg_constraints = {'df': GreaterThan(lower_bound=0.0), 'loc': Real(), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
log_prob(value)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
support = Real()
variance

TransformedDistribution

class torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Extension of the Distribution class, which applies a sequence of Transforms to a base distribution. Let f be the composition of transforms applied:

X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log |det (dX/dY)|

Note that the .event_shape of a TransformedDistribution is the maximum shape of its base distribution and its transforms, since transforms can introduce correlations among events.

An example for the usage of TransformedDistribution would be:

# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)

For more examples, please look at the implementations of Gumbel, HalfCauchy, HalfNormal, LogNormal, Pareto, Weibull, RelaxedBernoulli and RelaxedOneHotCategorical

arg_constraints = {}
cdf(value)[source]

Computes the cumulative distribution function by inverting the transform(s) and computing the score of the base distribution.

expand(batch_shape, _instance=None)[source]
has_rsample
icdf(value)[source]

Computes the inverse cumulative distribution function using transform(s) and computing the score of the base distribution.

log_prob(value)[source]

Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian.

rsample(sample_shape=torch.Size([]))[source]

Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. Samples first from base distribution and applies transform() for every transform in the list.

sample(sample_shape=torch.Size([]))[source]

Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. Samples first from base distribution and applies transform() for every transform in the list.

support

Uniform

class torch.distributions.uniform.Uniform(low, high, validate_args=None)[source]

Bases: torch.distributions.distribution.Distribution

Generates uniformly distributed random samples from the half-open interval [low, high).

Example:

>>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0]))
>>> m.sample()  # uniformly distributed in the range [0.0, 5.0)
tensor([ 2.3418])
Parameters:
arg_constraints = {'high': Dependent(), 'low': Dependent()}
cdf(value)[source]
entropy()[source]
expand(batch_shape, _instance=None)[source]
has_rsample = True
icdf(value)[source]
log_prob(value)[source]
mean
rsample(sample_shape=torch.Size([]))[source]
stddev
support
variance

Weibull

class torch.distributions.weibull.Weibull(scale, concentration, validate_args=None)[source]

Bases: torch.distributions.transformed_distribution.TransformedDistribution

Samples from a two-parameter Weibull distribution.

Example

>>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0]))
>>> m.sample()  # sample from a Weibull distribution with scale=1, concentration=1
tensor([ 0.4784])
Parameters:
  • scale (float or Tensor) – Scale parameter of distribution (lambda).
  • concentration (float or Tensor) – Concentration parameter of distribution (k/shape).
arg_constraints = {'concentration': GreaterThan(lower_bound=0.0), 'scale': GreaterThan(lower_bound=0.0)}
entropy()[source]
expand(batch_shape, _instance=None)[source]
mean
support = GreaterThan(lower_bound=0.0)
variance

KL Divergence

torch.distributions.kl.kl_divergence(p, q)[source]

Compute Kullback-Leibler divergence \(KL(p \| q)\) between two distributions.

\[KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx\]
Parameters:
Returns:

A batch of KL divergences of shape batch_shape.

Return type:

Tensor

Raises:

NotImplementedError – If the distribution types have not been registered via register_kl().

torch.distributions.kl.register_kl(type_p, type_q)[source]

Decorator to register a pairwise function with kl_divergence(). Usage:

@register_kl(Normal, Normal)
def kl_normal_normal(p, q):
    # insert implementation here

Lookup returns the most specific (type,type) match ordered by subclass. If the match is ambiguous, a RuntimeWarning is raised. For example to resolve the ambiguous situation:

@register_kl(BaseP, DerivedQ)
def kl_version1(p, q): ...
@register_kl(DerivedP, BaseQ)
def kl_version2(p, q): ...

you should register a third most-specific implementation, e.g.:

register_kl(DerivedP, DerivedQ)(kl_version1)  # Break the tie.
Parameters:
  • type_p (type) – A subclass of Distribution.
  • type_q (type) – A subclass of Distribution.

Transforms

class torch.distributions.transforms.Transform(cache_size=0)[source]

Abstract class for invertable transformations with computable log det jacobians. They are primarily used in torch.distributions.TransformedDistribution.

Caching is useful for tranforms whose inverses are either expensive or numerically unstable. Note that care must be taken with memoized values since the autograd graph may be reversed. For example while the following works with or without caching:

y = t(x)
t.log_abs_det_jacobian(x, y).backward()  # x will receive gradients.

However the following will error when caching due to dependency reversal:

y = t(x)
z = t.inv(y)
grad(z.sum(), [y])  # error because z is x

Derived classes should implement one or both of _call() or _inverse(). Derived classes that set bijective=True should also implement log_abs_det_jacobian().

Parameters:

cache_size (int) – Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported.

Variables:
  • domain (Constraint) – The constraint representing valid inputs to this transform.
  • codomain (Constraint) – The constraint representing valid outputs to this transform which are inputs to the inverse transform.
  • bijective (bool) – Whether this transform is bijective. A transform t is bijective iff t.inv(t(x)) == x and t(t.inv(y)) == y for every x in the domain and y in the codomain. Transforms that are not bijective should at least maintain the weaker pseudoinverse properties t(t.inv(t(x)) == t(x) and t.inv(t(t.inv(y))) == t.inv(y).
  • sign (int or Tensor) – For bijective univariate transforms, this should be +1 or -1 depending on whether transform is monotone increasing or decreasing.
  • event_dim (int) – Number of dimensions that are correlated together in the transform event_shape. This should be 0 for pointwise transforms, 1 for transforms that act jointly on vectors, 2 for transforms that act jointly on matrices, etc.
inv

Returns the inverse Transform of this transform. This should satisfy t.inv.inv is t.

sign

Returns the sign of the determinant of the Jacobian, if applicable. In general this only makes sense for bijective transforms.

log_abs_det_jacobian(x, y)[source]

Computes the log det jacobian log |dy/dx| given input and output.

class torch.distributions.transforms.ComposeTransform(parts)[source]

Composes multiple transforms in a chain. The transforms being composed are responsible for caching.

Parameters:parts (list of Transform) – A list of transforms to compose.
class torch.distributions.transforms.ExpTransform(cache_size=0)[source]

Transform via the mapping \(y = \exp(x)\).

class torch.distributions.transforms.PowerTransform(exponent, cache_size=0)[source]

Transform via the mapping \(y = x^{\text{exponent}}\).

class torch.distributions.transforms.SigmoidTransform(cache_size=0)[source]

Transform via the mapping \(y = \frac{1}{1 + \exp(-x)}\) and \(x = \text{logit}(y)\).

class torch.distributions.transforms.AbsTransform(cache_size=0)[source]

Transform via the mapping \(y = |x|\).

class torch.distributions.transforms.AffineTransform(loc, scale, event_dim=0, cache_size=0)[source]

Transform via the pointwise affine mapping \(y = \text{loc} + \text{scale} \times x\).

Parameters:
  • loc (Tensor or float) – Location parameter.
  • scale (Tensor or float) – Scale parameter.
  • event_dim (int) – Optional size of event_shape. This should be zero for univariate random variables, 1 for distributions over vectors, 2 for distributions over matrices, etc.
class torch.distributions.transforms.SoftmaxTransform(cache_size=0)[source]

Transform from unconstrained space to the simplex via \(y = \exp(x)\) then normalizing.

This is not bijective and cannot be used for HMC. However this acts mostly coordinate-wise (except for the final normalization), and thus is appropriate for coordinate-wise optimization algorithms.

class torch.distributions.transforms.StickBreakingTransform(cache_size=0)[source]

Transform from unconstrained space to the simplex of one additional dimension via a stick-breaking process.

This transform arises as an iterated sigmoid transform in a stick-breaking construction of the Dirichlet distribution: the first logit is transformed via sigmoid to the first probability and the probability of everything else, and then the process recurses.

This is bijective and appropriate for use in HMC; however it mixes coordinates together and is less appropriate for optimization.

class torch.distributions.transforms.LowerCholeskyTransform(cache_size=0)[source]

Transform from unconstrained matrices to lower-triangular matrices with nonnegative diagonal entries.

This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization.

Constraints

The following constraints are implemented:

  • constraints.boolean
  • constraints.dependent
  • constraints.greater_than(lower_bound)
  • constraints.integer_interval(lower_bound, upper_bound)
  • constraints.interval(lower_bound, upper_bound)
  • constraints.lower_cholesky
  • constraints.lower_triangular
  • constraints.nonnegative_integer
  • constraints.positive
  • constraints.positive_definite
  • constraints.positive_integer
  • constraints.real
  • constraints.real_vector
  • constraints.simplex
  • constraints.unit_interval
class torch.distributions.constraints.Constraint[source]

Abstract base class for constraints.

A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized.

check(value)[source]

Returns a byte tensor of sample_shape + batch_shape indicating whether each event in value satisfies this constraint.

torch.distributions.constraints.dependent_property

alias of torch.distributions.constraints._DependentProperty

torch.distributions.constraints.integer_interval

alias of torch.distributions.constraints._IntegerInterval

torch.distributions.constraints.greater_than

alias of torch.distributions.constraints._GreaterThan

torch.distributions.constraints.greater_than_eq

alias of torch.distributions.constraints._GreaterThanEq

torch.distributions.constraints.less_than

alias of torch.distributions.constraints._LessThan

torch.distributions.constraints.interval

alias of torch.distributions.constraints._Interval

torch.distributions.constraints.half_open_interval

alias of torch.distributions.constraints._HalfOpenInterval

Constraint Registry

PyTorch provides two global ConstraintRegistry objects that link Constraint objects to Transform objects. These objects both input constraints and return transforms, but they have different guarantees on bijectivity.

  1. biject_to(constraint) looks up a bijective Transform from constraints.real to the given constraint. The returned transform is guaranteed to have .bijective = True and should implement .log_abs_det_jacobian().
  2. transform_to(constraint) looks up a not-necessarily bijective Transform from constraints.real to the given constraint. The returned transform is not guaranteed to implement .log_abs_det_jacobian().

The transform_to() registry is useful for performing unconstrained optimization on constrained parameters of probability distributions, which are indicated by each distribution’s .arg_constraints dict. These transforms often overparameterize a space in order to avoid rotation; they are thus more suitable for coordinate-wise optimization algorithms like Adam:

loc = torch.zeros(100, requires_grad=True)
unconstrained = torch.zeros(100, requires_grad=True)
scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
loss = -Normal(loc, scale).log_prob(data).sum()

The biject_to() registry is useful for Hamiltonian Monte Carlo, where samples from a probability distribution with constrained .support are propagated in an unconstrained space, and algorithms are typically rotation invariant.:

dist = Exponential(rate)
unconstrained = torch.zeros(100, requires_grad=True)
sample = biject_to(dist.support)(unconstrained)
potential_energy = -dist.log_prob(sample).sum()

Note

An example where transform_to and biject_to differ is constraints.simplex: transform_to(constraints.simplex) returns a SoftmaxTransform that simply exponentiates and normalizes its inputs; this is a cheap and mostly coordinate-wise operation appropriate for algorithms like SVI. In contrast, biject_to(constraints.simplex) returns a StickBreakingTransform that bijects its input down to a one-fewer-dimensional space; this a more expensive less numerically stable transform but is needed for algorithms like HMC.

The biject_to and transform_to objects can be extended by user-defined constraints and transforms using their .register() method either as a function on singleton constraints:

transform_to.register(my_constraint, my_transform)

or as a decorator on parameterized constraints:

@transform_to.register(MyConstraintClass)
def my_factory(constraint):
    assert isinstance(constraint, MyConstraintClass)
    return MyTransform(constraint.param1, constraint.param2)

You can create your own registry by creating a new ConstraintRegistry object.

class torch.distributions.constraint_registry.ConstraintRegistry[source]

Registry to link constraints to transforms.

register(constraint, factory=None)[source]

Registers a Constraint subclass in this registry. Usage:

@my_registry.register(MyConstraintClass)
def construct_transform(constraint):
    assert isinstance(constraint, MyConstraint)
    return MyTransform(constraint.arg_constraints)
Parameters:
  • constraint (subclass of Constraint) – A subclass of Constraint, or a singleton object of the desired class.
  • factory (callable) – A callable that inputs a constraint object and returns a Transform object.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources