multinomial.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import torch
  2. from torch import inf
  3. from torch.distributions.binomial import Binomial
  4. from torch.distributions.distribution import Distribution
  5. from torch.distributions import Categorical
  6. from torch.distributions import constraints
  7. from torch.distributions.utils import broadcast_all
  8. __all__ = ['Multinomial']
  9. class Multinomial(Distribution):
  10. r"""
  11. Creates a Multinomial distribution parameterized by :attr:`total_count` and
  12. either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of
  13. :attr:`probs` indexes over categories. All other dimensions index over batches.
  14. Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is
  15. called (see example below)
  16. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
  17. and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
  18. will return this normalized value.
  19. The `logits` argument will be interpreted as unnormalized log probabilities
  20. and can therefore be any real number. It will likewise be normalized so that
  21. the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
  22. will return this normalized value.
  23. - :meth:`sample` requires a single shared `total_count` for all
  24. parameters and samples.
  25. - :meth:`log_prob` allows different `total_count` for each parameter and
  26. sample.
  27. Example::
  28. >>> # xdoctest: +SKIP("FIXME: found invalid values")
  29. >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.]))
  30. >>> x = m.sample() # equal probability of 0, 1, 2, 3
  31. tensor([ 21., 24., 30., 25.])
  32. >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x)
  33. tensor([-4.1338])
  34. Args:
  35. total_count (int): number of trials
  36. probs (Tensor): event probabilities
  37. logits (Tensor): event log probabilities (unnormalized)
  38. """
  39. arg_constraints = {'probs': constraints.simplex,
  40. 'logits': constraints.real_vector}
  41. total_count: int
  42. @property
  43. def mean(self):
  44. return self.probs * self.total_count
  45. @property
  46. def variance(self):
  47. return self.total_count * self.probs * (1 - self.probs)
  48. def __init__(self, total_count=1, probs=None, logits=None, validate_args=None):
  49. if not isinstance(total_count, int):
  50. raise NotImplementedError('inhomogeneous total_count is not supported')
  51. self.total_count = total_count
  52. self._categorical = Categorical(probs=probs, logits=logits)
  53. self._binomial = Binomial(total_count=total_count, probs=self.probs)
  54. batch_shape = self._categorical.batch_shape
  55. event_shape = self._categorical.param_shape[-1:]
  56. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  57. def expand(self, batch_shape, _instance=None):
  58. new = self._get_checked_instance(Multinomial, _instance)
  59. batch_shape = torch.Size(batch_shape)
  60. new.total_count = self.total_count
  61. new._categorical = self._categorical.expand(batch_shape)
  62. super(Multinomial, new).__init__(batch_shape, self.event_shape, validate_args=False)
  63. new._validate_args = self._validate_args
  64. return new
  65. def _new(self, *args, **kwargs):
  66. return self._categorical._new(*args, **kwargs)
  67. @constraints.dependent_property(is_discrete=True, event_dim=1)
  68. def support(self):
  69. return constraints.multinomial(self.total_count)
  70. @property
  71. def logits(self):
  72. return self._categorical.logits
  73. @property
  74. def probs(self):
  75. return self._categorical.probs
  76. @property
  77. def param_shape(self):
  78. return self._categorical.param_shape
  79. def sample(self, sample_shape=torch.Size()):
  80. sample_shape = torch.Size(sample_shape)
  81. samples = self._categorical.sample(torch.Size((self.total_count,)) + sample_shape)
  82. # samples.shape is (total_count, sample_shape, batch_shape), need to change it to
  83. # (sample_shape, batch_shape, total_count)
  84. shifted_idx = list(range(samples.dim()))
  85. shifted_idx.append(shifted_idx.pop(0))
  86. samples = samples.permute(*shifted_idx)
  87. counts = samples.new(self._extended_shape(sample_shape)).zero_()
  88. counts.scatter_add_(-1, samples, torch.ones_like(samples))
  89. return counts.type_as(self.probs)
  90. def entropy(self):
  91. n = torch.tensor(self.total_count)
  92. cat_entropy = self._categorical.entropy()
  93. term1 = n * cat_entropy - torch.lgamma(n + 1)
  94. support = self._binomial.enumerate_support(expand=False)[1:]
  95. binomial_probs = torch.exp(self._binomial.log_prob(support))
  96. weights = torch.lgamma(support + 1)
  97. term2 = (binomial_probs * weights).sum([0, -1])
  98. return term1 + term2
  99. def log_prob(self, value):
  100. if self._validate_args:
  101. self._validate_sample(value)
  102. logits, value = broadcast_all(self.logits, value)
  103. logits = logits.clone(memory_format=torch.contiguous_format)
  104. log_factorial_n = torch.lgamma(value.sum(-1) + 1)
  105. log_factorial_xs = torch.lgamma(value + 1).sum(-1)
  106. logits[(value == 0) & (logits == -inf)] = 0
  107. log_powers = (logits * value).sum(-1)
  108. return log_factorial_n - log_factorial_xs + log_powers