123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- import math
- import torch
- from torch.distributions import constraints
- from torch.distributions.distribution import Distribution
- from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv
- from torch.distributions.utils import _standard_normal, lazy_property
- __all__ = ['LowRankMultivariateNormal']
- def _batch_capacitance_tril(W, D):
- r"""
- Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W`
- and a batch of vectors :math:`D`.
- """
- m = W.size(-1)
- Wt_Dinv = W.mT / D.unsqueeze(-2)
- K = torch.matmul(Wt_Dinv, W).contiguous()
- K.view(-1, m * m)[:, ::m + 1] += 1 # add identity matrix to K
- return torch.linalg.cholesky(K)
- def _batch_lowrank_logdet(W, D, capacitance_tril):
- r"""
- Uses "matrix determinant lemma"::
- log|W @ W.T + D| = log|C| + log|D|,
- where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
- the log determinant.
- """
- return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum(-1)
- def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
- r"""
- Uses "Woodbury matrix identity"::
- inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D),
- where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared
- Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`.
- """
- Wt_Dinv = W.mT / D.unsqueeze(-2)
- Wt_Dinv_x = _batch_mv(Wt_Dinv, x)
- mahalanobis_term1 = (x.pow(2) / D).sum(-1)
- mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x)
- return mahalanobis_term1 - mahalanobis_term2
- class LowRankMultivariateNormal(Distribution):
- r"""
- Creates a multivariate normal distribution with covariance matrix having a low-rank form
- parameterized by :attr:`cov_factor` and :attr:`cov_diag`::
- covariance_matrix = cov_factor @ cov_factor.T + cov_diag
- Example:
- >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
- >>> # xdoctest: +IGNORE_WANT("non-determenistic")
- >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2))
- >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]`
- tensor([-0.2102, -0.5429])
- Args:
- loc (Tensor): mean of the distribution with shape `batch_shape + event_shape`
- cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape
- `batch_shape + event_shape + (rank,)`
- cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape
- `batch_shape + event_shape`
- Note:
- The computation for determinant and inverse of covariance matrix is avoided when
- `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity
- <https://en.wikipedia.org/wiki/Woodbury_matrix_identity>`_ and
- `matrix determinant lemma <https://en.wikipedia.org/wiki/Matrix_determinant_lemma>`_.
- Thanks to these formulas, we just need to compute the determinant and inverse of
- the small size "capacitance" matrix::
- capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
- """
- arg_constraints = {"loc": constraints.real_vector,
- "cov_factor": constraints.independent(constraints.real, 2),
- "cov_diag": constraints.independent(constraints.positive, 1)}
- support = constraints.real_vector
- has_rsample = True
- def __init__(self, loc, cov_factor, cov_diag, validate_args=None):
- if loc.dim() < 1:
- raise ValueError("loc must be at least one-dimensional.")
- event_shape = loc.shape[-1:]
- if cov_factor.dim() < 2:
- raise ValueError("cov_factor must be at least two-dimensional, "
- "with optional leading batch dimensions")
- if cov_factor.shape[-2:-1] != event_shape:
- raise ValueError("cov_factor must be a batch of matrices with shape {} x m"
- .format(event_shape[0]))
- if cov_diag.shape[-1:] != event_shape:
- raise ValueError("cov_diag must be a batch of vectors with shape {}".format(event_shape))
- loc_ = loc.unsqueeze(-1)
- cov_diag_ = cov_diag.unsqueeze(-1)
- try:
- loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors(loc_, cov_factor, cov_diag_)
- except RuntimeError as e:
- raise ValueError("Incompatible batch shapes: loc {}, cov_factor {}, cov_diag {}"
- .format(loc.shape, cov_factor.shape, cov_diag.shape)) from e
- self.loc = loc_[..., 0]
- self.cov_diag = cov_diag_[..., 0]
- batch_shape = self.loc.shape[:-1]
- self._unbroadcasted_cov_factor = cov_factor
- self._unbroadcasted_cov_diag = cov_diag
- self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
- super().__init__(batch_shape, event_shape, validate_args=validate_args)
- def expand(self, batch_shape, _instance=None):
- new = self._get_checked_instance(LowRankMultivariateNormal, _instance)
- batch_shape = torch.Size(batch_shape)
- loc_shape = batch_shape + self.event_shape
- new.loc = self.loc.expand(loc_shape)
- new.cov_diag = self.cov_diag.expand(loc_shape)
- new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
- new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
- new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
- new._capacitance_tril = self._capacitance_tril
- super(LowRankMultivariateNormal, new).__init__(batch_shape,
- self.event_shape,
- validate_args=False)
- new._validate_args = self._validate_args
- return new
- @property
- def mean(self):
- return self.loc
- @property
- def mode(self):
- return self.loc
- @lazy_property
- def variance(self):
- return (self._unbroadcasted_cov_factor.pow(2).sum(-1)
- + self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape)
- @lazy_property
- def scale_tril(self):
- # The following identity is used to increase the numerically computation stability
- # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3):
- # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2
- # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
- # hence it is well-conditioned and safe to take Cholesky decomposition.
- n = self._event_shape[0]
- cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
- Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
- K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous()
- K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K
- scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K)
- return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)
- @lazy_property
- def covariance_matrix(self):
- covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor,
- self._unbroadcasted_cov_factor.mT)
- + torch.diag_embed(self._unbroadcasted_cov_diag))
- return covariance_matrix.expand(self._batch_shape + self._event_shape +
- self._event_shape)
- @lazy_property
- def precision_matrix(self):
- # We use "Woodbury matrix identity" to take advantage of low rank form::
- # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
- # where :math:`C` is the capacitance matrix.
- Wt_Dinv = (self._unbroadcasted_cov_factor.mT
- / self._unbroadcasted_cov_diag.unsqueeze(-2))
- A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False)
- precision_matrix = torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A
- return precision_matrix.expand(self._batch_shape + self._event_shape +
- self._event_shape)
- def rsample(self, sample_shape=torch.Size()):
- shape = self._extended_shape(sample_shape)
- W_shape = shape[:-1] + self.cov_factor.shape[-1:]
- eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
- eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
- return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
- + self._unbroadcasted_cov_diag.sqrt() * eps_D)
- def log_prob(self, value):
- if self._validate_args:
- self._validate_sample(value)
- diff = value - self.loc
- M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor,
- self._unbroadcasted_cov_diag,
- diff,
- self._capacitance_tril)
- log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
- self._unbroadcasted_cov_diag,
- self._capacitance_tril)
- return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
- def entropy(self):
- log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
- self._unbroadcasted_cov_diag,
- self._capacitance_tril)
- H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
- if len(self._batch_shape) == 0:
- return H
- else:
- return H.expand(self._batch_shape)
|