independent.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import torch
  2. from torch.distributions import constraints
  3. from torch.distributions.distribution import Distribution
  4. from torch.distributions.utils import _sum_rightmost
  5. from typing import Dict
  6. __all__ = ['Independent']
  7. class Independent(Distribution):
  8. r"""
  9. Reinterprets some of the batch dims of a distribution as event dims.
  10. This is mainly useful for changing the shape of the result of
  11. :meth:`log_prob`. For example to create a diagonal Normal distribution with
  12. the same shape as a Multivariate Normal distribution (so they are
  13. interchangeable), you can::
  14. >>> from torch.distributions.multivariate_normal import MultivariateNormal
  15. >>> from torch.distributions.normal import Normal
  16. >>> loc = torch.zeros(3)
  17. >>> scale = torch.ones(3)
  18. >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale))
  19. >>> [mvn.batch_shape, mvn.event_shape]
  20. [torch.Size([]), torch.Size([3])]
  21. >>> normal = Normal(loc, scale)
  22. >>> [normal.batch_shape, normal.event_shape]
  23. [torch.Size([3]), torch.Size([])]
  24. >>> diagn = Independent(normal, 1)
  25. >>> [diagn.batch_shape, diagn.event_shape]
  26. [torch.Size([]), torch.Size([3])]
  27. Args:
  28. base_distribution (torch.distributions.distribution.Distribution): a
  29. base distribution
  30. reinterpreted_batch_ndims (int): the number of batch dims to
  31. reinterpret as event dims
  32. """
  33. arg_constraints: Dict[str, constraints.Constraint] = {}
  34. def __init__(self, base_distribution, reinterpreted_batch_ndims, validate_args=None):
  35. if reinterpreted_batch_ndims > len(base_distribution.batch_shape):
  36. raise ValueError("Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
  37. "actual {} vs {}".format(reinterpreted_batch_ndims,
  38. len(base_distribution.batch_shape)))
  39. shape = base_distribution.batch_shape + base_distribution.event_shape
  40. event_dim = reinterpreted_batch_ndims + len(base_distribution.event_shape)
  41. batch_shape = shape[:len(shape) - event_dim]
  42. event_shape = shape[len(shape) - event_dim:]
  43. self.base_dist = base_distribution
  44. self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
  45. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  46. def expand(self, batch_shape, _instance=None):
  47. new = self._get_checked_instance(Independent, _instance)
  48. batch_shape = torch.Size(batch_shape)
  49. new.base_dist = self.base_dist.expand(batch_shape +
  50. self.event_shape[:self.reinterpreted_batch_ndims])
  51. new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims
  52. super(Independent, new).__init__(batch_shape, self.event_shape, validate_args=False)
  53. new._validate_args = self._validate_args
  54. return new
  55. @property
  56. def has_rsample(self):
  57. return self.base_dist.has_rsample
  58. @property
  59. def has_enumerate_support(self):
  60. if self.reinterpreted_batch_ndims > 0:
  61. return False
  62. return self.base_dist.has_enumerate_support
  63. @constraints.dependent_property
  64. def support(self):
  65. result = self.base_dist.support
  66. if self.reinterpreted_batch_ndims:
  67. result = constraints.independent(result, self.reinterpreted_batch_ndims)
  68. return result
  69. @property
  70. def mean(self):
  71. return self.base_dist.mean
  72. @property
  73. def mode(self):
  74. return self.base_dist.mode
  75. @property
  76. def variance(self):
  77. return self.base_dist.variance
  78. def sample(self, sample_shape=torch.Size()):
  79. return self.base_dist.sample(sample_shape)
  80. def rsample(self, sample_shape=torch.Size()):
  81. return self.base_dist.rsample(sample_shape)
  82. def log_prob(self, value):
  83. log_prob = self.base_dist.log_prob(value)
  84. return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
  85. def entropy(self):
  86. entropy = self.base_dist.entropy()
  87. return _sum_rightmost(entropy, self.reinterpreted_batch_ndims)
  88. def enumerate_support(self, expand=True):
  89. if self.reinterpreted_batch_ndims > 0:
  90. raise NotImplementedError("Enumeration over cartesian product is not implemented")
  91. return self.base_dist.enumerate_support(expand=expand)
  92. def __repr__(self):
  93. return self.__class__.__name__ + '({}, {})'.format(self.base_dist, self.reinterpreted_batch_ndims)