cauchy.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. import math
  2. from torch import inf, nan
  3. from numbers import Number
  4. import torch
  5. from torch.distributions import constraints
  6. from torch.distributions.distribution import Distribution
  7. from torch.distributions.utils import broadcast_all
  8. __all__ = ['Cauchy']
  9. class Cauchy(Distribution):
  10. r"""
  11. Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of
  12. independent normally distributed random variables with means `0` follows a
  13. Cauchy distribution.
  14. Example::
  15. >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
  16. >>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0]))
  17. >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1
  18. tensor([ 2.3214])
  19. Args:
  20. loc (float or Tensor): mode or median of the distribution.
  21. scale (float or Tensor): half width at half maximum.
  22. """
  23. arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
  24. support = constraints.real
  25. has_rsample = True
  26. def __init__(self, loc, scale, validate_args=None):
  27. self.loc, self.scale = broadcast_all(loc, scale)
  28. if isinstance(loc, Number) and isinstance(scale, Number):
  29. batch_shape = torch.Size()
  30. else:
  31. batch_shape = self.loc.size()
  32. super().__init__(batch_shape, validate_args=validate_args)
  33. def expand(self, batch_shape, _instance=None):
  34. new = self._get_checked_instance(Cauchy, _instance)
  35. batch_shape = torch.Size(batch_shape)
  36. new.loc = self.loc.expand(batch_shape)
  37. new.scale = self.scale.expand(batch_shape)
  38. super(Cauchy, new).__init__(batch_shape, validate_args=False)
  39. new._validate_args = self._validate_args
  40. return new
  41. @property
  42. def mean(self):
  43. return torch.full(self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device)
  44. @property
  45. def mode(self):
  46. return self.loc
  47. @property
  48. def variance(self):
  49. return torch.full(self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device)
  50. def rsample(self, sample_shape=torch.Size()):
  51. shape = self._extended_shape(sample_shape)
  52. eps = self.loc.new(shape).cauchy_()
  53. return self.loc + eps * self.scale
  54. def log_prob(self, value):
  55. if self._validate_args:
  56. self._validate_sample(value)
  57. return -math.log(math.pi) - self.scale.log() - (((value - self.loc) / self.scale)**2).log1p()
  58. def cdf(self, value):
  59. if self._validate_args:
  60. self._validate_sample(value)
  61. return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5
  62. def icdf(self, value):
  63. return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc
  64. def entropy(self):
  65. return math.log(4 * math.pi) + self.scale.log()