123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271 |
- r"""
- PyTorch provides two global :class:`ConstraintRegistry` objects that link
- :class:`~torch.distributions.constraints.Constraint` objects to
- :class:`~torch.distributions.transforms.Transform` objects. These objects both
- input constraints and return transforms, but they have different guarantees on
- bijectivity.
- 1. ``biject_to(constraint)`` looks up a bijective
- :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
- to the given ``constraint``. The returned transform is guaranteed to have
- ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
- 2. ``transform_to(constraint)`` looks up a not-necessarily bijective
- :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
- to the given ``constraint``. The returned transform is not guaranteed to
- implement ``.log_abs_det_jacobian()``.
- The ``transform_to()`` registry is useful for performing unconstrained
- optimization on constrained parameters of probability distributions, which are
- indicated by each distribution's ``.arg_constraints`` dict. These transforms often
- overparameterize a space in order to avoid rotation; they are thus more
- suitable for coordinate-wise optimization algorithms like Adam::
- loc = torch.zeros(100, requires_grad=True)
- unconstrained = torch.zeros(100, requires_grad=True)
- scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
- loss = -Normal(loc, scale).log_prob(data).sum()
- The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
- samples from a probability distribution with constrained ``.support`` are
- propagated in an unconstrained space, and algorithms are typically rotation
- invariant.::
- dist = Exponential(rate)
- unconstrained = torch.zeros(100, requires_grad=True)
- sample = biject_to(dist.support)(unconstrained)
- potential_energy = -dist.log_prob(sample).sum()
- .. note::
- An example where ``transform_to`` and ``biject_to`` differ is
- ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
- :class:`~torch.distributions.transforms.SoftmaxTransform` that simply
- exponentiates and normalizes its inputs; this is a cheap and mostly
- coordinate-wise operation appropriate for algorithms like SVI. In
- contrast, ``biject_to(constraints.simplex)`` returns a
- :class:`~torch.distributions.transforms.StickBreakingTransform` that
- bijects its input down to a one-fewer-dimensional space; this a more
- expensive less numerically stable transform but is needed for algorithms
- like HMC.
- The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
- constraints and transforms using their ``.register()`` method either as a
- function on singleton constraints::
- transform_to.register(my_constraint, my_transform)
- or as a decorator on parameterized constraints::
- @transform_to.register(MyConstraintClass)
- def my_factory(constraint):
- assert isinstance(constraint, MyConstraintClass)
- return MyTransform(constraint.param1, constraint.param2)
- You can create your own registry by creating a new :class:`ConstraintRegistry`
- object.
- """
- import numbers
- from torch.distributions import constraints, transforms
- __all__ = [
- 'ConstraintRegistry',
- 'biject_to',
- 'transform_to',
- ]
- class ConstraintRegistry:
- """
- Registry to link constraints to transforms.
- """
- def __init__(self):
- self._registry = {}
- super().__init__()
- def register(self, constraint, factory=None):
- """
- Registers a :class:`~torch.distributions.constraints.Constraint`
- subclass in this registry. Usage::
- @my_registry.register(MyConstraintClass)
- def construct_transform(constraint):
- assert isinstance(constraint, MyConstraint)
- return MyTransform(constraint.arg_constraints)
- Args:
- constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
- A subclass of :class:`~torch.distributions.constraints.Constraint`, or
- a singleton object of the desired class.
- factory (Callable): A callable that inputs a constraint object and returns
- a :class:`~torch.distributions.transforms.Transform` object.
- """
- # Support use as decorator.
- if factory is None:
- return lambda factory: self.register(constraint, factory)
- # Support calling on singleton instances.
- if isinstance(constraint, constraints.Constraint):
- constraint = type(constraint)
- if not isinstance(constraint, type) or not issubclass(constraint, constraints.Constraint):
- raise TypeError('Expected constraint to be either a Constraint subclass or instance, '
- 'but got {}'.format(constraint))
- self._registry[constraint] = factory
- return factory
- def __call__(self, constraint):
- """
- Looks up a transform to constrained space, given a constraint object.
- Usage::
- constraint = Normal.arg_constraints['scale']
- scale = transform_to(constraint)(torch.zeros(1)) # constrained
- u = transform_to(constraint).inv(scale) # unconstrained
- Args:
- constraint (:class:`~torch.distributions.constraints.Constraint`):
- A constraint object.
- Returns:
- A :class:`~torch.distributions.transforms.Transform` object.
- Raises:
- `NotImplementedError` if no transform has been registered.
- """
- # Look up by Constraint subclass.
- try:
- factory = self._registry[type(constraint)]
- except KeyError:
- raise NotImplementedError(
- f'Cannot transform {type(constraint).__name__} constraints') from None
- return factory(constraint)
- biject_to = ConstraintRegistry()
- transform_to = ConstraintRegistry()
- ################################################################################
- # Registration Table
- ################################################################################
- @biject_to.register(constraints.real)
- @transform_to.register(constraints.real)
- def _transform_to_real(constraint):
- return transforms.identity_transform
- @biject_to.register(constraints.independent)
- def _biject_to_independent(constraint):
- base_transform = biject_to(constraint.base_constraint)
- return transforms.IndependentTransform(
- base_transform, constraint.reinterpreted_batch_ndims)
- @transform_to.register(constraints.independent)
- def _transform_to_independent(constraint):
- base_transform = transform_to(constraint.base_constraint)
- return transforms.IndependentTransform(
- base_transform, constraint.reinterpreted_batch_ndims)
- @biject_to.register(constraints.positive)
- @biject_to.register(constraints.nonnegative)
- @transform_to.register(constraints.positive)
- @transform_to.register(constraints.nonnegative)
- def _transform_to_positive(constraint):
- return transforms.ExpTransform()
- @biject_to.register(constraints.greater_than)
- @biject_to.register(constraints.greater_than_eq)
- @transform_to.register(constraints.greater_than)
- @transform_to.register(constraints.greater_than_eq)
- def _transform_to_greater_than(constraint):
- return transforms.ComposeTransform([transforms.ExpTransform(),
- transforms.AffineTransform(constraint.lower_bound, 1)])
- @biject_to.register(constraints.less_than)
- @transform_to.register(constraints.less_than)
- def _transform_to_less_than(constraint):
- return transforms.ComposeTransform([transforms.ExpTransform(),
- transforms.AffineTransform(constraint.upper_bound, -1)])
- @biject_to.register(constraints.interval)
- @biject_to.register(constraints.half_open_interval)
- @transform_to.register(constraints.interval)
- @transform_to.register(constraints.half_open_interval)
- def _transform_to_interval(constraint):
- # Handle the special case of the unit interval.
- lower_is_0 = isinstance(constraint.lower_bound, numbers.Number) and constraint.lower_bound == 0
- upper_is_1 = isinstance(constraint.upper_bound, numbers.Number) and constraint.upper_bound == 1
- if lower_is_0 and upper_is_1:
- return transforms.SigmoidTransform()
- loc = constraint.lower_bound
- scale = constraint.upper_bound - constraint.lower_bound
- return transforms.ComposeTransform([transforms.SigmoidTransform(),
- transforms.AffineTransform(loc, scale)])
- @biject_to.register(constraints.simplex)
- def _biject_to_simplex(constraint):
- return transforms.StickBreakingTransform()
- @transform_to.register(constraints.simplex)
- def _transform_to_simplex(constraint):
- return transforms.SoftmaxTransform()
- # TODO define a bijection for LowerCholeskyTransform
- @transform_to.register(constraints.lower_cholesky)
- def _transform_to_lower_cholesky(constraint):
- return transforms.LowerCholeskyTransform()
- @transform_to.register(constraints.positive_definite)
- @transform_to.register(constraints.positive_semidefinite)
- def _transform_to_positive_definite(constraint):
- return transforms.PositiveDefiniteTransform()
- @biject_to.register(constraints.corr_cholesky)
- @transform_to.register(constraints.corr_cholesky)
- def _transform_to_corr_cholesky(constraint):
- return transforms.CorrCholeskyTransform()
- @biject_to.register(constraints.cat)
- def _biject_to_cat(constraint):
- return transforms.CatTransform([biject_to(c)
- for c in constraint.cseq],
- constraint.dim,
- constraint.lengths)
- @transform_to.register(constraints.cat)
- def _transform_to_cat(constraint):
- return transforms.CatTransform([transform_to(c)
- for c in constraint.cseq],
- constraint.dim,
- constraint.lengths)
- @biject_to.register(constraints.stack)
- def _biject_to_stack(constraint):
- return transforms.StackTransform(
- [biject_to(c)
- for c in constraint.cseq], constraint.dim)
- @transform_to.register(constraints.stack)
- def _transform_to_stack(constraint):
- return transforms.StackTransform(
- [transform_to(c)
- for c in constraint.cseq], constraint.dim)
|