Shortcuts

Source code for torch.distributions.transformed_distribution

import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.transforms import Transform
from torch.distributions.utils import _sum_rightmost


[docs]class TransformedDistribution(Distribution): r""" Extension of the Distribution class, which applies a sequence of Transforms to a base distribution. Let f be the composition of transforms applied:: X ~ BaseDistribution Y = f(X) ~ TransformedDistribution(BaseDistribution, f) log p(Y) = log p(X) + log |det (dX/dY)| Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the maximum shape of its base distribution and its transforms, since transforms can introduce correlations among events. An example for the usage of :class:`TransformedDistribution` would be:: # Building a Logistic Distribution # X ~ Uniform(0, 1) # f = a + b * logit(X) # Y ~ f(X) ~ Logistic(a, b) base_distribution = Uniform(0, 1) transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] logistic = TransformedDistribution(base_distribution, transforms) For more examples, please look at the implementations of :class:`~torch.distributions.gumbel.Gumbel`, :class:`~torch.distributions.half_cauchy.HalfCauchy`, :class:`~torch.distributions.half_normal.HalfNormal`, :class:`~torch.distributions.log_normal.LogNormal`, :class:`~torch.distributions.pareto.Pareto`, :class:`~torch.distributions.weibull.Weibull`, :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` """ arg_constraints = {} def __init__(self, base_distribution, transforms, validate_args=None): self.base_dist = base_distribution if isinstance(transforms, Transform): self.transforms = [transforms, ] elif isinstance(transforms, list): if not all(isinstance(t, Transform) for t in transforms): raise ValueError("transforms must be a Transform or a list of Transforms") self.transforms = transforms else: raise ValueError("transforms must be a Transform or list, but was {}".format(transforms)) shape = self.base_dist.batch_shape + self.base_dist.event_shape event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim for t in self.transforms]) batch_shape = shape[:len(shape) - event_dim] event_shape = shape[len(shape) - event_dim:] super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None): new = self._get_checked_instance(TransformedDistribution, _instance) batch_shape = torch.Size(batch_shape) base_dist_batch_shape = batch_shape + self.base_dist.batch_shape[len(self.batch_shape):] new.base_dist = self.base_dist.expand(base_dist_batch_shape) new.transforms = self.transforms super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False) new._validate_args = self._validate_args return new
@constraints.dependent_property def support(self): return self.transforms[-1].codomain if self.transforms else self.base_dist.support @property def has_rsample(self): return self.base_dist.has_rsample
[docs] def sample(self, sample_shape=torch.Size()): """ Generates a sample_shape shaped sample or sample_shape shaped batch of samples if the distribution parameters are batched. Samples first from base distribution and applies `transform()` for every transform in the list. """ with torch.no_grad(): x = self.base_dist.sample(sample_shape) for transform in self.transforms: x = transform(x) return x
[docs] def rsample(self, sample_shape=torch.Size()): """ Generates a sample_shape shaped reparameterized sample or sample_shape shaped batch of reparameterized samples if the distribution parameters are batched. Samples first from base distribution and applies `transform()` for every transform in the list. """ x = self.base_dist.rsample(sample_shape) for transform in self.transforms: x = transform(x) return x
[docs] def log_prob(self, value): """ Scores the sample by inverting the transform(s) and computing the score using the score of the base distribution and the log abs det jacobian. """ event_dim = len(self.event_shape) log_prob = 0.0 y = value for transform in reversed(self.transforms): x = transform.inv(y) log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y), event_dim - transform.event_dim) y = x log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape)) return log_prob
def _monotonize_cdf(self, value): """ This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is monotone increasing. """ sign = 1 for transform in self.transforms: sign = sign * transform.sign if sign is 1: return value return sign * (value - 0.5) + 0.5
[docs] def cdf(self, value): """ Computes the cumulative distribution function by inverting the transform(s) and computing the score of the base distribution. """ for transform in self.transforms[::-1]: value = transform.inv(value) if self._validate_args: self.base_dist._validate_sample(value) value = self.base_dist.cdf(value) value = self._monotonize_cdf(value) return value
[docs] def icdf(self, value): """ Computes the inverse cumulative distribution function using transform(s) and computing the score of the base distribution. """ value = self._monotonize_cdf(value) if self._validate_args: self.base_dist._validate_sample(value) value = self.base_dist.icdf(value) for transform in self.transforms: value = transform(value) return value

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources