relaxed_categorical.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import torch
  2. from torch.distributions import constraints
  3. from torch.distributions.categorical import Categorical
  4. from torch.distributions.utils import clamp_probs, broadcast_all
  5. from torch.distributions.distribution import Distribution
  6. from torch.distributions.transformed_distribution import TransformedDistribution
  7. from torch.distributions.transforms import ExpTransform
  8. __all__ = ['ExpRelaxedCategorical', 'RelaxedOneHotCategorical']
  9. class ExpRelaxedCategorical(Distribution):
  10. r"""
  11. Creates a ExpRelaxedCategorical parameterized by
  12. :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both).
  13. Returns the log of a point in the simplex. Based on the interface to
  14. :class:`OneHotCategorical`.
  15. Implementation based on [1].
  16. See also: :func:`torch.distributions.OneHotCategorical`
  17. Args:
  18. temperature (Tensor): relaxation temperature
  19. probs (Tensor): event probabilities
  20. logits (Tensor): unnormalized log probability for each event
  21. [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables
  22. (Maddison et al, 2017)
  23. [2] Categorical Reparametrization with Gumbel-Softmax
  24. (Jang et al, 2017)
  25. """
  26. arg_constraints = {'probs': constraints.simplex,
  27. 'logits': constraints.real_vector}
  28. support = constraints.real_vector # The true support is actually a submanifold of this.
  29. has_rsample = True
  30. def __init__(self, temperature, probs=None, logits=None, validate_args=None):
  31. self._categorical = Categorical(probs, logits)
  32. self.temperature = temperature
  33. batch_shape = self._categorical.batch_shape
  34. event_shape = self._categorical.param_shape[-1:]
  35. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  36. def expand(self, batch_shape, _instance=None):
  37. new = self._get_checked_instance(ExpRelaxedCategorical, _instance)
  38. batch_shape = torch.Size(batch_shape)
  39. new.temperature = self.temperature
  40. new._categorical = self._categorical.expand(batch_shape)
  41. super(ExpRelaxedCategorical, new).__init__(batch_shape, self.event_shape, validate_args=False)
  42. new._validate_args = self._validate_args
  43. return new
  44. def _new(self, *args, **kwargs):
  45. return self._categorical._new(*args, **kwargs)
  46. @property
  47. def param_shape(self):
  48. return self._categorical.param_shape
  49. @property
  50. def logits(self):
  51. return self._categorical.logits
  52. @property
  53. def probs(self):
  54. return self._categorical.probs
  55. def rsample(self, sample_shape=torch.Size()):
  56. shape = self._extended_shape(sample_shape)
  57. uniforms = clamp_probs(torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device))
  58. gumbels = -((-(uniforms.log())).log())
  59. scores = (self.logits + gumbels) / self.temperature
  60. return scores - scores.logsumexp(dim=-1, keepdim=True)
  61. def log_prob(self, value):
  62. K = self._categorical._num_events
  63. if self._validate_args:
  64. self._validate_sample(value)
  65. logits, value = broadcast_all(self.logits, value)
  66. log_scale = (torch.full_like(self.temperature, float(K)).lgamma() -
  67. self.temperature.log().mul(-(K - 1)))
  68. score = logits - value.mul(self.temperature)
  69. score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1)
  70. return score + log_scale
  71. class RelaxedOneHotCategorical(TransformedDistribution):
  72. r"""
  73. Creates a RelaxedOneHotCategorical distribution parametrized by
  74. :attr:`temperature`, and either :attr:`probs` or :attr:`logits`.
  75. This is a relaxed version of the :class:`OneHotCategorical` distribution, so
  76. its samples are on simplex, and are reparametrizable.
  77. Example::
  78. >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
  79. >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]),
  80. ... torch.tensor([0.1, 0.2, 0.3, 0.4]))
  81. >>> m.sample()
  82. tensor([ 0.1294, 0.2324, 0.3859, 0.2523])
  83. Args:
  84. temperature (Tensor): relaxation temperature
  85. probs (Tensor): event probabilities
  86. logits (Tensor): unnormalized log probability for each event
  87. """
  88. arg_constraints = {'probs': constraints.simplex,
  89. 'logits': constraints.real_vector}
  90. support = constraints.simplex
  91. has_rsample = True
  92. def __init__(self, temperature, probs=None, logits=None, validate_args=None):
  93. base_dist = ExpRelaxedCategorical(temperature, probs, logits, validate_args=validate_args)
  94. super().__init__(base_dist, ExpTransform(), validate_args=validate_args)
  95. def expand(self, batch_shape, _instance=None):
  96. new = self._get_checked_instance(RelaxedOneHotCategorical, _instance)
  97. return super().expand(batch_shape, _instance=new)
  98. @property
  99. def temperature(self):
  100. return self.base_dist.temperature
  101. @property
  102. def logits(self):
  103. return self.base_dist.logits
  104. @property
  105. def probs(self):
  106. return self.base_dist.probs