exponential.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from numbers import Number
  2. import torch
  3. from torch.distributions import constraints
  4. from torch.distributions.exp_family import ExponentialFamily
  5. from torch.distributions.utils import broadcast_all
  6. __all__ = ['Exponential']
  7. class Exponential(ExponentialFamily):
  8. r"""
  9. Creates a Exponential distribution parameterized by :attr:`rate`.
  10. Example::
  11. >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
  12. >>> m = Exponential(torch.tensor([1.0]))
  13. >>> m.sample() # Exponential distributed with rate=1
  14. tensor([ 0.1046])
  15. Args:
  16. rate (float or Tensor): rate = 1 / scale of the distribution
  17. """
  18. arg_constraints = {'rate': constraints.positive}
  19. support = constraints.nonnegative
  20. has_rsample = True
  21. _mean_carrier_measure = 0
  22. @property
  23. def mean(self):
  24. return self.rate.reciprocal()
  25. @property
  26. def mode(self):
  27. return torch.zeros_like(self.rate)
  28. @property
  29. def stddev(self):
  30. return self.rate.reciprocal()
  31. @property
  32. def variance(self):
  33. return self.rate.pow(-2)
  34. def __init__(self, rate, validate_args=None):
  35. self.rate, = broadcast_all(rate)
  36. batch_shape = torch.Size() if isinstance(rate, Number) else self.rate.size()
  37. super().__init__(batch_shape, validate_args=validate_args)
  38. def expand(self, batch_shape, _instance=None):
  39. new = self._get_checked_instance(Exponential, _instance)
  40. batch_shape = torch.Size(batch_shape)
  41. new.rate = self.rate.expand(batch_shape)
  42. super(Exponential, new).__init__(batch_shape, validate_args=False)
  43. new._validate_args = self._validate_args
  44. return new
  45. def rsample(self, sample_shape=torch.Size()):
  46. shape = self._extended_shape(sample_shape)
  47. return self.rate.new(shape).exponential_() / self.rate
  48. def log_prob(self, value):
  49. if self._validate_args:
  50. self._validate_sample(value)
  51. return self.rate.log() - self.rate * value
  52. def cdf(self, value):
  53. if self._validate_args:
  54. self._validate_sample(value)
  55. return 1 - torch.exp(-self.rate * value)
  56. def icdf(self, value):
  57. return -torch.log1p(-value) / self.rate
  58. def entropy(self):
  59. return 1.0 - torch.log(self.rate)
  60. @property
  61. def _natural_params(self):
  62. return (-self.rate, )
  63. def _log_normalizer(self, x):
  64. return -torch.log(-x)