wishart.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. import math
  2. import warnings
  3. from numbers import Number
  4. from typing import Union
  5. import torch
  6. from torch import nan
  7. from torch.distributions import constraints
  8. from torch.distributions.exp_family import ExponentialFamily
  9. from torch.distributions.utils import lazy_property
  10. from torch.distributions.multivariate_normal import _precision_to_scale_tril
  11. __all__ = ['Wishart']
  12. _log_2 = math.log(2)
  13. def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:
  14. assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
  15. return torch.digamma(
  16. x.unsqueeze(-1)
  17. - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
  18. ).sum(-1)
  19. def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor:
  20. # We assume positive input for this function
  21. return x.clamp(min=torch.finfo(x.dtype).eps)
  22. class Wishart(ExponentialFamily):
  23. r"""
  24. Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
  25. or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
  26. Example:
  27. >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional")
  28. >>> m = Wishart(torch.eye(2), torch.Tensor([2]))
  29. >>> m.sample() # Wishart distributed with mean=`df * I` and
  30. >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
  31. Args:
  32. covariance_matrix (Tensor): positive-definite covariance matrix
  33. precision_matrix (Tensor): positive-definite precision matrix
  34. scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
  35. df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
  36. Note:
  37. Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
  38. :attr:`scale_tril` can be specified.
  39. Using :attr:`scale_tril` will be more efficient: all computations internally
  40. are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
  41. :attr:`precision_matrix` is passed instead, it is only used to compute
  42. the corresponding lower triangular matrices using a Cholesky decomposition.
  43. 'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
  44. **References**
  45. [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`.
  46. [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`.
  47. [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`.
  48. [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203.
  49. [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`.
  50. """
  51. arg_constraints = {
  52. 'covariance_matrix': constraints.positive_definite,
  53. 'precision_matrix': constraints.positive_definite,
  54. 'scale_tril': constraints.lower_cholesky,
  55. 'df': constraints.greater_than(0),
  56. }
  57. support = constraints.positive_definite
  58. has_rsample = True
  59. _mean_carrier_measure = 0
  60. def __init__(self,
  61. df: Union[torch.Tensor, Number],
  62. covariance_matrix: torch.Tensor = None,
  63. precision_matrix: torch.Tensor = None,
  64. scale_tril: torch.Tensor = None,
  65. validate_args=None):
  66. assert (covariance_matrix is not None) + (scale_tril is not None) + (precision_matrix is not None) == 1, \
  67. "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
  68. param = next(p for p in (covariance_matrix, precision_matrix, scale_tril) if p is not None)
  69. if param.dim() < 2:
  70. raise ValueError("scale_tril must be at least two-dimensional, with optional leading batch dimensions")
  71. if isinstance(df, Number):
  72. batch_shape = torch.Size(param.shape[:-2])
  73. self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
  74. else:
  75. batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
  76. self.df = df.expand(batch_shape)
  77. event_shape = param.shape[-2:]
  78. if self.df.le(event_shape[-1] - 1).any():
  79. raise ValueError(f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}.")
  80. if scale_tril is not None:
  81. self.scale_tril = param.expand(batch_shape + (-1, -1))
  82. elif covariance_matrix is not None:
  83. self.covariance_matrix = param.expand(batch_shape + (-1, -1))
  84. elif precision_matrix is not None:
  85. self.precision_matrix = param.expand(batch_shape + (-1, -1))
  86. self.arg_constraints['df'] = constraints.greater_than(event_shape[-1] - 1)
  87. if self.df.lt(event_shape[-1]).any():
  88. warnings.warn("Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim.")
  89. super().__init__(batch_shape, event_shape, validate_args=validate_args)
  90. self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
  91. if scale_tril is not None:
  92. self._unbroadcasted_scale_tril = scale_tril
  93. elif covariance_matrix is not None:
  94. self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
  95. else: # precision_matrix is not None
  96. self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
  97. # Chi2 distribution is needed for Bartlett decomposition sampling
  98. self._dist_chi2 = torch.distributions.chi2.Chi2(
  99. df=(
  100. self.df.unsqueeze(-1)
  101. - torch.arange(
  102. self._event_shape[-1],
  103. dtype=self._unbroadcasted_scale_tril.dtype,
  104. device=self._unbroadcasted_scale_tril.device,
  105. ).expand(batch_shape + (-1,))
  106. )
  107. )
  108. def expand(self, batch_shape, _instance=None):
  109. new = self._get_checked_instance(Wishart, _instance)
  110. batch_shape = torch.Size(batch_shape)
  111. cov_shape = batch_shape + self.event_shape
  112. new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
  113. new.df = self.df.expand(batch_shape)
  114. new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
  115. if 'covariance_matrix' in self.__dict__:
  116. new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
  117. if 'scale_tril' in self.__dict__:
  118. new.scale_tril = self.scale_tril.expand(cov_shape)
  119. if 'precision_matrix' in self.__dict__:
  120. new.precision_matrix = self.precision_matrix.expand(cov_shape)
  121. # Chi2 distribution is needed for Bartlett decomposition sampling
  122. new._dist_chi2 = torch.distributions.chi2.Chi2(
  123. df=(
  124. new.df.unsqueeze(-1)
  125. - torch.arange(
  126. self.event_shape[-1],
  127. dtype=new._unbroadcasted_scale_tril.dtype,
  128. device=new._unbroadcasted_scale_tril.device,
  129. ).expand(batch_shape + (-1,))
  130. )
  131. )
  132. super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
  133. new._validate_args = self._validate_args
  134. return new
  135. @lazy_property
  136. def scale_tril(self):
  137. return self._unbroadcasted_scale_tril.expand(
  138. self._batch_shape + self._event_shape)
  139. @lazy_property
  140. def covariance_matrix(self):
  141. return (
  142. self._unbroadcasted_scale_tril @ self._unbroadcasted_scale_tril.transpose(-2, -1)
  143. ).expand(self._batch_shape + self._event_shape)
  144. @lazy_property
  145. def precision_matrix(self):
  146. identity = torch.eye(
  147. self._event_shape[-1],
  148. device=self._unbroadcasted_scale_tril.device,
  149. dtype=self._unbroadcasted_scale_tril.dtype,
  150. )
  151. return torch.cholesky_solve(
  152. identity, self._unbroadcasted_scale_tril
  153. ).expand(self._batch_shape + self._event_shape)
  154. @property
  155. def mean(self):
  156. return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix
  157. @property
  158. def mode(self):
  159. factor = self.df - self.covariance_matrix.shape[-1] - 1
  160. factor[factor <= 0] = nan
  161. return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix
  162. @property
  163. def variance(self):
  164. V = self.covariance_matrix # has shape (batch_shape x event_shape)
  165. diag_V = V.diagonal(dim1=-2, dim2=-1)
  166. return self.df.view(self._batch_shape + (1, 1)) * (V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V))
  167. def _bartlett_sampling(self, sample_shape=torch.Size()):
  168. p = self._event_shape[-1] # has singleton shape
  169. # Implemented Sampling using Bartlett decomposition
  170. noise = _clamp_above_eps(
  171. self._dist_chi2.rsample(sample_shape).sqrt()
  172. ).diag_embed(dim1=-2, dim2=-1)
  173. i, j = torch.tril_indices(p, p, offset=-1)
  174. noise[..., i, j] = torch.randn(
  175. torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
  176. dtype=noise.dtype,
  177. device=noise.device,
  178. )
  179. chol = self._unbroadcasted_scale_tril @ noise
  180. return chol @ chol.transpose(-2, -1)
  181. def rsample(self, sample_shape=torch.Size(), max_try_correction=None):
  182. r"""
  183. .. warning::
  184. In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
  185. Several tries to correct singular samples are performed by default, but it may end up returning
  186. singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`.
  187. In those cases, the user should validate the samples and either fix the value of `df`
  188. or adjust `max_try_correction` value for argument in `.rsample` accordingly.
  189. """
  190. if max_try_correction is None:
  191. max_try_correction = 3 if torch._C._get_tracing_state() else 10
  192. sample_shape = torch.Size(sample_shape)
  193. sample = self._bartlett_sampling(sample_shape)
  194. # Below part is to improve numerical stability temporally and should be removed in the future
  195. is_singular = self.support.check(sample)
  196. if self._batch_shape:
  197. is_singular = is_singular.amax(self._batch_dims)
  198. if torch._C._get_tracing_state():
  199. # Less optimized version for JIT
  200. for _ in range(max_try_correction):
  201. sample_new = self._bartlett_sampling(sample_shape)
  202. sample = torch.where(is_singular, sample_new, sample)
  203. is_singular = ~self.support.check(sample)
  204. if self._batch_shape:
  205. is_singular = is_singular.amax(self._batch_dims)
  206. else:
  207. # More optimized version with data-dependent control flow.
  208. if is_singular.any():
  209. warnings.warn("Singular sample detected.")
  210. for _ in range(max_try_correction):
  211. sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
  212. sample[is_singular] = sample_new
  213. is_singular_new = ~self.support.check(sample_new)
  214. if self._batch_shape:
  215. is_singular_new = is_singular_new.amax(self._batch_dims)
  216. is_singular[is_singular.clone()] = is_singular_new
  217. if not is_singular.any():
  218. break
  219. return sample
  220. def log_prob(self, value):
  221. if self._validate_args:
  222. self._validate_sample(value)
  223. nu = self.df # has shape (batch_shape)
  224. p = self._event_shape[-1] # has singleton shape
  225. return (
  226. - nu * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
  227. - torch.mvlgamma(nu / 2, p=p)
  228. + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
  229. - torch.cholesky_solve(value, self._unbroadcasted_scale_tril).diagonal(dim1=-2, dim2=-1).sum(dim=-1) / 2
  230. )
  231. def entropy(self):
  232. nu = self.df # has shape (batch_shape)
  233. p = self._event_shape[-1] # has singleton shape
  234. V = self.covariance_matrix # has shape (batch_shape x event_shape)
  235. return (
  236. (p + 1) * (p * _log_2 / 2 + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1))
  237. + torch.mvlgamma(nu / 2, p=p)
  238. - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
  239. + nu * p / 2
  240. )
  241. @property
  242. def _natural_params(self):
  243. nu = self.df # has shape (batch_shape)
  244. p = self._event_shape[-1] # has singleton shape
  245. return - self.precision_matrix / 2, (nu - p - 1) / 2
  246. def _log_normalizer(self, x, y):
  247. p = self._event_shape[-1]
  248. return (
  249. (y + (p + 1) / 2) * (- torch.linalg.slogdet(- 2 * x).logabsdet + _log_2 * p)
  250. + torch.mvlgamma(y + (p + 1) / 2, p=p)
  251. )