constraint_registry.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. r"""
  2. PyTorch provides two global :class:`ConstraintRegistry` objects that link
  3. :class:`~torch.distributions.constraints.Constraint` objects to
  4. :class:`~torch.distributions.transforms.Transform` objects. These objects both
  5. input constraints and return transforms, but they have different guarantees on
  6. bijectivity.
  7. 1. ``biject_to(constraint)`` looks up a bijective
  8. :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
  9. to the given ``constraint``. The returned transform is guaranteed to have
  10. ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
  11. 2. ``transform_to(constraint)`` looks up a not-necessarily bijective
  12. :class:`~torch.distributions.transforms.Transform` from ``constraints.real``
  13. to the given ``constraint``. The returned transform is not guaranteed to
  14. implement ``.log_abs_det_jacobian()``.
  15. The ``transform_to()`` registry is useful for performing unconstrained
  16. optimization on constrained parameters of probability distributions, which are
  17. indicated by each distribution's ``.arg_constraints`` dict. These transforms often
  18. overparameterize a space in order to avoid rotation; they are thus more
  19. suitable for coordinate-wise optimization algorithms like Adam::
  20. loc = torch.zeros(100, requires_grad=True)
  21. unconstrained = torch.zeros(100, requires_grad=True)
  22. scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
  23. loss = -Normal(loc, scale).log_prob(data).sum()
  24. The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
  25. samples from a probability distribution with constrained ``.support`` are
  26. propagated in an unconstrained space, and algorithms are typically rotation
  27. invariant.::
  28. dist = Exponential(rate)
  29. unconstrained = torch.zeros(100, requires_grad=True)
  30. sample = biject_to(dist.support)(unconstrained)
  31. potential_energy = -dist.log_prob(sample).sum()
  32. .. note::
  33. An example where ``transform_to`` and ``biject_to`` differ is
  34. ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
  35. :class:`~torch.distributions.transforms.SoftmaxTransform` that simply
  36. exponentiates and normalizes its inputs; this is a cheap and mostly
  37. coordinate-wise operation appropriate for algorithms like SVI. In
  38. contrast, ``biject_to(constraints.simplex)`` returns a
  39. :class:`~torch.distributions.transforms.StickBreakingTransform` that
  40. bijects its input down to a one-fewer-dimensional space; this a more
  41. expensive less numerically stable transform but is needed for algorithms
  42. like HMC.
  43. The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
  44. constraints and transforms using their ``.register()`` method either as a
  45. function on singleton constraints::
  46. transform_to.register(my_constraint, my_transform)
  47. or as a decorator on parameterized constraints::
  48. @transform_to.register(MyConstraintClass)
  49. def my_factory(constraint):
  50. assert isinstance(constraint, MyConstraintClass)
  51. return MyTransform(constraint.param1, constraint.param2)
  52. You can create your own registry by creating a new :class:`ConstraintRegistry`
  53. object.
  54. """
  55. import numbers
  56. from torch.distributions import constraints, transforms
  57. __all__ = [
  58. 'ConstraintRegistry',
  59. 'biject_to',
  60. 'transform_to',
  61. ]
  62. class ConstraintRegistry:
  63. """
  64. Registry to link constraints to transforms.
  65. """
  66. def __init__(self):
  67. self._registry = {}
  68. super().__init__()
  69. def register(self, constraint, factory=None):
  70. """
  71. Registers a :class:`~torch.distributions.constraints.Constraint`
  72. subclass in this registry. Usage::
  73. @my_registry.register(MyConstraintClass)
  74. def construct_transform(constraint):
  75. assert isinstance(constraint, MyConstraint)
  76. return MyTransform(constraint.arg_constraints)
  77. Args:
  78. constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
  79. A subclass of :class:`~torch.distributions.constraints.Constraint`, or
  80. a singleton object of the desired class.
  81. factory (Callable): A callable that inputs a constraint object and returns
  82. a :class:`~torch.distributions.transforms.Transform` object.
  83. """
  84. # Support use as decorator.
  85. if factory is None:
  86. return lambda factory: self.register(constraint, factory)
  87. # Support calling on singleton instances.
  88. if isinstance(constraint, constraints.Constraint):
  89. constraint = type(constraint)
  90. if not isinstance(constraint, type) or not issubclass(constraint, constraints.Constraint):
  91. raise TypeError('Expected constraint to be either a Constraint subclass or instance, '
  92. 'but got {}'.format(constraint))
  93. self._registry[constraint] = factory
  94. return factory
  95. def __call__(self, constraint):
  96. """
  97. Looks up a transform to constrained space, given a constraint object.
  98. Usage::
  99. constraint = Normal.arg_constraints['scale']
  100. scale = transform_to(constraint)(torch.zeros(1)) # constrained
  101. u = transform_to(constraint).inv(scale) # unconstrained
  102. Args:
  103. constraint (:class:`~torch.distributions.constraints.Constraint`):
  104. A constraint object.
  105. Returns:
  106. A :class:`~torch.distributions.transforms.Transform` object.
  107. Raises:
  108. `NotImplementedError` if no transform has been registered.
  109. """
  110. # Look up by Constraint subclass.
  111. try:
  112. factory = self._registry[type(constraint)]
  113. except KeyError:
  114. raise NotImplementedError(
  115. f'Cannot transform {type(constraint).__name__} constraints') from None
  116. return factory(constraint)
  117. biject_to = ConstraintRegistry()
  118. transform_to = ConstraintRegistry()
  119. ################################################################################
  120. # Registration Table
  121. ################################################################################
  122. @biject_to.register(constraints.real)
  123. @transform_to.register(constraints.real)
  124. def _transform_to_real(constraint):
  125. return transforms.identity_transform
  126. @biject_to.register(constraints.independent)
  127. def _biject_to_independent(constraint):
  128. base_transform = biject_to(constraint.base_constraint)
  129. return transforms.IndependentTransform(
  130. base_transform, constraint.reinterpreted_batch_ndims)
  131. @transform_to.register(constraints.independent)
  132. def _transform_to_independent(constraint):
  133. base_transform = transform_to(constraint.base_constraint)
  134. return transforms.IndependentTransform(
  135. base_transform, constraint.reinterpreted_batch_ndims)
  136. @biject_to.register(constraints.positive)
  137. @biject_to.register(constraints.nonnegative)
  138. @transform_to.register(constraints.positive)
  139. @transform_to.register(constraints.nonnegative)
  140. def _transform_to_positive(constraint):
  141. return transforms.ExpTransform()
  142. @biject_to.register(constraints.greater_than)
  143. @biject_to.register(constraints.greater_than_eq)
  144. @transform_to.register(constraints.greater_than)
  145. @transform_to.register(constraints.greater_than_eq)
  146. def _transform_to_greater_than(constraint):
  147. return transforms.ComposeTransform([transforms.ExpTransform(),
  148. transforms.AffineTransform(constraint.lower_bound, 1)])
  149. @biject_to.register(constraints.less_than)
  150. @transform_to.register(constraints.less_than)
  151. def _transform_to_less_than(constraint):
  152. return transforms.ComposeTransform([transforms.ExpTransform(),
  153. transforms.AffineTransform(constraint.upper_bound, -1)])
  154. @biject_to.register(constraints.interval)
  155. @biject_to.register(constraints.half_open_interval)
  156. @transform_to.register(constraints.interval)
  157. @transform_to.register(constraints.half_open_interval)
  158. def _transform_to_interval(constraint):
  159. # Handle the special case of the unit interval.
  160. lower_is_0 = isinstance(constraint.lower_bound, numbers.Number) and constraint.lower_bound == 0
  161. upper_is_1 = isinstance(constraint.upper_bound, numbers.Number) and constraint.upper_bound == 1
  162. if lower_is_0 and upper_is_1:
  163. return transforms.SigmoidTransform()
  164. loc = constraint.lower_bound
  165. scale = constraint.upper_bound - constraint.lower_bound
  166. return transforms.ComposeTransform([transforms.SigmoidTransform(),
  167. transforms.AffineTransform(loc, scale)])
  168. @biject_to.register(constraints.simplex)
  169. def _biject_to_simplex(constraint):
  170. return transforms.StickBreakingTransform()
  171. @transform_to.register(constraints.simplex)
  172. def _transform_to_simplex(constraint):
  173. return transforms.SoftmaxTransform()
  174. # TODO define a bijection for LowerCholeskyTransform
  175. @transform_to.register(constraints.lower_cholesky)
  176. def _transform_to_lower_cholesky(constraint):
  177. return transforms.LowerCholeskyTransform()
  178. @transform_to.register(constraints.positive_definite)
  179. @transform_to.register(constraints.positive_semidefinite)
  180. def _transform_to_positive_definite(constraint):
  181. return transforms.PositiveDefiniteTransform()
  182. @biject_to.register(constraints.corr_cholesky)
  183. @transform_to.register(constraints.corr_cholesky)
  184. def _transform_to_corr_cholesky(constraint):
  185. return transforms.CorrCholeskyTransform()
  186. @biject_to.register(constraints.cat)
  187. def _biject_to_cat(constraint):
  188. return transforms.CatTransform([biject_to(c)
  189. for c in constraint.cseq],
  190. constraint.dim,
  191. constraint.lengths)
  192. @transform_to.register(constraints.cat)
  193. def _transform_to_cat(constraint):
  194. return transforms.CatTransform([transform_to(c)
  195. for c in constraint.cseq],
  196. constraint.dim,
  197. constraint.lengths)
  198. @biject_to.register(constraints.stack)
  199. def _biject_to_stack(constraint):
  200. return transforms.StackTransform(
  201. [biject_to(c)
  202. for c in constraint.cseq], constraint.dim)
  203. @transform_to.register(constraints.stack)
  204. def _transform_to_stack(constraint):
  205. return transforms.StackTransform(
  206. [transform_to(c)
  207. for c in constraint.cseq], constraint.dim)