123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- import torch
- from torch.distributions import constraints
- from torch.distributions.distribution import Distribution
- from torch.distributions.independent import Independent
- from torch.distributions.transforms import ComposeTransform, Transform
- from torch.distributions.utils import _sum_rightmost
- from typing import Dict
- __all__ = ['TransformedDistribution']
- class TransformedDistribution(Distribution):
- r"""
- Extension of the Distribution class, which applies a sequence of Transforms
- to a base distribution. Let f be the composition of transforms applied::
- X ~ BaseDistribution
- Y = f(X) ~ TransformedDistribution(BaseDistribution, f)
- log p(Y) = log p(X) + log |det (dX/dY)|
- Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the
- maximum shape of its base distribution and its transforms, since transforms
- can introduce correlations among events.
- An example for the usage of :class:`TransformedDistribution` would be::
- # Building a Logistic Distribution
- # X ~ Uniform(0, 1)
- # f = a + b * logit(X)
- # Y ~ f(X) ~ Logistic(a, b)
- base_distribution = Uniform(0, 1)
- transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)]
- logistic = TransformedDistribution(base_distribution, transforms)
- For more examples, please look at the implementations of
- :class:`~torch.distributions.gumbel.Gumbel`,
- :class:`~torch.distributions.half_cauchy.HalfCauchy`,
- :class:`~torch.distributions.half_normal.HalfNormal`,
- :class:`~torch.distributions.log_normal.LogNormal`,
- :class:`~torch.distributions.pareto.Pareto`,
- :class:`~torch.distributions.weibull.Weibull`,
- :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and
- :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical`
- """
- arg_constraints: Dict[str, constraints.Constraint] = {}
- def __init__(self, base_distribution, transforms, validate_args=None):
- if isinstance(transforms, Transform):
- self.transforms = [transforms, ]
- elif isinstance(transforms, list):
- if not all(isinstance(t, Transform) for t in transforms):
- raise ValueError("transforms must be a Transform or a list of Transforms")
- self.transforms = transforms
- else:
- raise ValueError("transforms must be a Transform or list, but was {}".format(transforms))
- # Reshape base_distribution according to transforms.
- base_shape = base_distribution.batch_shape + base_distribution.event_shape
- base_event_dim = len(base_distribution.event_shape)
- transform = ComposeTransform(self.transforms)
- if len(base_shape) < transform.domain.event_dim:
- raise ValueError("base_distribution needs to have shape with size at least {}, but got {}."
- .format(transform.domain.event_dim, base_shape))
- forward_shape = transform.forward_shape(base_shape)
- expanded_base_shape = transform.inverse_shape(forward_shape)
- if base_shape != expanded_base_shape:
- base_batch_shape = expanded_base_shape[:len(expanded_base_shape) - base_event_dim]
- base_distribution = base_distribution.expand(base_batch_shape)
- reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim
- if reinterpreted_batch_ndims > 0:
- base_distribution = Independent(base_distribution, reinterpreted_batch_ndims)
- self.base_dist = base_distribution
- # Compute shapes.
- transform_change_in_event_dim = transform.codomain.event_dim - transform.domain.event_dim
- event_dim = max(
- transform.codomain.event_dim, # the transform is coupled
- base_event_dim + transform_change_in_event_dim # the base dist is coupled
- )
- assert len(forward_shape) >= event_dim
- cut = len(forward_shape) - event_dim
- batch_shape = forward_shape[:cut]
- event_shape = forward_shape[cut:]
- super().__init__(batch_shape, event_shape, validate_args=validate_args)
- def expand(self, batch_shape, _instance=None):
- new = self._get_checked_instance(TransformedDistribution, _instance)
- batch_shape = torch.Size(batch_shape)
- shape = batch_shape + self.event_shape
- for t in reversed(self.transforms):
- shape = t.inverse_shape(shape)
- base_batch_shape = shape[:len(shape) - len(self.base_dist.event_shape)]
- new.base_dist = self.base_dist.expand(base_batch_shape)
- new.transforms = self.transforms
- super(TransformedDistribution, new).__init__(batch_shape, self.event_shape, validate_args=False)
- new._validate_args = self._validate_args
- return new
- @constraints.dependent_property(is_discrete=False)
- def support(self):
- if not self.transforms:
- return self.base_dist.support
- support = self.transforms[-1].codomain
- if len(self.event_shape) > support.event_dim:
- support = constraints.independent(support, len(self.event_shape) - support.event_dim)
- return support
- @property
- def has_rsample(self):
- return self.base_dist.has_rsample
- def sample(self, sample_shape=torch.Size()):
- """
- Generates a sample_shape shaped sample or sample_shape shaped batch of
- samples if the distribution parameters are batched. Samples first from
- base distribution and applies `transform()` for every transform in the
- list.
- """
- with torch.no_grad():
- x = self.base_dist.sample(sample_shape)
- for transform in self.transforms:
- x = transform(x)
- return x
- def rsample(self, sample_shape=torch.Size()):
- """
- Generates a sample_shape shaped reparameterized sample or sample_shape
- shaped batch of reparameterized samples if the distribution parameters
- are batched. Samples first from base distribution and applies
- `transform()` for every transform in the list.
- """
- x = self.base_dist.rsample(sample_shape)
- for transform in self.transforms:
- x = transform(x)
- return x
- def log_prob(self, value):
- """
- Scores the sample by inverting the transform(s) and computing the score
- using the score of the base distribution and the log abs det jacobian.
- """
- if self._validate_args:
- self._validate_sample(value)
- event_dim = len(self.event_shape)
- log_prob = 0.0
- y = value
- for transform in reversed(self.transforms):
- x = transform.inv(y)
- event_dim += transform.domain.event_dim - transform.codomain.event_dim
- log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
- event_dim - transform.domain.event_dim)
- y = x
- log_prob = log_prob + _sum_rightmost(self.base_dist.log_prob(y),
- event_dim - len(self.base_dist.event_shape))
- return log_prob
- def _monotonize_cdf(self, value):
- """
- This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is
- monotone increasing.
- """
- sign = 1
- for transform in self.transforms:
- sign = sign * transform.sign
- if isinstance(sign, int) and sign == 1:
- return value
- return sign * (value - 0.5) + 0.5
- def cdf(self, value):
- """
- Computes the cumulative distribution function by inverting the
- transform(s) and computing the score of the base distribution.
- """
- for transform in self.transforms[::-1]:
- value = transform.inv(value)
- if self._validate_args:
- self.base_dist._validate_sample(value)
- value = self.base_dist.cdf(value)
- value = self._monotonize_cdf(value)
- return value
- def icdf(self, value):
- """
- Computes the inverse cumulative distribution function using
- transform(s) and computing the score of the base distribution.
- """
- value = self._monotonize_cdf(value)
- value = self.base_dist.icdf(value)
- for transform in self.transforms:
- value = transform(value)
- return value
|