multivariate_normal.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import math
  2. import torch
  3. from torch.distributions import constraints
  4. from torch.distributions.distribution import Distribution
  5. from torch.distributions.utils import _standard_normal, lazy_property
  6. __all__ = ['MultivariateNormal']
  7. def _batch_mv(bmat, bvec):
  8. r"""
  9. Performs a batched matrix-vector product, with compatible but different batch shapes.
  10. This function takes as input `bmat`, containing :math:`n \times n` matrices, and
  11. `bvec`, containing length :math:`n` vectors.
  12. Both `bmat` and `bvec` may have any number of leading dimensions, which correspond
  13. to a batch shape. They are not necessarily assumed to have the same batch shape,
  14. just ones which can be broadcasted.
  15. """
  16. return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
  17. def _batch_mahalanobis(bL, bx):
  18. r"""
  19. Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}`
  20. for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`.
  21. Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch
  22. shape, but `bL` one should be able to broadcasted to `bx` one.
  23. """
  24. n = bx.size(-1)
  25. bx_batch_shape = bx.shape[:-1]
  26. # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
  27. # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve
  28. bx_batch_dims = len(bx_batch_shape)
  29. bL_batch_dims = bL.dim() - 2
  30. outer_batch_dims = bx_batch_dims - bL_batch_dims
  31. old_batch_dims = outer_batch_dims + bL_batch_dims
  32. new_batch_dims = outer_batch_dims + 2 * bL_batch_dims
  33. # Reshape bx with the shape (..., 1, i, j, 1, n)
  34. bx_new_shape = bx.shape[:outer_batch_dims]
  35. for (sL, sx) in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]):
  36. bx_new_shape += (sx // sL, sL)
  37. bx_new_shape += (n,)
  38. bx = bx.reshape(bx_new_shape)
  39. # Permute bx to make it have shape (..., 1, j, i, 1, n)
  40. permute_dims = (list(range(outer_batch_dims)) +
  41. list(range(outer_batch_dims, new_batch_dims, 2)) +
  42. list(range(outer_batch_dims + 1, new_batch_dims, 2)) +
  43. [new_batch_dims])
  44. bx = bx.permute(permute_dims)
  45. flat_L = bL.reshape(-1, n, n) # shape = b x n x n
  46. flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
  47. flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
  48. M_swap = torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) # shape = b x c
  49. M = M_swap.t() # shape = c x b
  50. # Now we revert the above reshape and permute operators.
  51. permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1)
  52. permute_inv_dims = list(range(outer_batch_dims))
  53. for i in range(bL_batch_dims):
  54. permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i]
  55. reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1)
  56. return reshaped_M.reshape(bx_batch_shape)
  57. def _precision_to_scale_tril(P):
  58. # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
  59. Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1)))
  60. L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1)
  61. Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device)
  62. L = torch.linalg.solve_triangular(L_inv, Id, upper=False)
  63. return L
  64. class MultivariateNormal(Distribution):
  65. r"""
  66. Creates a multivariate normal (also called Gaussian) distribution
  67. parameterized by a mean vector and a covariance matrix.
  68. The multivariate normal distribution can be parameterized either
  69. in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}`
  70. or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}`
  71. or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued
  72. diagonal entries, such that
  73. :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix
  74. can be obtained via e.g. Cholesky decomposition of the covariance.
  75. Example:
  76. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  77. >>> # xdoctest: +IGNORE_WANT("non-determenistic")
  78. >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2))
  79. >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I`
  80. tensor([-0.2102, -0.5429])
  81. Args:
  82. loc (Tensor): mean of the distribution
  83. covariance_matrix (Tensor): positive-definite covariance matrix
  84. precision_matrix (Tensor): positive-definite precision matrix
  85. scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
  86. Note:
  87. Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
  88. :attr:`scale_tril` can be specified.
  89. Using :attr:`scale_tril` will be more efficient: all computations internally
  90. are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
  91. :attr:`precision_matrix` is passed instead, it is only used to compute
  92. the corresponding lower triangular matrices using a Cholesky decomposition.
  93. """
  94. arg_constraints = {'loc': constraints.real_vector,
  95. 'covariance_matrix': constraints.positive_definite,
  96. 'precision_matrix': constraints.positive_definite,
  97. 'scale_tril': constraints.lower_cholesky}
  98. support = constraints.real_vector
  99. has_rsample = True
  100. def __init__(self, loc, covariance_matrix=None, precision_matrix=None, scale_tril=None, validate_args=None):
  101. if loc.dim() < 1:
  102. raise ValueError("loc must be at least one-dimensional.")
  103. if (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) != 1:
  104. raise ValueError("Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified.")
  105. if scale_tril is not None:
  106. if scale_tril.dim() < 2:
  107. raise ValueError("scale_tril matrix must be at least two-dimensional, "
  108. "with optional leading batch dimensions")
  109. batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
  110. self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
  111. elif covariance_matrix is not None:
  112. if covariance_matrix.dim() < 2:
  113. raise ValueError("covariance_matrix must be at least two-dimensional, "
  114. "with optional leading batch dimensions")
  115. batch_shape = torch.broadcast_shapes(covariance_matrix.shape[:-2], loc.shape[:-1])
  116. self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
  117. else:
  118. if precision_matrix.dim() < 2:
  119. raise ValueError("precision_matrix must be at least two-dimensional, "
  120. "with optional leading batch dimensions")
  121. batch_shape = torch.broadcast_shapes(precision_matrix.shape[:-2], loc.shape[:-1])
  122. self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
  123. self.loc = loc.expand(batch_shape + (-1,))
  124. event_shape = self.loc.shape[-1:]
  125. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  126. if scale_tril is not None:
  127. self._unbroadcasted_scale_tril = scale_tril
  128. elif covariance_matrix is not None:
  129. self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
  130. else: # precision_matrix is not None
  131. self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
  132. def expand(self, batch_shape, _instance=None):
  133. new = self._get_checked_instance(MultivariateNormal, _instance)
  134. batch_shape = torch.Size(batch_shape)
  135. loc_shape = batch_shape + self.event_shape
  136. cov_shape = batch_shape + self.event_shape + self.event_shape
  137. new.loc = self.loc.expand(loc_shape)
  138. new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
  139. if 'covariance_matrix' in self.__dict__:
  140. new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
  141. if 'scale_tril' in self.__dict__:
  142. new.scale_tril = self.scale_tril.expand(cov_shape)
  143. if 'precision_matrix' in self.__dict__:
  144. new.precision_matrix = self.precision_matrix.expand(cov_shape)
  145. super(MultivariateNormal, new).__init__(batch_shape,
  146. self.event_shape,
  147. validate_args=False)
  148. new._validate_args = self._validate_args
  149. return new
  150. @lazy_property
  151. def scale_tril(self):
  152. return self._unbroadcasted_scale_tril.expand(
  153. self._batch_shape + self._event_shape + self._event_shape)
  154. @lazy_property
  155. def covariance_matrix(self):
  156. return (torch.matmul(self._unbroadcasted_scale_tril,
  157. self._unbroadcasted_scale_tril.mT)
  158. .expand(self._batch_shape + self._event_shape + self._event_shape))
  159. @lazy_property
  160. def precision_matrix(self):
  161. return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand(
  162. self._batch_shape + self._event_shape + self._event_shape)
  163. @property
  164. def mean(self):
  165. return self.loc
  166. @property
  167. def mode(self):
  168. return self.loc
  169. @property
  170. def variance(self):
  171. return self._unbroadcasted_scale_tril.pow(2).sum(-1).expand(
  172. self._batch_shape + self._event_shape)
  173. def rsample(self, sample_shape=torch.Size()):
  174. shape = self._extended_shape(sample_shape)
  175. eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
  176. return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps)
  177. def log_prob(self, value):
  178. if self._validate_args:
  179. self._validate_sample(value)
  180. diff = value - self.loc
  181. M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
  182. half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
  183. return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det
  184. def entropy(self):
  185. half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
  186. H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det
  187. if len(self._batch_shape) == 0:
  188. return H
  189. else:
  190. return H.expand(self._batch_shape)