fishersnedecor.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from numbers import Number
  2. import torch
  3. from torch import nan
  4. from torch.distributions import constraints
  5. from torch.distributions.distribution import Distribution
  6. from torch.distributions.gamma import Gamma
  7. from torch.distributions.utils import broadcast_all
  8. __all__ = ['FisherSnedecor']
  9. class FisherSnedecor(Distribution):
  10. r"""
  11. Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`.
  12. Example::
  13. >>> # xdoctest: +IGNORE_WANT("non-deterinistic")
  14. >>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0]))
  15. >>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2
  16. tensor([ 0.2453])
  17. Args:
  18. df1 (float or Tensor): degrees of freedom parameter 1
  19. df2 (float or Tensor): degrees of freedom parameter 2
  20. """
  21. arg_constraints = {'df1': constraints.positive, 'df2': constraints.positive}
  22. support = constraints.positive
  23. has_rsample = True
  24. def __init__(self, df1, df2, validate_args=None):
  25. self.df1, self.df2 = broadcast_all(df1, df2)
  26. self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
  27. self._gamma2 = Gamma(self.df2 * 0.5, self.df2)
  28. if isinstance(df1, Number) and isinstance(df2, Number):
  29. batch_shape = torch.Size()
  30. else:
  31. batch_shape = self.df1.size()
  32. super().__init__(batch_shape, validate_args=validate_args)
  33. def expand(self, batch_shape, _instance=None):
  34. new = self._get_checked_instance(FisherSnedecor, _instance)
  35. batch_shape = torch.Size(batch_shape)
  36. new.df1 = self.df1.expand(batch_shape)
  37. new.df2 = self.df2.expand(batch_shape)
  38. new._gamma1 = self._gamma1.expand(batch_shape)
  39. new._gamma2 = self._gamma2.expand(batch_shape)
  40. super(FisherSnedecor, new).__init__(batch_shape, validate_args=False)
  41. new._validate_args = self._validate_args
  42. return new
  43. @property
  44. def mean(self):
  45. df2 = self.df2.clone(memory_format=torch.contiguous_format)
  46. df2[df2 <= 2] = nan
  47. return df2 / (df2 - 2)
  48. @property
  49. def mode(self):
  50. mode = (self.df1 - 2) / self.df1 * self.df2 / (self.df2 + 2)
  51. mode[self.df1 <= 2] = nan
  52. return mode
  53. @property
  54. def variance(self):
  55. df2 = self.df2.clone(memory_format=torch.contiguous_format)
  56. df2[df2 <= 4] = nan
  57. return 2 * df2.pow(2) * (self.df1 + df2 - 2) / (self.df1 * (df2 - 2).pow(2) * (df2 - 4))
  58. def rsample(self, sample_shape=torch.Size(())):
  59. shape = self._extended_shape(sample_shape)
  60. # X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2)
  61. # Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2)
  62. X1 = self._gamma1.rsample(sample_shape).view(shape)
  63. X2 = self._gamma2.rsample(sample_shape).view(shape)
  64. tiny = torch.finfo(X2.dtype).tiny
  65. X2.clamp_(min=tiny)
  66. Y = X1 / X2
  67. Y.clamp_(min=tiny)
  68. return Y
  69. def log_prob(self, value):
  70. if self._validate_args:
  71. self._validate_sample(value)
  72. ct1 = self.df1 * 0.5
  73. ct2 = self.df2 * 0.5
  74. ct3 = self.df1 / self.df2
  75. t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma()
  76. t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value)
  77. t3 = (ct1 + ct2) * torch.log1p(ct3 * value)
  78. return t1 + t2 - t3