gumbel.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from numbers import Number
  2. import math
  3. import torch
  4. from torch.distributions import constraints
  5. from torch.distributions.uniform import Uniform
  6. from torch.distributions.transformed_distribution import TransformedDistribution
  7. from torch.distributions.transforms import AffineTransform, ExpTransform
  8. from torch.distributions.utils import broadcast_all, euler_constant
  9. __all__ = ['Gumbel']
  10. class Gumbel(TransformedDistribution):
  11. r"""
  12. Samples from a Gumbel Distribution.
  13. Examples::
  14. >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
  15. >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0]))
  16. >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2
  17. tensor([ 1.0124])
  18. Args:
  19. loc (float or Tensor): Location parameter of the distribution
  20. scale (float or Tensor): Scale parameter of the distribution
  21. """
  22. arg_constraints = {'loc': constraints.real, 'scale': constraints.positive}
  23. support = constraints.real
  24. def __init__(self, loc, scale, validate_args=None):
  25. self.loc, self.scale = broadcast_all(loc, scale)
  26. finfo = torch.finfo(self.loc.dtype)
  27. if isinstance(loc, Number) and isinstance(scale, Number):
  28. base_dist = Uniform(finfo.tiny, 1 - finfo.eps)
  29. else:
  30. base_dist = Uniform(torch.full_like(self.loc, finfo.tiny),
  31. torch.full_like(self.loc, 1 - finfo.eps))
  32. transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-torch.ones_like(self.scale)),
  33. ExpTransform().inv, AffineTransform(loc=loc, scale=-self.scale)]
  34. super().__init__(base_dist, transforms, validate_args=validate_args)
  35. def expand(self, batch_shape, _instance=None):
  36. new = self._get_checked_instance(Gumbel, _instance)
  37. new.loc = self.loc.expand(batch_shape)
  38. new.scale = self.scale.expand(batch_shape)
  39. return super().expand(batch_shape, _instance=new)
  40. # Explicitly defining the log probability function for Gumbel due to precision issues
  41. def log_prob(self, value):
  42. if self._validate_args:
  43. self._validate_sample(value)
  44. y = (self.loc - value) / self.scale
  45. return (y - y.exp()) - self.scale.log()
  46. @property
  47. def mean(self):
  48. return self.loc + self.scale * euler_constant
  49. @property
  50. def mode(self):
  51. return self.loc
  52. @property
  53. def stddev(self):
  54. return (math.pi / math.sqrt(6)) * self.scale
  55. @property
  56. def variance(self):
  57. return self.stddev.pow(2)
  58. def entropy(self):
  59. return self.scale.log() + (1 + euler_constant)