one_hot_categorical.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import torch
  2. from torch.distributions import constraints
  3. from torch.distributions.categorical import Categorical
  4. from torch.distributions.distribution import Distribution
  5. __all__ = ['OneHotCategorical', 'OneHotCategoricalStraightThrough']
  6. class OneHotCategorical(Distribution):
  7. r"""
  8. Creates a one-hot categorical distribution parameterized by :attr:`probs` or
  9. :attr:`logits`.
  10. Samples are one-hot coded vectors of size ``probs.size(-1)``.
  11. .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum,
  12. and it will be normalized to sum to 1 along the last dimension. :attr:`probs`
  13. will return this normalized value.
  14. The `logits` argument will be interpreted as unnormalized log probabilities
  15. and can therefore be any real number. It will likewise be normalized so that
  16. the resulting probabilities sum to 1 along the last dimension. :attr:`logits`
  17. will return this normalized value.
  18. See also: :func:`torch.distributions.Categorical` for specifications of
  19. :attr:`probs` and :attr:`logits`.
  20. Example::
  21. >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
  22. >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]))
  23. >>> m.sample() # equal probability of 0, 1, 2, 3
  24. tensor([ 0., 0., 0., 1.])
  25. Args:
  26. probs (Tensor): event probabilities
  27. logits (Tensor): event log probabilities (unnormalized)
  28. """
  29. arg_constraints = {'probs': constraints.simplex,
  30. 'logits': constraints.real_vector}
  31. support = constraints.one_hot
  32. has_enumerate_support = True
  33. def __init__(self, probs=None, logits=None, validate_args=None):
  34. self._categorical = Categorical(probs, logits)
  35. batch_shape = self._categorical.batch_shape
  36. event_shape = self._categorical.param_shape[-1:]
  37. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  38. def expand(self, batch_shape, _instance=None):
  39. new = self._get_checked_instance(OneHotCategorical, _instance)
  40. batch_shape = torch.Size(batch_shape)
  41. new._categorical = self._categorical.expand(batch_shape)
  42. super(OneHotCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
  43. new._validate_args = self._validate_args
  44. return new
  45. def _new(self, *args, **kwargs):
  46. return self._categorical._new(*args, **kwargs)
  47. @property
  48. def _param(self):
  49. return self._categorical._param
  50. @property
  51. def probs(self):
  52. return self._categorical.probs
  53. @property
  54. def logits(self):
  55. return self._categorical.logits
  56. @property
  57. def mean(self):
  58. return self._categorical.probs
  59. @property
  60. def mode(self):
  61. probs = self._categorical.probs
  62. mode = probs.argmax(axis=-1)
  63. return torch.nn.functional.one_hot(mode, num_classes=probs.shape[-1]).to(probs)
  64. @property
  65. def variance(self):
  66. return self._categorical.probs * (1 - self._categorical.probs)
  67. @property
  68. def param_shape(self):
  69. return self._categorical.param_shape
  70. def sample(self, sample_shape=torch.Size()):
  71. sample_shape = torch.Size(sample_shape)
  72. probs = self._categorical.probs
  73. num_events = self._categorical._num_events
  74. indices = self._categorical.sample(sample_shape)
  75. return torch.nn.functional.one_hot(indices, num_events).to(probs)
  76. def log_prob(self, value):
  77. if self._validate_args:
  78. self._validate_sample(value)
  79. indices = value.max(-1)[1]
  80. return self._categorical.log_prob(indices)
  81. def entropy(self):
  82. return self._categorical.entropy()
  83. def enumerate_support(self, expand=True):
  84. n = self.event_shape[0]
  85. values = torch.eye(n, dtype=self._param.dtype, device=self._param.device)
  86. values = values.view((n,) + (1,) * len(self.batch_shape) + (n,))
  87. if expand:
  88. values = values.expand((n,) + self.batch_shape + (n,))
  89. return values
  90. class OneHotCategoricalStraightThrough(OneHotCategorical):
  91. r"""
  92. Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight-
  93. through gradient estimator from [1].
  94. [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation
  95. (Bengio et al, 2013)
  96. """
  97. has_rsample = True
  98. def rsample(self, sample_shape=torch.Size()):
  99. samples = self.sample(sample_shape)
  100. probs = self._categorical.probs # cached via @lazy_property
  101. return samples + (probs - probs.detach())