bernoulli.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from numbers import Number
  2. import torch
  3. from torch import nan
  4. from torch.distributions import constraints
  5. from torch.distributions.exp_family import ExponentialFamily
  6. from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
  7. from torch.nn.functional import binary_cross_entropy_with_logits
  8. __all__ = ['Bernoulli']
  9. class Bernoulli(ExponentialFamily):
  10. r"""
  11. Creates a Bernoulli distribution parameterized by :attr:`probs`
  12. or :attr:`logits` (but not both).
  13. Samples are binary (0 or 1). They take the value `1` with probability `p`
  14. and `0` with probability `1 - p`.
  15. Example::
  16. >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
  17. >>> m = Bernoulli(torch.tensor([0.3]))
  18. >>> m.sample() # 30% chance 1; 70% chance 0
  19. tensor([ 0.])
  20. Args:
  21. probs (Number, Tensor): the probability of sampling `1`
  22. logits (Number, Tensor): the log-odds of sampling `1`
  23. """
  24. arg_constraints = {'probs': constraints.unit_interval,
  25. 'logits': constraints.real}
  26. support = constraints.boolean
  27. has_enumerate_support = True
  28. _mean_carrier_measure = 0
  29. def __init__(self, probs=None, logits=None, validate_args=None):
  30. if (probs is None) == (logits is None):
  31. raise ValueError("Either `probs` or `logits` must be specified, but not both.")
  32. if probs is not None:
  33. is_scalar = isinstance(probs, Number)
  34. self.probs, = broadcast_all(probs)
  35. else:
  36. is_scalar = isinstance(logits, Number)
  37. self.logits, = broadcast_all(logits)
  38. self._param = self.probs if probs is not None else self.logits
  39. if is_scalar:
  40. batch_shape = torch.Size()
  41. else:
  42. batch_shape = self._param.size()
  43. super().__init__(batch_shape, validate_args=validate_args)
  44. def expand(self, batch_shape, _instance=None):
  45. new = self._get_checked_instance(Bernoulli, _instance)
  46. batch_shape = torch.Size(batch_shape)
  47. if 'probs' in self.__dict__:
  48. new.probs = self.probs.expand(batch_shape)
  49. new._param = new.probs
  50. if 'logits' in self.__dict__:
  51. new.logits = self.logits.expand(batch_shape)
  52. new._param = new.logits
  53. super(Bernoulli, new).__init__(batch_shape, validate_args=False)
  54. new._validate_args = self._validate_args
  55. return new
  56. def _new(self, *args, **kwargs):
  57. return self._param.new(*args, **kwargs)
  58. @property
  59. def mean(self):
  60. return self.probs
  61. @property
  62. def mode(self):
  63. mode = (self.probs >= 0.5).to(self.probs)
  64. mode[self.probs == 0.5] = nan
  65. return mode
  66. @property
  67. def variance(self):
  68. return self.probs * (1 - self.probs)
  69. @lazy_property
  70. def logits(self):
  71. return probs_to_logits(self.probs, is_binary=True)
  72. @lazy_property
  73. def probs(self):
  74. return logits_to_probs(self.logits, is_binary=True)
  75. @property
  76. def param_shape(self):
  77. return self._param.size()
  78. def sample(self, sample_shape=torch.Size()):
  79. shape = self._extended_shape(sample_shape)
  80. with torch.no_grad():
  81. return torch.bernoulli(self.probs.expand(shape))
  82. def log_prob(self, value):
  83. if self._validate_args:
  84. self._validate_sample(value)
  85. logits, value = broadcast_all(self.logits, value)
  86. return -binary_cross_entropy_with_logits(logits, value, reduction='none')
  87. def entropy(self):
  88. return binary_cross_entropy_with_logits(self.logits, self.probs, reduction='none')
  89. def enumerate_support(self, expand=True):
  90. values = torch.arange(2, dtype=self._param.dtype, device=self._param.device)
  91. values = values.view((-1,) + (1,) * len(self._batch_shape))
  92. if expand:
  93. values = values.expand((-1,) + self._batch_shape)
  94. return values
  95. @property
  96. def _natural_params(self):
  97. return (torch.logit(self.probs), )
  98. def _log_normalizer(self, x):
  99. return torch.log1p(torch.exp(x))