Shortcuts

Source code for torch.distributions.constraints

r"""
The following constraints are implemented:

- ``constraints.boolean``
- ``constraints.dependent``
- ``constraints.greater_than(lower_bound)``
- ``constraints.integer_interval(lower_bound, upper_bound)``
- ``constraints.interval(lower_bound, upper_bound)``
- ``constraints.lower_cholesky``
- ``constraints.lower_triangular``
- ``constraints.nonnegative_integer``
- ``constraints.positive``
- ``constraints.positive_definite``
- ``constraints.positive_integer``
- ``constraints.real``
- ``constraints.real_vector``
- ``constraints.simplex``
- ``constraints.unit_interval``
"""

import torch
from torch.distributions.utils import batch_tril

__all__ = [
    'Constraint',
    'boolean',
    'dependent',
    'dependent_property',
    'greater_than',
    'greater_than_eq',
    'integer_interval',
    'interval',
    'half_open_interval',
    'is_dependent',
    'less_than',
    'lower_cholesky',
    'lower_triangular',
    'nonnegative_integer',
    'positive',
    'positive_definite',
    'positive_integer',
    'real',
    'real_vector',
    'simplex',
    'unit_interval',
]


[docs]class Constraint(object): """ Abstract base class for constraints. A constraint object represents a region over which a variable is valid, e.g. within which a variable can be optimized. """
[docs] def check(self, value): """ Returns a byte tensor of `sample_shape + batch_shape` indicating whether each event in value satisfies this constraint. """ raise NotImplementedError
def __repr__(self): return self.__class__.__name__[1:] + '()'
class _Dependent(Constraint): """ Placeholder for variables whose support depends on other variables. These variables obey no simple coordinate-wise constraints. """ def check(self, x): raise ValueError('Cannot determine validity of dependent constraint') def is_dependent(constraint): return isinstance(constraint, _Dependent) class _DependentProperty(property, _Dependent): """ Decorator that extends @property to act like a `Dependent` constraint when called on a class and act like a property when called on an object. Example:: class Uniform(Distribution): def __init__(self, low, high): self.low = low self.high = high @constraints.dependent_property def support(self): return constraints.interval(self.low, self.high) """ pass class _Boolean(Constraint): """ Constrain to the two values `{0, 1}`. """ def check(self, value): return (value == 0) | (value == 1) class _IntegerInterval(Constraint): """ Constrain to an integer interval `[lower_bound, upper_bound]`. """ def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound def check(self, value): return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) def __repr__(self): fmt_string = self.__class__.__name__[1:] fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound) return fmt_string class _IntegerLessThan(Constraint): """ Constrain to an integer interval `(-inf, upper_bound]`. """ def __init__(self, upper_bound): self.upper_bound = upper_bound def check(self, value): return (value % 1 == 0) & (value <= self.upper_bound) def __repr__(self): fmt_string = self.__class__.__name__[1:] fmt_string += '(upper_bound={})'.format(self.upper_bound) return fmt_string class _IntegerGreaterThan(Constraint): """ Constrain to an integer interval `[lower_bound, inf)`. """ def __init__(self, lower_bound): self.lower_bound = lower_bound def check(self, value): return (value % 1 == 0) & (value >= self.lower_bound) def __repr__(self): fmt_string = self.__class__.__name__[1:] fmt_string += '(lower_bound={})'.format(self.lower_bound) return fmt_string class _Real(Constraint): """ Trivially constrain to the extended real line `[-inf, inf]`. """ def check(self, value): return value == value # False for NANs. class _GreaterThan(Constraint): """ Constrain to a real half line `(lower_bound, inf]`. """ def __init__(self, lower_bound): self.lower_bound = lower_bound def check(self, value): return self.lower_bound < value def __repr__(self): fmt_string = self.__class__.__name__[1:] fmt_string += '(lower_bound={})'.format(self.lower_bound) return fmt_string class _GreaterThanEq(Constraint): """ Constrain to a real half line `[lower_bound, inf)`. """ def __init__(self, lower_bound): self.lower_bound = lower_bound def check(self, value): return self.lower_bound <= value def __repr__(self): fmt_string = self.__class__.__name__[1:] fmt_string += '(lower_bound={})'.format(self.lower_bound) return fmt_string class _LessThan(Constraint): """ Constrain to a real half line `[-inf, upper_bound)`. """ def __init__(self, upper_bound): self.upper_bound = upper_bound def check(self, value): return value < self.upper_bound def __repr__(self): fmt_string = self.__class__.__name__[1:] fmt_string += '(upper_bound={})'.format(self.upper_bound) return fmt_string class _Interval(Constraint): """ Constrain to a real interval `[lower_bound, upper_bound]`. """ def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound def check(self, value): return (self.lower_bound <= value) & (value <= self.upper_bound) def __repr__(self): fmt_string = self.__class__.__name__[1:] fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound) return fmt_string class _HalfOpenInterval(Constraint): """ Constrain to a real interval `[lower_bound, upper_bound)`. """ def __init__(self, lower_bound, upper_bound): self.lower_bound = lower_bound self.upper_bound = upper_bound def check(self, value): return (self.lower_bound <= value) & (value < self.upper_bound) def __repr__(self): fmt_string = self.__class__.__name__[1:] fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound) return fmt_string class _Simplex(Constraint): """ Constrain to the unit simplex in the innermost (rightmost) dimension. Specifically: `x >= 0` and `x.sum(-1) == 1`. """ def check(self, value): return (value >= 0).all() & ((value.sum(-1, True) - 1).abs() < 1e-6).all() class _LowerTriangular(Constraint): """ Constrain to lower-triangular square matrices. """ def check(self, value): value_tril = batch_tril(value) return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] class _LowerCholesky(Constraint): """ Constrain to lower-triangular square matrices with positive diagonals. """ def check(self, value): value_tril = batch_tril(value) lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] n = value.size(-1) diag_mask = torch.eye(n, n, dtype=value.dtype, device=value.device) positive_diagonal = (value * diag_mask > (diag_mask - 1)).min(-1)[0].min(-1)[0] return lower_triangular & positive_diagonal class _PositiveDefinite(Constraint): """ Constrain to positive-definite matrices. """ def check(self, value): matrix_shape = value.shape[-2:] batch_shape = value.unsqueeze(0).shape[:-2] # TODO: replace with batched linear algebra routine when one becomes available # note that `symeig()` returns eigenvalues in ascending order flattened_value = value.reshape((-1,) + matrix_shape) return torch.stack([v.symeig(eigenvectors=False)[0][:1] > 0.0 for v in flattened_value]).view(batch_shape) class _RealVector(Constraint): """ Constrain to real-valued vectors. This is the same as `constraints.real`, but additionally reduces across the `event_shape` dimension. """ def check(self, value): return (value == value).all() # False for NANs. # Public interface. dependent = _Dependent() dependent_property = _DependentProperty boolean = _Boolean() nonnegative_integer = _IntegerGreaterThan(0) positive_integer = _IntegerGreaterThan(1) integer_interval = _IntegerInterval real = _Real() real_vector = _RealVector() positive = _GreaterThan(0.) greater_than = _GreaterThan greater_than_eq = _GreaterThanEq less_than = _LessThan unit_interval = _Interval(0., 1.) interval = _Interval half_open_interval = _HalfOpenInterval simplex = _Simplex() lower_triangular = _LowerTriangular() lower_cholesky = _LowerCholesky() positive_definite = _PositiveDefinite()

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources