poisson.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from numbers import Number
  2. import torch
  3. from torch.distributions import constraints
  4. from torch.distributions.exp_family import ExponentialFamily
  5. from torch.distributions.utils import broadcast_all
  6. __all__ = ['Poisson']
  7. class Poisson(ExponentialFamily):
  8. r"""
  9. Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter.
  10. Samples are nonnegative integers, with a pmf given by
  11. .. math::
  12. \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!}
  13. Example::
  14. >>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'")
  15. >>> m = Poisson(torch.tensor([4]))
  16. >>> m.sample()
  17. tensor([ 3.])
  18. Args:
  19. rate (Number, Tensor): the rate parameter
  20. """
  21. arg_constraints = {'rate': constraints.nonnegative}
  22. support = constraints.nonnegative_integer
  23. @property
  24. def mean(self):
  25. return self.rate
  26. @property
  27. def mode(self):
  28. return self.rate.floor()
  29. @property
  30. def variance(self):
  31. return self.rate
  32. def __init__(self, rate, validate_args=None):
  33. self.rate, = broadcast_all(rate)
  34. if isinstance(rate, Number):
  35. batch_shape = torch.Size()
  36. else:
  37. batch_shape = self.rate.size()
  38. super().__init__(batch_shape, validate_args=validate_args)
  39. def expand(self, batch_shape, _instance=None):
  40. new = self._get_checked_instance(Poisson, _instance)
  41. batch_shape = torch.Size(batch_shape)
  42. new.rate = self.rate.expand(batch_shape)
  43. super(Poisson, new).__init__(batch_shape, validate_args=False)
  44. new._validate_args = self._validate_args
  45. return new
  46. def sample(self, sample_shape=torch.Size()):
  47. shape = self._extended_shape(sample_shape)
  48. with torch.no_grad():
  49. return torch.poisson(self.rate.expand(shape))
  50. def log_prob(self, value):
  51. if self._validate_args:
  52. self._validate_sample(value)
  53. rate, value = broadcast_all(self.rate, value)
  54. return value.xlogy(rate) - rate - (value + 1).lgamma()
  55. @property
  56. def _natural_params(self):
  57. return (torch.log(self.rate), )
  58. def _log_normalizer(self, x):
  59. return torch.exp(x)