mixture_same_family.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. import torch
  2. from torch.distributions.distribution import Distribution
  3. from torch.distributions import Categorical
  4. from torch.distributions import constraints
  5. from typing import Dict
  6. __all__ = ['MixtureSameFamily']
  7. class MixtureSameFamily(Distribution):
  8. r"""
  9. The `MixtureSameFamily` distribution implements a (batch of) mixture
  10. distribution where all component are from different parameterizations of
  11. the same distribution type. It is parameterized by a `Categorical`
  12. "selecting distribution" (over `k` component) and a component
  13. distribution, i.e., a `Distribution` with a rightmost batch shape
  14. (equal to `[k]`) which indexes each (batch of) component.
  15. Examples::
  16. >>> # xdoctest: +SKIP("undefined vars")
  17. >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
  18. >>> # weighted normal distributions
  19. >>> mix = D.Categorical(torch.ones(5,))
  20. >>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
  21. >>> gmm = MixtureSameFamily(mix, comp)
  22. >>> # Construct Gaussian Mixture Modle in 2D consisting of 5 equally
  23. >>> # weighted bivariate normal distributions
  24. >>> mix = D.Categorical(torch.ones(5,))
  25. >>> comp = D.Independent(D.Normal(
  26. ... torch.randn(5,2), torch.rand(5,2)), 1)
  27. >>> gmm = MixtureSameFamily(mix, comp)
  28. >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
  29. >>> # consisting of 5 random weighted bivariate normal distributions
  30. >>> mix = D.Categorical(torch.rand(3,5))
  31. >>> comp = D.Independent(D.Normal(
  32. ... torch.randn(3,5,2), torch.rand(3,5,2)), 1)
  33. >>> gmm = MixtureSameFamily(mix, comp)
  34. Args:
  35. mixture_distribution: `torch.distributions.Categorical`-like
  36. instance. Manages the probability of selecting component.
  37. The number of categories must match the rightmost batch
  38. dimension of the `component_distribution`. Must have either
  39. scalar `batch_shape` or `batch_shape` matching
  40. `component_distribution.batch_shape[:-1]`
  41. component_distribution: `torch.distributions.Distribution`-like
  42. instance. Right-most batch dimension indexes component.
  43. """
  44. arg_constraints: Dict[str, constraints.Constraint] = {}
  45. has_rsample = False
  46. def __init__(self,
  47. mixture_distribution,
  48. component_distribution,
  49. validate_args=None):
  50. self._mixture_distribution = mixture_distribution
  51. self._component_distribution = component_distribution
  52. if not isinstance(self._mixture_distribution, Categorical):
  53. raise ValueError(" The Mixture distribution needs to be an "
  54. " instance of torch.distributions.Categorical")
  55. if not isinstance(self._component_distribution, Distribution):
  56. raise ValueError("The Component distribution need to be an "
  57. "instance of torch.distributions.Distribution")
  58. # Check that batch size matches
  59. mdbs = self._mixture_distribution.batch_shape
  60. cdbs = self._component_distribution.batch_shape[:-1]
  61. for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
  62. if size1 != 1 and size2 != 1 and size1 != size2:
  63. raise ValueError("`mixture_distribution.batch_shape` ({0}) is not "
  64. "compatible with `component_distribution."
  65. "batch_shape`({1})".format(mdbs, cdbs))
  66. # Check that the number of mixture component matches
  67. km = self._mixture_distribution.logits.shape[-1]
  68. kc = self._component_distribution.batch_shape[-1]
  69. if km is not None and kc is not None and km != kc:
  70. raise ValueError("`mixture_distribution component` ({0}) does not"
  71. " equal `component_distribution.batch_shape[-1]`"
  72. " ({1})".format(km, kc))
  73. self._num_component = km
  74. event_shape = self._component_distribution.event_shape
  75. self._event_ndims = len(event_shape)
  76. super().__init__(batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args)
  77. def expand(self, batch_shape, _instance=None):
  78. batch_shape = torch.Size(batch_shape)
  79. batch_shape_comp = batch_shape + (self._num_component,)
  80. new = self._get_checked_instance(MixtureSameFamily, _instance)
  81. new._component_distribution = \
  82. self._component_distribution.expand(batch_shape_comp)
  83. new._mixture_distribution = \
  84. self._mixture_distribution.expand(batch_shape)
  85. new._num_component = self._num_component
  86. new._event_ndims = self._event_ndims
  87. event_shape = new._component_distribution.event_shape
  88. super(MixtureSameFamily, new).__init__(batch_shape=batch_shape,
  89. event_shape=event_shape,
  90. validate_args=False)
  91. new._validate_args = self._validate_args
  92. return new
  93. @constraints.dependent_property
  94. def support(self):
  95. # FIXME this may have the wrong shape when support contains batched
  96. # parameters
  97. return self._component_distribution.support
  98. @property
  99. def mixture_distribution(self):
  100. return self._mixture_distribution
  101. @property
  102. def component_distribution(self):
  103. return self._component_distribution
  104. @property
  105. def mean(self):
  106. probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
  107. return torch.sum(probs * self.component_distribution.mean,
  108. dim=-1 - self._event_ndims) # [B, E]
  109. @property
  110. def variance(self):
  111. # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
  112. probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
  113. mean_cond_var = torch.sum(probs * self.component_distribution.variance,
  114. dim=-1 - self._event_ndims)
  115. var_cond_mean = torch.sum(probs * (self.component_distribution.mean -
  116. self._pad(self.mean)).pow(2.0),
  117. dim=-1 - self._event_ndims)
  118. return mean_cond_var + var_cond_mean
  119. def cdf(self, x):
  120. x = self._pad(x)
  121. cdf_x = self.component_distribution.cdf(x)
  122. mix_prob = self.mixture_distribution.probs
  123. return torch.sum(cdf_x * mix_prob, dim=-1)
  124. def log_prob(self, x):
  125. if self._validate_args:
  126. self._validate_sample(x)
  127. x = self._pad(x)
  128. log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]
  129. log_mix_prob = torch.log_softmax(self.mixture_distribution.logits,
  130. dim=-1) # [B, k]
  131. return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B]
  132. def sample(self, sample_shape=torch.Size()):
  133. with torch.no_grad():
  134. sample_len = len(sample_shape)
  135. batch_len = len(self.batch_shape)
  136. gather_dim = sample_len + batch_len
  137. es = self.event_shape
  138. # mixture samples [n, B]
  139. mix_sample = self.mixture_distribution.sample(sample_shape)
  140. mix_shape = mix_sample.shape
  141. # component samples [n, B, k, E]
  142. comp_samples = self.component_distribution.sample(sample_shape)
  143. # Gather along the k dimension
  144. mix_sample_r = mix_sample.reshape(
  145. mix_shape + torch.Size([1] * (len(es) + 1)))
  146. mix_sample_r = mix_sample_r.repeat(
  147. torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es)
  148. samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
  149. return samples.squeeze(gather_dim)
  150. def _pad(self, x):
  151. return x.unsqueeze(-1 - self._event_ndims)
  152. def _pad_mixture_dimensions(self, x):
  153. dist_batch_ndims = self.batch_shape.numel()
  154. cat_batch_ndims = self.mixture_distribution.batch_shape.numel()
  155. pad_ndims = 0 if cat_batch_ndims == 1 else \
  156. dist_batch_ndims - cat_batch_ndims
  157. xs = x.shape
  158. x = x.reshape(xs[:-1] + torch.Size(pad_ndims * [1]) +
  159. xs[-1:] + torch.Size(self._event_ndims * [1]))
  160. return x
  161. def __repr__(self):
  162. args_string = '\n {},\n {}'.format(self.mixture_distribution,
  163. self.component_distribution)
  164. return 'MixtureSameFamily' + '(' + args_string + ')'