von_mises.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import math
  2. import torch
  3. import torch.jit
  4. from torch.distributions import constraints
  5. from torch.distributions.distribution import Distribution
  6. from torch.distributions.utils import broadcast_all, lazy_property
  7. __all__ = ['VonMises']
  8. def _eval_poly(y, coef):
  9. coef = list(coef)
  10. result = coef.pop()
  11. while coef:
  12. result = coef.pop() + y * result
  13. return result
  14. _I0_COEF_SMALL = [1.0, 3.5156229, 3.0899424, 1.2067492, 0.2659732, 0.360768e-1, 0.45813e-2]
  15. _I0_COEF_LARGE = [0.39894228, 0.1328592e-1, 0.225319e-2, -0.157565e-2, 0.916281e-2,
  16. -0.2057706e-1, 0.2635537e-1, -0.1647633e-1, 0.392377e-2]
  17. _I1_COEF_SMALL = [0.5, 0.87890594, 0.51498869, 0.15084934, 0.2658733e-1, 0.301532e-2, 0.32411e-3]
  18. _I1_COEF_LARGE = [0.39894228, -0.3988024e-1, -0.362018e-2, 0.163801e-2, -0.1031555e-1,
  19. 0.2282967e-1, -0.2895312e-1, 0.1787654e-1, -0.420059e-2]
  20. _COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
  21. _COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
  22. def _log_modified_bessel_fn(x, order=0):
  23. """
  24. Returns ``log(I_order(x))`` for ``x > 0``,
  25. where `order` is either 0 or 1.
  26. """
  27. assert order == 0 or order == 1
  28. # compute small solution
  29. y = (x / 3.75)
  30. y = y * y
  31. small = _eval_poly(y, _COEF_SMALL[order])
  32. if order == 1:
  33. small = x.abs() * small
  34. small = small.log()
  35. # compute large solution
  36. y = 3.75 / x
  37. large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
  38. result = torch.where(x < 3.75, small, large)
  39. return result
  40. @torch.jit.script_if_tracing
  41. def _rejection_sample(loc, concentration, proposal_r, x):
  42. done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
  43. while not done.all():
  44. u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
  45. u1, u2, u3 = u.unbind()
  46. z = torch.cos(math.pi * u1)
  47. f = (1 + proposal_r * z) / (proposal_r + z)
  48. c = concentration * (proposal_r - f)
  49. accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
  50. if accept.any():
  51. x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
  52. done = done | accept
  53. return (x + math.pi + loc) % (2 * math.pi) - math.pi
  54. class VonMises(Distribution):
  55. """
  56. A circular von Mises distribution.
  57. This implementation uses polar coordinates. The ``loc`` and ``value`` args
  58. can be any real number (to facilitate unconstrained optimization), but are
  59. interpreted as angles modulo 2 pi.
  60. Example::
  61. >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
  62. >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
  63. >>> m.sample() # von Mises distributed with loc=1 and concentration=1
  64. tensor([1.9777])
  65. :param torch.Tensor loc: an angle in radians.
  66. :param torch.Tensor concentration: concentration parameter
  67. """
  68. arg_constraints = {'loc': constraints.real, 'concentration': constraints.positive}
  69. support = constraints.real
  70. has_rsample = False
  71. def __init__(self, loc, concentration, validate_args=None):
  72. self.loc, self.concentration = broadcast_all(loc, concentration)
  73. batch_shape = self.loc.shape
  74. event_shape = torch.Size()
  75. # Parameters for sampling
  76. tau = 1 + (1 + 4 * self.concentration ** 2).sqrt()
  77. rho = (tau - (2 * tau).sqrt()) / (2 * self.concentration)
  78. self._proposal_r = (1 + rho ** 2) / (2 * rho)
  79. super().__init__(batch_shape, event_shape, validate_args)
  80. def log_prob(self, value):
  81. if self._validate_args:
  82. self._validate_sample(value)
  83. log_prob = self.concentration * torch.cos(value - self.loc)
  84. log_prob = log_prob - math.log(2 * math.pi) - _log_modified_bessel_fn(self.concentration, order=0)
  85. return log_prob
  86. @torch.no_grad()
  87. def sample(self, sample_shape=torch.Size()):
  88. """
  89. The sampling algorithm for the von Mises distribution is based on the following paper:
  90. Best, D. J., and Nicholas I. Fisher.
  91. "Efficient simulation of the von Mises distribution." Applied Statistics (1979): 152-157.
  92. """
  93. shape = self._extended_shape(sample_shape)
  94. x = torch.empty(shape, dtype=self.loc.dtype, device=self.loc.device)
  95. return _rejection_sample(self.loc, self.concentration, self._proposal_r, x)
  96. def expand(self, batch_shape):
  97. try:
  98. return super().expand(batch_shape)
  99. except NotImplementedError:
  100. validate_args = self.__dict__.get('_validate_args')
  101. loc = self.loc.expand(batch_shape)
  102. concentration = self.concentration.expand(batch_shape)
  103. return type(self)(loc, concentration, validate_args=validate_args)
  104. @property
  105. def mean(self):
  106. """
  107. The provided mean is the circular one.
  108. """
  109. return self.loc
  110. @property
  111. def mode(self):
  112. return self.loc
  113. @lazy_property
  114. def variance(self):
  115. """
  116. The provided variance is the circular one.
  117. """
  118. return 1 - (_log_modified_bessel_fn(self.concentration, order=1) -
  119. _log_modified_bessel_fn(self.concentration, order=0)).exp()