Source code for torch.distributions.transformed_distribution
import torch
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.transforms import Transform
from torch.distributions.utils import _sum_rightmost
[docs]class TransformedDistribution(Distribution):
r"""
Extension of the Distribution class, which applies a sequence of Transforms
to a base distribution. Let f be the composition of transforms applied::
X ~ BaseDistribution
Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
log p(Y) = log p(X) + log |det (dX/dY)|
Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
maximum shape of its base distribution and its transforms, since transforms
can introduce correlations among events.
An example for the usage of :class:`TransformedDistribution` would be::
# Building a Logistic Distribution
# X ~ Uniform(0, 1)
# f = a + b * logit(X)
# Y ~ f(X) ~ Logistic(a, b)
base_distribution = Uniform(0, 1)
transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
logistic = TransformedDistribution(base_distribution, transforms)
For more examples, please look at the implementations of
:class:`~torch.distributions.gumbel.Gumbel`,
:class:`~torch.distributions.half_cauchy.HalfCauchy`,
:class:`~torch.distributions.half_normal.HalfNormal`,
:class:`~torch.distributions.log_normal.LogNormal`,
:class:`~torch.distributions.pareto.Pareto`,
:class:`~torch.distributions.weibull.Weibull`,
:class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
:class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
"""
arg_constraints = {}
def __init__(self, base_distribution, transforms, validate_args=None):
self.base_dist = base_distribution
if isinstance(transforms, Transform):
self.transforms = [transforms, ]
elif isinstance(transforms, list):
if not all(isinstance(t, Transform) for t in transforms):
raise ValueError("transforms must be a Transform or a list of Transforms")
self.transforms = transforms
else:
raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))
shape = self.base_dist.batch_shape + self.base_dist.event_shape
event_dim = max([len(self.base_dist.event_shape)] + [t.event_dim for t in self.transforms])
batch_shape = shape[:len(shape) - event_dim]
event_shape = shape[len(shape) - event_dim:]
super(TransformedDistribution, self).__init__(batch_shape, event_shape, validate_args=validate_args)
[docs] def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(TransformedDistribution, _instance)
batch_shape = torch.Size(batch_shape)
base_dist_batch_shape = batch_shape + self.base_dist.batch_shape[len(self.batch_shape):]
new.base_dist = self.base_dist.expand(base_dist_batch_shape)
new.transforms = self.transforms
super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
new._validate_args = self._validate_args
return new
@constraints.dependent_property
def support(self):
return self.transforms[-1].codomain if self.transforms else self.base_dist.support
@property
def has_rsample(self):
return self.base_dist.has_rsample
[docs] def sample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped sample or sample_shape shaped batch of
samples if the distribution parameters are batched. Samples first from
base distribution and applies `transform()` for every transform in the
list.
"""
with torch.no_grad():
x = self.base_dist.sample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
[docs] def rsample(self, sample_shape=torch.Size()):
"""
Generates a sample_shape shaped reparameterized sample or sample_shape
shaped batch of reparameterized samples if the distribution parameters
are batched. Samples first from base distribution and applies
`transform()` for every transform in the list.
"""
x = self.base_dist.rsample(sample_shape)
for transform in self.transforms:
x = transform(x)
return x
[docs] def log_prob(self, value):
"""
Scores the sample by inverting the transform(s) and computing the score
using the score of the base distribution and the log abs det jacobian.
"""
event_dim = len(self.event_shape)
log_prob = 0.0
y = value
for transform in reversed(self.transforms):
x = transform.inv(y)
log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
event_dim - transform.event_dim)
y = x
log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
event_dim - len(self.base_dist.event_shape))
return log_prob
def _monotonize_cdf(self, value):
"""
This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
monotone increasing.
"""
sign = 1
for transform in self.transforms:
sign = sign * transform.sign
if sign is 1:
return value
return sign * (value - 0.5) + 0.5
[docs] def cdf(self, value):
"""
Computes the cumulative distribution function by inverting the
transform(s) and computing the score of the base distribution.
"""
for transform in self.transforms[::-1]:
value = transform.inv(value)
if self._validate_args:
self.base_dist._validate_sample(value)
value = self.base_dist.cdf(value)
value = self._monotonize_cdf(value)
return value
[docs] def icdf(self, value):
"""
Computes the inverse cumulative distribution function using
transform(s) and computing the score of the base distribution.
"""
value = self._monotonize_cdf(value)
if self._validate_args:
self.base_dist._validate_sample(value)
value = self.base_dist.icdf(value)
for transform in self.transforms:
value = transform(value)
return value