transformed_distribution.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import torch
  2. from torch.distributions import constraints
  3. from torch.distributions.distribution import Distribution
  4. from torch.distributions.independent import Independent
  5. from torch.distributions.transforms import ComposeTransform, Transform
  6. from torch.distributions.utils import _sum_rightmost
  7. from typing import Dict
  8. __all__ = ['TransformedDistribution']
  9. class TransformedDistribution(Distribution):
  10. r"""
  11. Extension of the Distribution class, which applies a sequence of Transforms
  12. to a base distribution. Let f be the composition of transforms applied::
  13. X ~ BaseDistribution
  14. Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
  15. log p(Y) = log p(X) + log |det (dX/dY)|
  16. Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
  17. maximum shape of its base distribution and its transforms, since transforms
  18. can introduce correlations among events.
  19. An example for the usage of :class:`TransformedDistribution` would be::
  20. # Building a Logistic Distribution
  21. # X ~ Uniform(0, 1)
  22. # f = a + b * logit(X)
  23. # Y ~ f(X) ~ Logistic(a, b)
  24. base_distribution = Uniform(0, 1)
  25. transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
  26. logistic = TransformedDistribution(base_distribution, transforms)
  27. For more examples, please look at the implementations of
  28. :class:`~torch.distributions.gumbel.Gumbel`,
  29. :class:`~torch.distributions.half_cauchy.HalfCauchy`,
  30. :class:`~torch.distributions.half_normal.HalfNormal`,
  31. :class:`~torch.distributions.log_normal.LogNormal`,
  32. :class:`~torch.distributions.pareto.Pareto`,
  33. :class:`~torch.distributions.weibull.Weibull`,
  34. :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
  35. :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
  36. """
  37. arg_constraints: Dict[str, constraints.Constraint] = {}
  38. def __init__(self, base_distribution, transforms, validate_args=None):
  39. if isinstance(transforms, Transform):
  40. self.transforms = [transforms, ]
  41. elif isinstance(transforms, list):
  42. if not all(isinstance(t, Transform) for t in transforms):
  43. raise ValueError("transforms must be a Transform or a list of Transforms")
  44. self.transforms = transforms
  45. else:
  46. raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))
  47. # Reshape base_distribution according to transforms.
  48. base_shape = base_distribution.batch_shape + base_distribution.event_shape
  49. base_event_dim = len(base_distribution.event_shape)
  50. transform = ComposeTransform(self.transforms)
  51. if len(base_shape) < transform.domain.event_dim:
  52. raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
  53. .format(transform.domain.event_dim, base_shape))
  54. forward_shape = transform.forward_shape(base_shape)
  55. expanded_base_shape = transform.inverse_shape(forward_shape)
  56. if base_shape != expanded_base_shape:
  57. base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim]
  58. base_distribution = base_distribution.expand(base_batch_shape)
  59. reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
  60. if reinterpreted_batch_ndims > 0:
  61. base_distribution = Independent(base_distribution, reinterpreted_batch_ndims)
  62. self.base_dist = base_distribution
  63. # Compute shapes.
  64. transform_change_in_event_dim = transform.codomain.event_dim - transform.domain.event_dim
  65. event_dim = max(
  66. transform.codomain.event_dim, # the transform is coupled
  67. base_event_dim + transform_change_in_event_dim # the base dist is coupled
  68. )
  69. assert len(forward_shape) >= event_dim
  70. cut = len(forward_shape) - event_dim
  71. batch_shape = forward_shape[:cut]
  72. event_shape = forward_shape[cut:]
  73. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  74. def expand(self, batch_shape, _instance=None):
  75. new = self._get_checked_instance(TransformedDistribution, _instance)
  76. batch_shape = torch.Size(batch_shape)
  77. shape = batch_shape + self.event_shape
  78. for t in reversed(self.transforms):
  79. shape = t.inverse_shape(shape)
  80. base_batch_shape = shape[:len(shape) - len(self.base_dist.event_shape)]
  81. new.base_dist = self.base_dist.expand(base_batch_shape)
  82. new.transforms = self.transforms
  83. super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
  84. new._validate_args = self._validate_args
  85. return new
  86. @constraints.dependent_property(is_discrete=False)
  87. def support(self):
  88. if not self.transforms:
  89. return self.base_dist.support
  90. support = self.transforms[-1].codomain
  91. if len(self.event_shape) > support.event_dim:
  92. support = constraints.independent(support, len(self.event_shape) - support.event_dim)
  93. return support
  94. @property
  95. def has_rsample(self):
  96. return self.base_dist.has_rsample
  97. def sample(self, sample_shape=torch.Size()):
  98. """
  99. Generates a sample_shape shaped sample or sample_shape shaped batch of
  100. samples if the distribution parameters are batched. Samples first from
  101. base distribution and applies `transform()` for every transform in the
  102. list.
  103. """
  104. with torch.no_grad():
  105. x = self.base_dist.sample(sample_shape)
  106. for transform in self.transforms:
  107. x = transform(x)
  108. return x
  109. def rsample(self, sample_shape=torch.Size()):
  110. """
  111. Generates a sample_shape shaped reparameterized sample or sample_shape
  112. shaped batch of reparameterized samples if the distribution parameters
  113. are batched. Samples first from base distribution and applies
  114. `transform()` for every transform in the list.
  115. """
  116. x = self.base_dist.rsample(sample_shape)
  117. for transform in self.transforms:
  118. x = transform(x)
  119. return x
  120. def log_prob(self, value):
  121. """
  122. Scores the sample by inverting the transform(s) and computing the score
  123. using the score of the base distribution and the log abs det jacobian.
  124. """
  125. if self._validate_args:
  126. self._validate_sample(value)
  127. event_dim = len(self.event_shape)
  128. log_prob = 0.0
  129. y = value
  130. for transform in reversed(self.transforms):
  131. x = transform.inv(y)
  132. event_dim += transform.domain.event_dim - transform.codomain.event_dim
  133. log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
  134. event_dim - transform.domain.event_dim)
  135. y = x
  136. log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
  137. event_dim - len(self.base_dist.event_shape))
  138. return log_prob
  139. def _monotonize_cdf(self, value):
  140. """
  141. This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
  142. monotone increasing.
  143. """
  144. sign = 1
  145. for transform in self.transforms:
  146. sign = sign * transform.sign
  147. if isinstance(sign, int) and sign == 1:
  148. return value
  149. return sign * (value - 0.5) + 0.5
  150. def cdf(self, value):
  151. """
  152. Computes the cumulative distribution function by inverting the
  153. transform(s) and computing the score of the base distribution.
  154. """
  155. for transform in self.transforms[::-1]:
  156. value = transform.inv(value)
  157. if self._validate_args:
  158. self.base_dist._validate_sample(value)
  159. value = self.base_dist.cdf(value)
  160. value = self._monotonize_cdf(value)
  161. return value
  162. def icdf(self, value):
  163. """
  164. Computes the inverse cumulative distribution function using
  165. transform(s) and computing the score of the base distribution.
  166. """
  167. value = self._monotonize_cdf(value)
  168. value = self.base_dist.icdf(value)
  169. for transform in self.transforms:
  170. value = transform(value)
  171. return value