_lowrank.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. """Implement various linear algebra algorithms for low rank matrices.
  2. """
  3. __all__ = ["svd_lowrank", "pca_lowrank"]
  4. from typing import Optional, Tuple
  5. import torch
  6. from torch import Tensor
  7. from . import _linalg_utils as _utils
  8. from .overrides import handle_torch_function, has_torch_function
  9. def get_approximate_basis(
  10. A: Tensor, q: int, niter: Optional[int] = 2, M: Optional[Tensor] = None
  11. ) -> Tensor:
  12. """Return tensor :math:`Q` with :math:`q` orthonormal columns such
  13. that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is
  14. specified, then :math:`Q` is such that :math:`Q Q^H (A - M)`
  15. approximates :math:`A - M`.
  16. .. note:: The implementation is based on the Algorithm 4.4 from
  17. Halko et al, 2009.
  18. .. note:: For an adequate approximation of a k-rank matrix
  19. :math:`A`, where k is not known in advance but could be
  20. estimated, the number of :math:`Q` columns, q, can be
  21. choosen according to the following criteria: in general,
  22. :math:`k <= q <= min(2*k, m, n)`. For large low-rank
  23. matrices, take :math:`q = k + 5..10`. If k is
  24. relatively small compared to :math:`min(m, n)`, choosing
  25. :math:`q = k + 0..2` may be sufficient.
  26. .. note:: To obtain repeatable results, reset the seed for the
  27. pseudorandom number generator
  28. Args::
  29. A (Tensor): the input tensor of size :math:`(*, m, n)`
  30. q (int): the dimension of subspace spanned by :math:`Q`
  31. columns.
  32. niter (int, optional): the number of subspace iterations to
  33. conduct; ``niter`` must be a
  34. nonnegative integer. In most cases, the
  35. default value 2 is more than enough.
  36. M (Tensor, optional): the input tensor's mean of size
  37. :math:`(*, 1, n)`.
  38. References::
  39. - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
  40. structure with randomness: probabilistic algorithms for
  41. constructing approximate matrix decompositions,
  42. arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
  43. `arXiv <http://arxiv.org/abs/0909.4061>`_).
  44. """
  45. niter = 2 if niter is None else niter
  46. m, n = A.shape[-2:]
  47. dtype = _utils.get_floating_dtype(A)
  48. matmul = _utils.matmul
  49. R = torch.randn(n, q, dtype=dtype, device=A.device)
  50. # The following code could be made faster using torch.geqrf + torch.ormqr
  51. # but geqrf is not differentiable
  52. A_H = _utils.transjugate(A)
  53. if M is None:
  54. Q = torch.linalg.qr(matmul(A, R)).Q
  55. for i in range(niter):
  56. Q = torch.linalg.qr(matmul(A_H, Q)).Q
  57. Q = torch.linalg.qr(matmul(A, Q)).Q
  58. else:
  59. M_H = _utils.transjugate(M)
  60. Q = torch.linalg.qr(matmul(A, R) - matmul(M, R)).Q
  61. for i in range(niter):
  62. Q = torch.linalg.qr(matmul(A_H, Q) - matmul(M_H, Q)).Q
  63. Q = torch.linalg.qr(matmul(A, Q) - matmul(M, Q)).Q
  64. return Q
  65. def svd_lowrank(
  66. A: Tensor,
  67. q: Optional[int] = 6,
  68. niter: Optional[int] = 2,
  69. M: Optional[Tensor] = None,
  70. ) -> Tuple[Tensor, Tensor, Tensor]:
  71. r"""Return the singular value decomposition ``(U, S, V)`` of a matrix,
  72. batches of matrices, or a sparse matrix :math:`A` such that
  73. :math:`A \approx U diag(S) V^T`. In case :math:`M` is given, then
  74. SVD is computed for the matrix :math:`A - M`.
  75. .. note:: The implementation is based on the Algorithm 5.1 from
  76. Halko et al, 2009.
  77. .. note:: To obtain repeatable results, reset the seed for the
  78. pseudorandom number generator
  79. .. note:: The input is assumed to be a low-rank matrix.
  80. .. note:: In general, use the full-rank SVD implementation
  81. :func:`torch.linalg.svd` for dense matrices due to its 10-fold
  82. higher performance characteristics. The low-rank SVD
  83. will be useful for huge sparse matrices that
  84. :func:`torch.linalg.svd` cannot handle.
  85. Args::
  86. A (Tensor): the input tensor of size :math:`(*, m, n)`
  87. q (int, optional): a slightly overestimated rank of A.
  88. niter (int, optional): the number of subspace iterations to
  89. conduct; niter must be a nonnegative
  90. integer, and defaults to 2
  91. M (Tensor, optional): the input tensor's mean of size
  92. :math:`(*, 1, n)`.
  93. References::
  94. - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
  95. structure with randomness: probabilistic algorithms for
  96. constructing approximate matrix decompositions,
  97. arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
  98. `arXiv <https://arxiv.org/abs/0909.4061>`_).
  99. """
  100. if not torch.jit.is_scripting():
  101. tensor_ops = (A, M)
  102. if not set(map(type, tensor_ops)).issubset(
  103. (torch.Tensor, type(None))
  104. ) and has_torch_function(tensor_ops):
  105. return handle_torch_function(
  106. svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M
  107. )
  108. return _svd_lowrank(A, q=q, niter=niter, M=M)
  109. def _svd_lowrank(
  110. A: Tensor,
  111. q: Optional[int] = 6,
  112. niter: Optional[int] = 2,
  113. M: Optional[Tensor] = None,
  114. ) -> Tuple[Tensor, Tensor, Tensor]:
  115. q = 6 if q is None else q
  116. m, n = A.shape[-2:]
  117. matmul = _utils.matmul
  118. if M is None:
  119. M_t = None
  120. else:
  121. M_t = _utils.transpose(M)
  122. A_t = _utils.transpose(A)
  123. # Algorithm 5.1 in Halko et al 2009, slightly modified to reduce
  124. # the number conjugate and transpose operations
  125. if m < n or n > q:
  126. # computing the SVD approximation of a transpose in
  127. # order to keep B shape minimal (the m < n case) or the V
  128. # shape small (the n > q case)
  129. Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
  130. Q_c = _utils.conjugate(Q)
  131. if M is None:
  132. B_t = matmul(A, Q_c)
  133. else:
  134. B_t = matmul(A, Q_c) - matmul(M, Q_c)
  135. assert B_t.shape[-2] == m, (B_t.shape, m)
  136. assert B_t.shape[-1] == q, (B_t.shape, q)
  137. assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
  138. U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
  139. V = Vh.mH
  140. V = Q.matmul(V)
  141. else:
  142. Q = get_approximate_basis(A, q, niter=niter, M=M)
  143. Q_c = _utils.conjugate(Q)
  144. if M is None:
  145. B = matmul(A_t, Q_c)
  146. else:
  147. B = matmul(A_t, Q_c) - matmul(M_t, Q_c)
  148. B_t = _utils.transpose(B)
  149. assert B_t.shape[-2] == q, (B_t.shape, q)
  150. assert B_t.shape[-1] == n, (B_t.shape, n)
  151. assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
  152. U, S, Vh = torch.linalg.svd(B_t, full_matrices=False)
  153. V = Vh.mH
  154. U = Q.matmul(U)
  155. return U, S, V
  156. def pca_lowrank(
  157. A: Tensor, q: Optional[int] = None, center: bool = True, niter: int = 2
  158. ) -> Tuple[Tensor, Tensor, Tensor]:
  159. r"""Performs linear Principal Component Analysis (PCA) on a low-rank
  160. matrix, batches of such matrices, or sparse matrix.
  161. This function returns a namedtuple ``(U, S, V)`` which is the
  162. nearly optimal approximation of a singular value decomposition of
  163. a centered matrix :math:`A` such that :math:`A = U diag(S) V^T`.
  164. .. note:: The relation of ``(U, S, V)`` to PCA is as follows:
  165. - :math:`A` is a data matrix with ``m`` samples and
  166. ``n`` features
  167. - the :math:`V` columns represent the principal directions
  168. - :math:`S ** 2 / (m - 1)` contains the eigenvalues of
  169. :math:`A^T A / (m - 1)` which is the covariance of
  170. ``A`` when ``center=True`` is provided.
  171. - ``matmul(A, V[:, :k])`` projects data to the first k
  172. principal components
  173. .. note:: Different from the standard SVD, the size of returned
  174. matrices depend on the specified rank and q
  175. values as follows:
  176. - :math:`U` is m x q matrix
  177. - :math:`S` is q-vector
  178. - :math:`V` is n x q matrix
  179. .. note:: To obtain repeatable results, reset the seed for the
  180. pseudorandom number generator
  181. Args:
  182. A (Tensor): the input tensor of size :math:`(*, m, n)`
  183. q (int, optional): a slightly overestimated rank of
  184. :math:`A`. By default, ``q = min(6, m,
  185. n)``.
  186. center (bool, optional): if True, center the input tensor,
  187. otherwise, assume that the input is
  188. centered.
  189. niter (int, optional): the number of subspace iterations to
  190. conduct; niter must be a nonnegative
  191. integer, and defaults to 2.
  192. References::
  193. - Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding
  194. structure with randomness: probabilistic algorithms for
  195. constructing approximate matrix decompositions,
  196. arXiv:0909.4061 [math.NA; math.PR], 2009 (available at
  197. `arXiv <http://arxiv.org/abs/0909.4061>`_).
  198. """
  199. if not torch.jit.is_scripting():
  200. if type(A) is not torch.Tensor and has_torch_function((A,)):
  201. return handle_torch_function(
  202. pca_lowrank, (A,), A, q=q, center=center, niter=niter
  203. )
  204. (m, n) = A.shape[-2:]
  205. if q is None:
  206. q = min(6, m, n)
  207. elif not (q >= 0 and q <= min(m, n)):
  208. raise ValueError(
  209. "q(={}) must be non-negative integer"
  210. " and not greater than min(m, n)={}".format(q, min(m, n))
  211. )
  212. if not (niter >= 0):
  213. raise ValueError("niter(={}) must be non-negative integer".format(niter))
  214. dtype = _utils.get_floating_dtype(A)
  215. if not center:
  216. return _svd_lowrank(A, q, niter=niter, M=None)
  217. if _utils.is_sparse(A):
  218. if len(A.shape) != 2:
  219. raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor")
  220. c = torch.sparse.sum(A, dim=(-2,)) / m
  221. # reshape c
  222. column_indices = c.indices()[0]
  223. indices = torch.zeros(
  224. 2,
  225. len(column_indices),
  226. dtype=column_indices.dtype,
  227. device=column_indices.device,
  228. )
  229. indices[0] = column_indices
  230. C_t = torch.sparse_coo_tensor(
  231. indices, c.values(), (n, 1), dtype=dtype, device=A.device
  232. )
  233. ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device)
  234. M = _utils.transpose(torch.sparse.mm(C_t, ones_m1_t))
  235. return _svd_lowrank(A, q, niter=niter, M=M)
  236. else:
  237. C = A.mean(dim=(-2,), keepdim=True)
  238. return _svd_lowrank(A - C, q, niter=niter, M=None)