Shortcuts

Source code for torch.distributions.transforms

import math
import numbers
import weakref

import torch
from torch.distributions import constraints
from torch.distributions.utils import (_sum_rightmost, broadcast_all,
                                       lazy_property)
from torch.nn.functional import pad

__all__ = [
    'AbsTransform',
    'AffineTransform',
    'ComposeTransform',
    'ExpTransform',
    'LowerCholeskyTransform',
    'PowerTransform',
    'SigmoidTransform',
    'SoftmaxTransform',
    'StickBreakingTransform',
    'Transform',
    'identity_transform',
]


[docs]class Transform(object): """ Abstract class for invertable transformations with computable log det jacobians. They are primarily used in :class:`torch.distributions.TransformedDistribution`. Caching is useful for tranforms whose inverses are either expensive or numerically unstable. Note that care must be taken with memoized values since the autograd graph may be reversed. For example while the following works with or without caching:: y = t(x) t.log_abs_det_jacobian(x, y).backward() # x will receive gradients. However the following will error when caching due to dependency reversal:: y = t(x) z = t.inv(y) grad(z.sum(), [y]) # error because z is x Derived classes should implement one or both of :meth:`_call` or :meth:`_inverse`. Derived classes that set `bijective=True` should also implement :meth:`log_abs_det_jacobian`. Args: cache_size (int): Size of cache. If zero, no caching is done. If one, the latest single value is cached. Only 0 and 1 are supported. Attributes: domain (:class:`~torch.distributions.constraints.Constraint`): The constraint representing valid inputs to this transform. codomain (:class:`~torch.distributions.constraints.Constraint`): The constraint representing valid outputs to this transform which are inputs to the inverse transform. bijective (bool): Whether this transform is bijective. A transform ``t`` is bijective iff ``t.inv(t(x)) == x`` and ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in the codomain. Transforms that are not bijective should at least maintain the weaker pseudoinverse properties ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``. sign (int or Tensor): For bijective univariate transforms, this should be +1 or -1 depending on whether transform is monotone increasing or decreasing. event_dim (int): Number of dimensions that are correlated together in the transform ``event_shape``. This should be 0 for pointwise transforms, 1 for transforms that act jointly on vectors, 2 for transforms that act jointly on matrices, etc. """ bijective = False event_dim = 0 def __init__(self, cache_size=0): self._cache_size = cache_size self._inv = None if cache_size == 0: pass # default behavior elif cache_size == 1: self._cached_x_y = None, None else: raise ValueError('cache_size must be 0 or 1') @property def inv(self): """ Returns the inverse :class:`Transform` of this transform. This should satisfy ``t.inv.inv is t``. """ inv = None if self._inv is not None: inv = self._inv() if inv is None: inv = _InverseTransform(self) self._inv = weakref.ref(inv) return inv @property def sign(self): """ Returns the sign of the determinant of the Jacobian, if applicable. In general this only makes sense for bijective transforms. """ raise NotImplementedError def __eq__(self, other): return self is other def __ne__(self, other): # Necessary for Python2 return not self.__eq__(other) def __call__(self, x): """ Computes the transform `x => y`. """ if self._cache_size == 0: return self._call(x) x_old, y_old = self._cached_x_y if x is x_old: return y_old y = self._call(x) self._cached_x_y = x, y return y def _inv_call(self, y): """ Inverts the transform `y => x`. """ if self._cache_size == 0: return self._inverse(y) x_old, y_old = self._cached_x_y if y is y_old: return x_old x = self._inverse(y) self._cached_x_y = x, y return x def _call(self, x): """ Abstract method to compute forward transformation. """ raise NotImplementedError def _inverse(self, y): """ Abstract method to compute inverse transformation. """ raise NotImplementedError
[docs] def log_abs_det_jacobian(self, x, y): """ Computes the log det jacobian `log |dy/dx|` given input and output. """ raise NotImplementedError
def __repr__(self): return self.__class__.__name__ + '()'
class _InverseTransform(Transform): """ Inverts a single :class:`Transform`. This class is private; please instead use the ``Transform.inv`` property. """ def __init__(self, transform): super(_InverseTransform, self).__init__() self._inv = transform @constraints.dependent_property def domain(self): return self._inv.codomain @constraints.dependent_property def codomain(self): return self._inv.domain @property def bijective(self): return self._inv.bijective @property def sign(self): return self._inv.sign @property def event_dim(self): return self._inv.event_dim @property def inv(self): return self._inv def __eq__(self, other): if not isinstance(other, _InverseTransform): return False return self._inv == other._inv def __call__(self, x): return self._inv._inv_call(x) def log_abs_det_jacobian(self, x, y): return -self._inv.log_abs_det_jacobian(y, x)
[docs]class ComposeTransform(Transform): """ Composes multiple transforms in a chain. The transforms being composed are responsible for caching. Args: parts (list of :class:`Transform`): A list of transforms to compose. """ def __init__(self, parts): super(ComposeTransform, self).__init__() self.parts = parts def __eq__(self, other): if not isinstance(other, ComposeTransform): return False return self.parts == other.parts @constraints.dependent_property def domain(self): if not self.parts: return constraints.real return self.parts[0].domain @constraints.dependent_property def codomain(self): if not self.parts: return constraints.real return self.parts[-1].codomain @lazy_property def bijective(self): return all(p.bijective for p in self.parts) @lazy_property def sign(self): sign = 1 for p in self.parts: sign = sign * p.sign return sign @lazy_property def event_dim(self): return max(p.event_dim for p in self.parts) if self.parts else 0 @property def inv(self): inv = None if self._inv is not None: inv = self._inv() if inv is None: inv = ComposeTransform([p.inv for p in reversed(self.parts)]) self._inv = weakref.ref(inv) inv._inv = weakref.ref(self) return inv def __call__(self, x): for part in self.parts: x = part(x) return x def log_abs_det_jacobian(self, x, y): if not self.parts: return torch.zeros_like(x) result = 0 for part in self.parts: y = part(x) result = result + _sum_rightmost(part.log_abs_det_jacobian(x, y), self.event_dim - part.event_dim) x = y return result def __repr__(self): fmt_string = self.__class__.__name__ + '(\n ' fmt_string += ',\n '.join([p.__repr__() for p in self.parts]) fmt_string += '\n)' return fmt_string
identity_transform = ComposeTransform([])
[docs]class ExpTransform(Transform): r""" Transform via the mapping :math:`y = \exp(x)`. """ domain = constraints.real codomain = constraints.positive bijective = True sign = +1 def __eq__(self, other): return isinstance(other, ExpTransform) def _call(self, x): return x.exp() def _inverse(self, y): return y.log() def log_abs_det_jacobian(self, x, y): return x
[docs]class PowerTransform(Transform): r""" Transform via the mapping :math:`y = x^{\text{exponent}}`. """ domain = constraints.positive codomain = constraints.positive bijective = True sign = +1 def __init__(self, exponent, cache_size=0): super(PowerTransform, self).__init__(cache_size=cache_size) self.exponent, = broadcast_all(exponent) def __eq__(self, other): if not isinstance(other, PowerTransform): return False return self.exponent.eq(other.exponent).all().item() def _call(self, x): return x.pow(self.exponent) def _inverse(self, y): return y.pow(1 / self.exponent) def log_abs_det_jacobian(self, x, y): return (self.exponent * y / x).abs().log()
[docs]class SigmoidTransform(Transform): r""" Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. """ domain = constraints.real codomain = constraints.unit_interval bijective = True sign = +1 def __eq__(self, other): return isinstance(other, SigmoidTransform) def _call(self, x): return torch.sigmoid(x) def _inverse(self, y): return y.log() - (-y).log1p() def log_abs_det_jacobian(self, x, y): return -(y.reciprocal() + (1 - y).reciprocal()).log()
[docs]class AbsTransform(Transform): r""" Transform via the mapping :math:`y = |x|`. """ domain = constraints.real codomain = constraints.positive def __eq__(self, other): return isinstance(other, AbsTransform) def _call(self, x): return x.abs() def _inverse(self, y): return y
[docs]class AffineTransform(Transform): r""" Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`. Args: loc (Tensor or float): Location parameter. scale (Tensor or float): Scale parameter. event_dim (int): Optional size of `event_shape`. This should be zero for univariate random variables, 1 for distributions over vectors, 2 for distributions over matrices, etc. """ domain = constraints.real codomain = constraints.real bijective = True def __init__(self, loc, scale, event_dim=0, cache_size=0): super(AffineTransform, self).__init__(cache_size=cache_size) self.loc = loc self.scale = scale self.event_dim = event_dim def __eq__(self, other): if not isinstance(other, AffineTransform): return False if isinstance(self.loc, numbers.Number) and isinstance(other.loc, numbers.Number): if self.loc != other.loc: return False else: if not (self.loc == other.loc).all().item(): return False if isinstance(self.scale, numbers.Number) and isinstance(other.scale, numbers.Number): if self.scale != other.scale: return False else: if not (self.scale == other.scale).all().item(): return False return True @property def sign(self): if isinstance(self.scale, numbers.Number): return 1 if self.scale > 0 else -1 if self.scale < 0 else 0 return self.scale.sign() def _call(self, x): return self.loc + self.scale * x def _inverse(self, y): return (y - self.loc) / self.scale def log_abs_det_jacobian(self, x, y): shape = x.shape scale = self.scale if isinstance(scale, numbers.Number): result = x.new_empty(shape).fill_(math.log(abs(scale))) else: result = torch.abs(scale).log() if self.event_dim: result_size = result.size()[:-self.event_dim] + (-1,) result = result.view(result_size).sum(-1) shape = shape[:-self.event_dim] return result.expand(shape)
[docs]class SoftmaxTransform(Transform): r""" Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then normalizing. This is not bijective and cannot be used for HMC. However this acts mostly coordinate-wise (except for the final normalization), and thus is appropriate for coordinate-wise optimization algorithms. """ domain = constraints.real codomain = constraints.simplex event_dim = 1 def __eq__(self, other): return isinstance(other, SoftmaxTransform) def _call(self, x): logprobs = x probs = (logprobs - logprobs.max(-1, True)[0]).exp() return probs / probs.sum(-1, True) def _inverse(self, y): probs = y return probs.log()
[docs]class StickBreakingTransform(Transform): """ Transform from unconstrained space to the simplex of one additional dimension via a stick-breaking process. This transform arises as an iterated sigmoid transform in a stick-breaking construction of the `Dirichlet` distribution: the first logit is transformed via sigmoid to the first probability and the probability of everything else, and then the process recurses. This is bijective and appropriate for use in HMC; however it mixes coordinates together and is less appropriate for optimization. """ domain = constraints.real codomain = constraints.simplex bijective = True event_dim = 1 def __eq__(self, other): return isinstance(other, StickBreakingTransform) def _call(self, x): offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) z = torch.sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) y = pad(z, (0, 1), value=1) * pad(z_cumprod, (1, 0), value=1) return y def _inverse(self, y): shape = y.shape[:-1] + (y.shape[-1] - 1,) offset = (shape[-1] + 1) - y.new([1]).expand(shape).cumsum(-1) sf = (1 - y.cumsum(-1))[..., :-1] x = y[..., :-1].log() - sf.log() + offset.log() return x def log_abs_det_jacobian(self, x, y): offset = (x.shape[-1] + 1) - x.new([1]).expand(x.shape).cumsum(-1) z = torch.sigmoid(x - offset.log()) detJ = ((1 - z).log() + y[..., :-1].log()).sum(-1) return detJ
[docs]class LowerCholeskyTransform(Transform): """ Transform from unconstrained matrices to lower-triangular matrices with nonnegative diagonal entries. This is useful for parameterizing positive definite matrices in terms of their Cholesky factorization. """ domain = constraints.real codomain = constraints.lower_cholesky event_dim = 2 def __eq__(self, other): return isinstance(other, LowerCholeskyTransform) def _call_on_event(self, x): return x.tril(-1) + x.diag().exp().diag() def _inverse_on_event(self, y): return y.tril(-1) + y.diag().log().diag() def _call(self, x): flat_x = x.contiguous().view((-1,) + x.shape[-2:]) return torch.stack([self._call_on_event(flat_x[i]) for i in range(flat_x.size(0))]).view(x.shape) def _inverse(self, y): flat_y = y.contiguous().view((-1,) + y.shape[-2:]) return torch.stack([self._inverse_on_event(flat_y[i]) for i in range(flat_y.size(0))]).view(y.shape)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources