lowrank_multivariate_normal.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import math
  2. import torch
  3. from torch.distributions import constraints
  4. from torch.distributions.distribution import Distribution
  5. from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
  6. from torch.distributions.utils import _standard_normal, lazy_property
  7. __all__ = ['LowRankMultivariateNormal']
  8. def _batch_capacitance_tril(W, D):
  9. r"""
  10. Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
  11. and a batch of vectors :math:`D`.
  12. """
  13. m = W.size(-1)
  14. Wt_Dinv = W.mT / D.unsqueeze(-2)
  15. K = torch.matmul(Wt_Dinv, W).contiguous()
  16. K.view(-1, m * m)[:, ::m + 1] += 1 # add identity matrix to K
  17. return torch.linalg.cholesky(K)
  18. def _batch_lowrank_logdet(W, D, capacitance_tril):
  19. r"""
  20. Uses "matrix determinant lemma"::
  21. log|W @ W.T + D| = log|C| + log|D|,
  22. where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
  23. the log determinant.
  24. """
  25. return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1)
  26. def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
  27. r"""
  28. Uses "Woodbury matrix identity"::
  29. inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
  30. where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
  31. Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
  32. """
  33. Wt_Dinv = W.mT / D.unsqueeze(-2)
  34. Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
  35. mahalanobis_term1 = (x.pow(2) / D).sum(-1)
  36. mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
  37. return mahalanobis_term1 - mahalanobis_term2
  38. class LowRankMultivariateNormal(Distribution):
  39. r"""
  40. Creates a multivariate normal distribution with covariance matrix having a low-rank form
  41. parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
  42. covariance_matrix = cov_factor @ cov_factor.T + cov_diag
  43. Example:
  44. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  45. >>> # xdoctest: +IGNORE_WANT("non-determenistic")
  46. >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
  47. >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
  48. tensor([-0.2102, -0.5429])
  49. Args:
  50. loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
  51. cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
  52. `batch_shape + event_shape + (rank,)`
  53. cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
  54. `batch_shape + event_shape`
  55. Note:
  56. The computation for determinant and inverse of covariance matrix is avoided when
  57. `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
  58. <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
  59. `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
  60. Thanks to these formulas, we just need to compute the determinant and inverse of
  61. the small size "capacitance" matrix::
  62. capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
  63. """
  64. arg_constraints = {"loc": constraints.real_vector,
  65. "cov_factor": constraints.independent(constraints.real, 2),
  66. "cov_diag": constraints.independent(constraints.positive, 1)}
  67. support = constraints.real_vector
  68. has_rsample = True
  69. def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
  70. if loc.dim() < 1:
  71. raise ValueError("loc must be at least one-dimensional.")
  72. event_shape = loc.shape[-1:]
  73. if cov_factor.dim() < 2:
  74. raise ValueError("cov_factor must be at least two-dimensional, "
  75. "with optional leading batch dimensions")
  76. if cov_factor.shape[-2:-1] != event_shape:
  77. raise ValueError("cov_factor must be a batch of matrices with shape {} x m"
  78. .format(event_shape[0]))
  79. if cov_diag.shape[-1:] != event_shape:
  80. raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape))
  81. loc_ = loc.unsqueeze(-1)
  82. cov_diag_ = cov_diag.unsqueeze(-1)
  83. try:
  84. loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_)
  85. except RuntimeError as e:
  86. raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
  87. .format(loc.shape, cov_factor.shape, cov_diag.shape)) from e
  88. self.loc = loc_[..., 0]
  89. self.cov_diag = cov_diag_[..., 0]
  90. batch_shape = self.loc.shape[:-1]
  91. self._unbroadcasted_cov_factor = cov_factor
  92. self._unbroadcasted_cov_diag = cov_diag
  93. self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
  94. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  95. def expand(self, batch_shape, _instance=None):
  96. new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
  97. batch_shape = torch.Size(batch_shape)
  98. loc_shape = batch_shape + self.event_shape
  99. new.loc = self.loc.expand(loc_shape)
  100. new.cov_diag = self.cov_diag.expand(loc_shape)
  101. new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
  102. new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
  103. new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
  104. new._capacitance_tril = self._capacitance_tril
  105. super(LowRankMultivariateNormal, new).__init__(batch_shape,
  106. self.event_shape,
  107. validate_args=False)
  108. new._validate_args = self._validate_args
  109. return new
  110. @property
  111. def mean(self):
  112. return self.loc
  113. @property
  114. def mode(self):
  115. return self.loc
  116. @lazy_property
  117. def variance(self):
  118. return (self._unbroadcasted_cov_factor.pow(2).sum(-1)
  119. + self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape)
  120. @lazy_property
  121. def scale_tril(self):
  122. # The following identity is used to increase the numerically computation stability
  123. # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
  124. # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
  125. # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
  126. # hence it is well-conditioned and safe to take Cholesky decomposition.
  127. n = self._event_shape[0]
  128. cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
  129. Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
  130. K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
  131. K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K
  132. scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K)
  133. return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)
  134. @lazy_property
  135. def covariance_matrix(self):
  136. covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor,
  137. self._unbroadcasted_cov_factor.mT)
  138. + torch.diag_embed(self._unbroadcasted_cov_diag))
  139. return covariance_matrix.expand(self._batch_shape + self._event_shape +
  140. self._event_shape)
  141. @lazy_property
  142. def precision_matrix(self):
  143. # We use "Woodbury matrix identity" to take advantage of low rank form::
  144. # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
  145. # where :math:`C` is the capacitance matrix.
  146. Wt_Dinv = (self._unbroadcasted_cov_factor.mT
  147. / self._unbroadcasted_cov_diag.unsqueeze(-2))
  148. A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
  149. precision_matrix = torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
  150. return precision_matrix.expand(self._batch_shape + self._event_shape +
  151. self._event_shape)
  152. def rsample(self, sample_shape=torch.Size()):
  153. shape = self._extended_shape(sample_shape)
  154. W_shape = shape[:-1] + self.cov_factor.shape[-1:]
  155. eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
  156. eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
  157. return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
  158. + self._unbroadcasted_cov_diag.sqrt() * eps_D)
  159. def log_prob(self, value):
  160. if self._validate_args:
  161. self._validate_sample(value)
  162. diff = value - self.loc
  163. M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor,
  164. self._unbroadcasted_cov_diag,
  165. diff,
  166. self._capacitance_tril)
  167. log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
  168. self._unbroadcasted_cov_diag,
  169. self._capacitance_tril)
  170. return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
  171. def entropy(self):
  172. log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
  173. self._unbroadcasted_cov_diag,
  174. self._capacitance_tril)
  175. H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
  176. if len(self._batch_shape) == 0:
  177. return H
  178. else:
  179. return H.expand(self._batch_shape)