_lobpcg.py 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175
  1. """Locally Optimal Block Preconditioned Conjugate Gradient methods.
  2. """
  3. # Author: Pearu Peterson
  4. # Created: February 2020
  5. from typing import Dict, Optional, Tuple
  6. import torch
  7. from torch import Tensor
  8. from . import _linalg_utils as _utils
  9. from .overrides import handle_torch_function, has_torch_function
  10. __all__ = ["lobpcg"]
  11. def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
  12. # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
  13. F = D.unsqueeze(-2) - D.unsqueeze(-1)
  14. F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
  15. F.pow_(-1)
  16. # A.grad = U (D.grad + (U^T U.grad * F)) U^T
  17. Ut = U.mT.contiguous()
  18. res = torch.matmul(
  19. U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)
  20. )
  21. return res
  22. def _polynomial_coefficients_given_roots(roots):
  23. """
  24. Given the `roots` of a polynomial, find the polynomial's coefficients.
  25. If roots = (r_1, ..., r_n), then the method returns
  26. coefficients (a_0, a_1, ..., a_n (== 1)) so that
  27. p(x) = (x - r_1) * ... * (x - r_n)
  28. = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
  29. Note: for better performance requires writing a low-level kernel
  30. """
  31. poly_order = roots.shape[-1]
  32. poly_coeffs_shape = list(roots.shape)
  33. # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
  34. # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
  35. # but we insert one extra coefficient to enable better vectorization below
  36. poly_coeffs_shape[-1] += 2
  37. poly_coeffs = roots.new_zeros(poly_coeffs_shape)
  38. poly_coeffs[..., 0] = 1
  39. poly_coeffs[..., -1] = 1
  40. # perform the Horner's rule
  41. for i in range(1, poly_order + 1):
  42. # note that it is computationally hard to compute backward for this method,
  43. # because then given the coefficients it would require finding the roots and/or
  44. # calculating the sensitivity based on the Vieta's theorem.
  45. # So the code below tries to circumvent the explicit root finding by series
  46. # of operations on memory copies imitating the Horner's method.
  47. # The memory copies are required to construct nodes in the computational graph
  48. # by exploting the explicit (not in-place, separate node for each step)
  49. # recursion of the Horner's method.
  50. # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
  51. poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
  52. out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
  53. out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
  54. -1, poly_order - i + 1, i + 1
  55. )
  56. poly_coeffs = poly_coeffs_new
  57. return poly_coeffs.narrow(-1, 1, poly_order + 1)
  58. def _polynomial_value(poly, x, zero_power, transition):
  59. """
  60. A generic method for computing poly(x) using the Horner's rule.
  61. Args:
  62. poly (Tensor): the (possibly batched) 1D Tensor representing
  63. polynomial coefficients such that
  64. poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
  65. poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
  66. x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
  67. zero_power (Tensor): the representation of `x^0`. It is application-specific.
  68. transition (Callable): the function that accepts some intermediate result `int_val`,
  69. the `x` and a specific polynomial coefficient
  70. `poly[..., k]` for some iteration `k`.
  71. It basically performs one iteration of the Horner's rule
  72. defined as `x * int_val + poly[..., k] * zero_power`.
  73. Note that `zero_power` is not a parameter,
  74. because the step `+ poly[..., k] * zero_power` depends on `x`,
  75. whether it is a vector, a matrix, or something else, so this
  76. functionality is delegated to the user.
  77. """
  78. res = zero_power.clone()
  79. for k in range(poly.size(-1) - 2, -1, -1):
  80. res = transition(res, x, poly[..., k])
  81. return res
  82. def _matrix_polynomial_value(poly, x, zero_power=None):
  83. """
  84. Evaluates `poly(x)` for the (batched) matrix input `x`.
  85. Check out `_polynomial_value` function for more details.
  86. """
  87. # matrix-aware Horner's rule iteration
  88. def transition(curr_poly_val, x, poly_coeff):
  89. res = x.matmul(curr_poly_val)
  90. res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
  91. return res
  92. if zero_power is None:
  93. zero_power = torch.eye(
  94. x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
  95. ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
  96. return _polynomial_value(poly, x, zero_power, transition)
  97. def _vector_polynomial_value(poly, x, zero_power=None):
  98. """
  99. Evaluates `poly(x)` for the (batched) vector input `x`.
  100. Check out `_polynomial_value` function for more details.
  101. """
  102. # vector-aware Horner's rule iteration
  103. def transition(curr_poly_val, x, poly_coeff):
  104. res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
  105. return res
  106. if zero_power is None:
  107. zero_power = x.new_ones(1).expand(x.shape)
  108. return _polynomial_value(poly, x, zero_power, transition)
  109. def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
  110. # compute a projection operator onto an orthogonal subspace spanned by the
  111. # columns of U defined as (I - UU^T)
  112. Ut = U.mT.contiguous()
  113. proj_U_ortho = -U.matmul(Ut)
  114. proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
  115. # compute U_ortho, a basis for the orthogonal complement to the span(U),
  116. # by projecting a random [..., m, m - k] matrix onto the subspace spanned
  117. # by the columns of U.
  118. #
  119. # fix generator for determinism
  120. gen = torch.Generator(A.device)
  121. # orthogonal complement to the span(U)
  122. U_ortho = proj_U_ortho.matmul(
  123. torch.randn(
  124. (*A.shape[:-1], A.size(-1) - D.size(-1)),
  125. dtype=A.dtype,
  126. device=A.device,
  127. generator=gen,
  128. )
  129. )
  130. U_ortho_t = U_ortho.mT.contiguous()
  131. # compute the coefficients of the characteristic polynomial of the tensor D.
  132. # Note that D is diagonal, so the diagonal elements are exactly the roots
  133. # of the characteristic polynomial.
  134. chr_poly_D = _polynomial_coefficients_given_roots(D)
  135. # the code belows finds the explicit solution to the Sylvester equation
  136. # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
  137. # and incorporates it into the whole gradient stored in the `res` variable.
  138. #
  139. # Equivalent to the following naive implementation:
  140. # res = A.new_zeros(A.shape)
  141. # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
  142. # for k in range(1, chr_poly_D.size(-1)):
  143. # p_res.zero_()
  144. # for i in range(0, k):
  145. # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
  146. # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t())
  147. #
  148. # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
  149. # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
  150. # and we need to compute g(U_grad, A, U, D)
  151. #
  152. # The naive implementation is based on the paper
  153. # Hu, Qingxi, and Daizhan Cheng.
  154. # "The polynomial solution to the Sylvester matrix equation."
  155. # Applied mathematics letters 19.9 (2006): 859-864.
  156. #
  157. # We can modify the computation of `p_res` from above in a more efficient way
  158. # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
  159. # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
  160. # + ...
  161. # + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
  162. # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
  163. U_grad_projected = U_grad
  164. series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
  165. for k in range(1, chr_poly_D.size(-1)):
  166. poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
  167. series_acc += U_grad_projected * poly_D.unsqueeze(-2)
  168. U_grad_projected = A.matmul(U_grad_projected)
  169. # compute chr_poly_D(A) which essentially is:
  170. #
  171. # chr_poly_D_at_A = A.new_zeros(A.shape)
  172. # for k in range(chr_poly_D.size(-1)):
  173. # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
  174. #
  175. # Note, however, for better performance we use the Horner's rule
  176. chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
  177. # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
  178. chr_poly_D_at_A_to_U_ortho = torch.matmul(
  179. U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
  180. )
  181. # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
  182. # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
  183. # Cholesky decomposition requires the input to be positive-definite.
  184. # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
  185. # 1. `largest` == False, or
  186. # 2. `largest` == True and `k` is even
  187. # under the assumption that `A` has distinct eigenvalues.
  188. #
  189. # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
  190. chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
  191. chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
  192. chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
  193. )
  194. # compute the gradient part in span(U)
  195. res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
  196. # incorporate the Sylvester equation solution into the full gradient
  197. # it resides in span(U_ortho)
  198. res -= U_ortho.matmul(
  199. chr_poly_D_at_A_to_U_ortho_sign
  200. * torch.cholesky_solve(
  201. U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L
  202. )
  203. ).matmul(Ut)
  204. return res
  205. def _symeig_backward(D_grad, U_grad, A, D, U, largest):
  206. # if `U` is square, then the columns of `U` is a complete eigenspace
  207. if U.size(-1) == U.size(-2):
  208. return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
  209. else:
  210. return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)
  211. class LOBPCGAutogradFunction(torch.autograd.Function):
  212. @staticmethod
  213. def forward( # type: ignore[override]
  214. ctx,
  215. A: Tensor,
  216. k: Optional[int] = None,
  217. B: Optional[Tensor] = None,
  218. X: Optional[Tensor] = None,
  219. n: Optional[int] = None,
  220. iK: Optional[Tensor] = None,
  221. niter: Optional[int] = None,
  222. tol: Optional[float] = None,
  223. largest: Optional[bool] = None,
  224. method: Optional[str] = None,
  225. tracker: None = None,
  226. ortho_iparams: Optional[Dict[str, int]] = None,
  227. ortho_fparams: Optional[Dict[str, float]] = None,
  228. ortho_bparams: Optional[Dict[str, bool]] = None,
  229. ) -> Tuple[Tensor, Tensor]:
  230. # makes sure that input is contiguous for efficiency.
  231. # Note: autograd does not support dense gradients for sparse input yet.
  232. A = A.contiguous() if (not A.is_sparse) else A
  233. if B is not None:
  234. B = B.contiguous() if (not B.is_sparse) else B
  235. D, U = _lobpcg(
  236. A,
  237. k,
  238. B,
  239. X,
  240. n,
  241. iK,
  242. niter,
  243. tol,
  244. largest,
  245. method,
  246. tracker,
  247. ortho_iparams,
  248. ortho_fparams,
  249. ortho_bparams,
  250. )
  251. ctx.save_for_backward(A, B, D, U)
  252. ctx.largest = largest
  253. return D, U
  254. @staticmethod
  255. def backward(ctx, D_grad, U_grad):
  256. A_grad = B_grad = None
  257. grads = [None] * 14
  258. A, B, D, U = ctx.saved_tensors
  259. largest = ctx.largest
  260. # lobpcg.backward has some limitations. Checks for unsupported input
  261. if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
  262. raise ValueError(
  263. "lobpcg.backward does not support sparse input yet."
  264. "Note that lobpcg.forward does though."
  265. )
  266. if (
  267. A.dtype in (torch.complex64, torch.complex128)
  268. or B is not None
  269. and B.dtype in (torch.complex64, torch.complex128)
  270. ):
  271. raise ValueError(
  272. "lobpcg.backward does not support complex input yet."
  273. "Note that lobpcg.forward does though."
  274. )
  275. if B is not None:
  276. raise ValueError(
  277. "lobpcg.backward does not support backward with B != I yet."
  278. )
  279. if largest is None:
  280. largest = True
  281. # symeig backward
  282. if B is None:
  283. A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)
  284. # A has index 0
  285. grads[0] = A_grad
  286. # B has index 2
  287. grads[2] = B_grad
  288. return tuple(grads)
  289. def lobpcg(
  290. A: Tensor,
  291. k: Optional[int] = None,
  292. B: Optional[Tensor] = None,
  293. X: Optional[Tensor] = None,
  294. n: Optional[int] = None,
  295. iK: Optional[Tensor] = None,
  296. niter: Optional[int] = None,
  297. tol: Optional[float] = None,
  298. largest: Optional[bool] = None,
  299. method: Optional[str] = None,
  300. tracker: None = None,
  301. ortho_iparams: Optional[Dict[str, int]] = None,
  302. ortho_fparams: Optional[Dict[str, float]] = None,
  303. ortho_bparams: Optional[Dict[str, bool]] = None,
  304. ) -> Tuple[Tensor, Tensor]:
  305. """Find the k largest (or smallest) eigenvalues and the corresponding
  306. eigenvectors of a symmetric positive definite generalized
  307. eigenvalue problem using matrix-free LOBPCG methods.
  308. This function is a front-end to the following LOBPCG algorithms
  309. selectable via `method` argument:
  310. `method="basic"` - the LOBPCG method introduced by Andrew
  311. Knyazev, see [Knyazev2001]. A less robust method, may fail when
  312. Cholesky is applied to singular input.
  313. `method="ortho"` - the LOBPCG method with orthogonal basis
  314. selection [StathopoulosEtal2002]. A robust method.
  315. Supported inputs are dense, sparse, and batches of dense matrices.
  316. .. note:: In general, the basic method spends least time per
  317. iteration. However, the robust methods converge much faster and
  318. are more stable. So, the usage of the basic method is generally
  319. not recommended but there exist cases where the usage of the
  320. basic method may be preferred.
  321. .. warning:: The backward method does not support sparse and complex inputs.
  322. It works only when `B` is not provided (i.e. `B == None`).
  323. We are actively working on extensions, and the details of
  324. the algorithms are going to be published promptly.
  325. .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
  326. To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
  327. in first-order optimization routines, prior to running `lobpcg`
  328. we do the following symmetrization map: `A -> (A + A.t()) / 2`.
  329. The map is performed only when the `A` requires gradients.
  330. Args:
  331. A (Tensor): the input tensor of size :math:`(*, m, m)`
  332. B (Tensor, optional): the input tensor of size :math:`(*, m,
  333. m)`. When not specified, `B` is interpreted as
  334. identity matrix.
  335. X (tensor, optional): the input tensor of size :math:`(*, m, n)`
  336. where `k <= n <= m`. When specified, it is used as
  337. initial approximation of eigenvectors. X must be a
  338. dense tensor.
  339. iK (tensor, optional): the input tensor of size :math:`(*, m,
  340. m)`. When specified, it will be used as preconditioner.
  341. k (integer, optional): the number of requested
  342. eigenpairs. Default is the number of :math:`X`
  343. columns (when specified) or `1`.
  344. n (integer, optional): if :math:`X` is not specified then `n`
  345. specifies the size of the generated random
  346. approximation of eigenvectors. Default value for `n`
  347. is `k`. If :math:`X` is specified, the value of `n`
  348. (when specified) must be the number of :math:`X`
  349. columns.
  350. tol (float, optional): residual tolerance for stopping
  351. criterion. Default is `feps ** 0.5` where `feps` is
  352. smallest non-zero floating-point number of the given
  353. input tensor `A` data type.
  354. largest (bool, optional): when True, solve the eigenproblem for
  355. the largest eigenvalues. Otherwise, solve the
  356. eigenproblem for smallest eigenvalues. Default is
  357. `True`.
  358. method (str, optional): select LOBPCG method. See the
  359. description of the function above. Default is
  360. "ortho".
  361. niter (int, optional): maximum number of iterations. When
  362. reached, the iteration process is hard-stopped and
  363. the current approximation of eigenpairs is returned.
  364. For infinite iteration but until convergence criteria
  365. is met, use `-1`.
  366. tracker (callable, optional) : a function for tracing the
  367. iteration process. When specified, it is called at
  368. each iteration step with LOBPCG instance as an
  369. argument. The LOBPCG instance holds the full state of
  370. the iteration process in the following attributes:
  371. `iparams`, `fparams`, `bparams` - dictionaries of
  372. integer, float, and boolean valued input
  373. parameters, respectively
  374. `ivars`, `fvars`, `bvars`, `tvars` - dictionaries
  375. of integer, float, boolean, and Tensor valued
  376. iteration variables, respectively.
  377. `A`, `B`, `iK` - input Tensor arguments.
  378. `E`, `X`, `S`, `R` - iteration Tensor variables.
  379. For instance:
  380. `ivars["istep"]` - the current iteration step
  381. `X` - the current approximation of eigenvectors
  382. `E` - the current approximation of eigenvalues
  383. `R` - the current residual
  384. `ivars["converged_count"]` - the current number of converged eigenpairs
  385. `tvars["rerr"]` - the current state of convergence criteria
  386. Note that when `tracker` stores Tensor objects from
  387. the LOBPCG instance, it must make copies of these.
  388. If `tracker` sets `bvars["force_stop"] = True`, the
  389. iteration process will be hard-stopped.
  390. ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
  391. various parameters to LOBPCG algorithm when using
  392. `method="ortho"`.
  393. Returns:
  394. E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
  395. X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
  396. References:
  397. [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
  398. Preconditioned Eigensolver: Locally Optimal Block Preconditioned
  399. Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
  400. 517-541. (25 pages)
  401. https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
  402. [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
  403. Wu. (2002) A Block Orthogonalization Procedure with Constant
  404. Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
  405. 2165-2182. (18 pages)
  406. https://epubs.siam.org/doi/10.1137/S1064827500370883
  407. [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
  408. Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
  409. SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
  410. https://epubs.siam.org/doi/abs/10.1137/17M1129830
  411. """
  412. if not torch.jit.is_scripting():
  413. tensor_ops = (A, B, X, iK)
  414. if not set(map(type, tensor_ops)).issubset(
  415. (torch.Tensor, type(None))
  416. ) and has_torch_function(tensor_ops):
  417. return handle_torch_function(
  418. lobpcg,
  419. tensor_ops,
  420. A,
  421. k=k,
  422. B=B,
  423. X=X,
  424. n=n,
  425. iK=iK,
  426. niter=niter,
  427. tol=tol,
  428. largest=largest,
  429. method=method,
  430. tracker=tracker,
  431. ortho_iparams=ortho_iparams,
  432. ortho_fparams=ortho_fparams,
  433. ortho_bparams=ortho_bparams,
  434. )
  435. if not torch._jit_internal.is_scripting():
  436. if A.requires_grad or (B is not None and B.requires_grad):
  437. # While it is expected that `A` is symmetric,
  438. # the `A_grad` might be not. Therefore we perform the trick below,
  439. # so that `A_grad` becomes symmetric.
  440. # The symmetrization is important for first-order optimization methods,
  441. # so that (A - alpha * A_grad) is still a symmetric matrix.
  442. # Same holds for `B`.
  443. A_sym = (A + A.mT) / 2
  444. B_sym = (B + B.mT) / 2 if (B is not None) else None
  445. return LOBPCGAutogradFunction.apply(
  446. A_sym,
  447. k,
  448. B_sym,
  449. X,
  450. n,
  451. iK,
  452. niter,
  453. tol,
  454. largest,
  455. method,
  456. tracker,
  457. ortho_iparams,
  458. ortho_fparams,
  459. ortho_bparams,
  460. )
  461. else:
  462. if A.requires_grad or (B is not None and B.requires_grad):
  463. raise RuntimeError(
  464. "Script and require grads is not supported atm."
  465. "If you just want to do the forward, use .detach()"
  466. "on A and B before calling into lobpcg"
  467. )
  468. return _lobpcg(
  469. A,
  470. k,
  471. B,
  472. X,
  473. n,
  474. iK,
  475. niter,
  476. tol,
  477. largest,
  478. method,
  479. tracker,
  480. ortho_iparams,
  481. ortho_fparams,
  482. ortho_bparams,
  483. )
  484. def _lobpcg(
  485. A: Tensor,
  486. k: Optional[int] = None,
  487. B: Optional[Tensor] = None,
  488. X: Optional[Tensor] = None,
  489. n: Optional[int] = None,
  490. iK: Optional[Tensor] = None,
  491. niter: Optional[int] = None,
  492. tol: Optional[float] = None,
  493. largest: Optional[bool] = None,
  494. method: Optional[str] = None,
  495. tracker: None = None,
  496. ortho_iparams: Optional[Dict[str, int]] = None,
  497. ortho_fparams: Optional[Dict[str, float]] = None,
  498. ortho_bparams: Optional[Dict[str, bool]] = None,
  499. ) -> Tuple[Tensor, Tensor]:
  500. # A must be square:
  501. assert A.shape[-2] == A.shape[-1], A.shape
  502. if B is not None:
  503. # A and B must have the same shapes:
  504. assert A.shape == B.shape, (A.shape, B.shape)
  505. dtype = _utils.get_floating_dtype(A)
  506. device = A.device
  507. if tol is None:
  508. feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype]
  509. tol = feps**0.5
  510. m = A.shape[-1]
  511. k = (1 if X is None else X.shape[-1]) if k is None else k
  512. n = (k if n is None else n) if X is None else X.shape[-1]
  513. if m < 3 * n:
  514. raise ValueError(
  515. "LPBPCG algorithm is not applicable when the number of A rows (={})"
  516. " is smaller than 3 x the number of requested eigenpairs (={})".format(m, n)
  517. )
  518. method = "ortho" if method is None else method
  519. iparams = {
  520. "m": m,
  521. "n": n,
  522. "k": k,
  523. "niter": 1000 if niter is None else niter,
  524. }
  525. fparams = {
  526. "tol": tol,
  527. }
  528. bparams = {"largest": True if largest is None else largest}
  529. if method == "ortho":
  530. if ortho_iparams is not None:
  531. iparams.update(ortho_iparams)
  532. if ortho_fparams is not None:
  533. fparams.update(ortho_fparams)
  534. if ortho_bparams is not None:
  535. bparams.update(ortho_bparams)
  536. iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3)
  537. iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3)
  538. fparams["ortho_tol"] = fparams.get("ortho_tol", tol)
  539. fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol)
  540. fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol)
  541. bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False)
  542. if not torch.jit.is_scripting():
  543. LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[assignment]
  544. if len(A.shape) > 2:
  545. N = int(torch.prod(torch.tensor(A.shape[:-2])))
  546. bA = A.reshape((N,) + A.shape[-2:])
  547. bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
  548. bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
  549. bE = torch.empty((N, k), dtype=dtype, device=device)
  550. bXret = torch.empty((N, m, k), dtype=dtype, device=device)
  551. for i in range(N):
  552. A_ = bA[i]
  553. B_ = bB[i] if bB is not None else None
  554. X_ = (
  555. torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
  556. )
  557. assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
  558. iparams["batch_index"] = i
  559. worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
  560. worker.run()
  561. bE[i] = worker.E[:k]
  562. bXret[i] = worker.X[:, :k]
  563. if not torch.jit.is_scripting():
  564. LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[assignment]
  565. return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
  566. X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
  567. assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
  568. worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
  569. worker.run()
  570. if not torch.jit.is_scripting():
  571. LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[assignment]
  572. return worker.E[:k], worker.X[:, :k]
  573. class LOBPCG:
  574. """Worker class of LOBPCG methods."""
  575. def __init__(
  576. self,
  577. A: Optional[Tensor],
  578. B: Optional[Tensor],
  579. X: Tensor,
  580. iK: Optional[Tensor],
  581. iparams: Dict[str, int],
  582. fparams: Dict[str, float],
  583. bparams: Dict[str, bool],
  584. method: str,
  585. tracker: None,
  586. ) -> None:
  587. # constant parameters
  588. self.A = A
  589. self.B = B
  590. self.iK = iK
  591. self.iparams = iparams
  592. self.fparams = fparams
  593. self.bparams = bparams
  594. self.method = method
  595. self.tracker = tracker
  596. m = iparams["m"]
  597. n = iparams["n"]
  598. # variable parameters
  599. self.X = X
  600. self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
  601. self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
  602. self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
  603. self.tvars: Dict[str, Tensor] = {}
  604. self.ivars: Dict[str, int] = {"istep": 0}
  605. self.fvars: Dict[str, float] = {"_": 0.0}
  606. self.bvars: Dict[str, bool] = {"_": False}
  607. def __str__(self):
  608. lines = ["LOPBCG:"]
  609. lines += [" iparams={}".format(self.iparams)]
  610. lines += [" fparams={}".format(self.fparams)]
  611. lines += [" bparams={}".format(self.bparams)]
  612. lines += [" ivars={}".format(self.ivars)]
  613. lines += [" fvars={}".format(self.fvars)]
  614. lines += [" bvars={}".format(self.bvars)]
  615. lines += [" tvars={}".format(self.tvars)]
  616. lines += [" A={}".format(self.A)]
  617. lines += [" B={}".format(self.B)]
  618. lines += [" iK={}".format(self.iK)]
  619. lines += [" X={}".format(self.X)]
  620. lines += [" E={}".format(self.E)]
  621. r = ""
  622. for line in lines:
  623. r += line + "\n"
  624. return r
  625. def update(self):
  626. """Set and update iteration variables."""
  627. if self.ivars["istep"] == 0:
  628. X_norm = float(torch.norm(self.X))
  629. iX_norm = X_norm**-1
  630. A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
  631. B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
  632. self.fvars["X_norm"] = X_norm
  633. self.fvars["A_norm"] = A_norm
  634. self.fvars["B_norm"] = B_norm
  635. self.ivars["iterations_left"] = self.iparams["niter"]
  636. self.ivars["converged_count"] = 0
  637. self.ivars["converged_end"] = 0
  638. if self.method == "ortho":
  639. self._update_ortho()
  640. else:
  641. self._update_basic()
  642. self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1
  643. self.ivars["istep"] = self.ivars["istep"] + 1
  644. def update_residual(self):
  645. """Update residual R from A, B, X, E."""
  646. mm = _utils.matmul
  647. self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
  648. def update_converged_count(self):
  649. """Determine the number of converged eigenpairs using backward stable
  650. convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
  651. Users may redefine this method for custom convergence criteria.
  652. """
  653. # (...) -> int
  654. prev_count = self.ivars["converged_count"]
  655. tol = self.fparams["tol"]
  656. A_norm = self.fvars["A_norm"]
  657. B_norm = self.fvars["B_norm"]
  658. E, X, R = self.E, self.X, self.R
  659. rerr = (
  660. torch.norm(R, 2, (0,))
  661. * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1
  662. )
  663. converged = rerr < tol
  664. count = 0
  665. for b in converged:
  666. if not b:
  667. # ignore convergence of following pairs to ensure
  668. # strict ordering of eigenpairs
  669. break
  670. count += 1
  671. assert count >= prev_count, (
  672. "the number of converged eigenpairs "
  673. "(was {}, got {}) cannot decrease".format(prev_count, count)
  674. )
  675. self.ivars["converged_count"] = count
  676. self.tvars["rerr"] = rerr
  677. return count
  678. def stop_iteration(self):
  679. """Return True to stop iterations.
  680. Note that tracker (if defined) can force-stop iterations by
  681. setting ``worker.bvars['force_stop'] = True``.
  682. """
  683. return (
  684. self.bvars.get("force_stop", False)
  685. or self.ivars["iterations_left"] == 0
  686. or self.ivars["converged_count"] >= self.iparams["k"]
  687. )
  688. def run(self):
  689. """Run LOBPCG iterations.
  690. Use this method as a template for implementing LOBPCG
  691. iteration scheme with custom tracker that is compatible with
  692. TorchScript.
  693. """
  694. self.update()
  695. if not torch.jit.is_scripting() and self.tracker is not None:
  696. self.call_tracker()
  697. while not self.stop_iteration():
  698. self.update()
  699. if not torch.jit.is_scripting() and self.tracker is not None:
  700. self.call_tracker()
  701. @torch.jit.unused
  702. def call_tracker(self):
  703. """Interface for tracking iteration process in Python mode.
  704. Tracking the iteration process is disabled in TorchScript
  705. mode. In fact, one should specify tracker=None when JIT
  706. compiling functions using lobpcg.
  707. """
  708. # do nothing when in TorchScript mode
  709. pass
  710. # Internal methods
  711. def _update_basic(self):
  712. """
  713. Update or initialize iteration variables when `method == "basic"`.
  714. """
  715. mm = torch.matmul
  716. ns = self.ivars["converged_end"]
  717. nc = self.ivars["converged_count"]
  718. n = self.iparams["n"]
  719. largest = self.bparams["largest"]
  720. if self.ivars["istep"] == 0:
  721. Ri = self._get_rayleigh_ritz_transform(self.X)
  722. M = _utils.qform(_utils.qform(self.A, self.X), Ri)
  723. E, Z = _utils.symeig(M, largest)
  724. self.X[:] = mm(self.X, mm(Ri, Z))
  725. self.E[:] = E
  726. np = 0
  727. self.update_residual()
  728. nc = self.update_converged_count()
  729. self.S[..., :n] = self.X
  730. W = _utils.matmul(self.iK, self.R)
  731. self.ivars["converged_end"] = ns = n + np + W.shape[-1]
  732. self.S[:, n + np : ns] = W
  733. else:
  734. S_ = self.S[:, nc:ns]
  735. Ri = self._get_rayleigh_ritz_transform(S_)
  736. M = _utils.qform(_utils.qform(self.A, S_), Ri)
  737. E_, Z = _utils.symeig(M, largest)
  738. self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc]))
  739. self.E[nc:] = E_[: n - nc]
  740. P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc]))
  741. np = P.shape[-1]
  742. self.update_residual()
  743. nc = self.update_converged_count()
  744. self.S[..., :n] = self.X
  745. self.S[:, n : n + np] = P
  746. W = _utils.matmul(self.iK, self.R[:, nc:])
  747. self.ivars["converged_end"] = ns = n + np + W.shape[-1]
  748. self.S[:, n + np : ns] = W
  749. def _update_ortho(self):
  750. """
  751. Update or initialize iteration variables when `method == "ortho"`.
  752. """
  753. mm = torch.matmul
  754. ns = self.ivars["converged_end"]
  755. nc = self.ivars["converged_count"]
  756. n = self.iparams["n"]
  757. largest = self.bparams["largest"]
  758. if self.ivars["istep"] == 0:
  759. Ri = self._get_rayleigh_ritz_transform(self.X)
  760. M = _utils.qform(_utils.qform(self.A, self.X), Ri)
  761. E, Z = _utils.symeig(M, largest)
  762. self.X = mm(self.X, mm(Ri, Z))
  763. self.update_residual()
  764. np = 0
  765. nc = self.update_converged_count()
  766. self.S[:, :n] = self.X
  767. W = self._get_ortho(self.R, self.X)
  768. ns = self.ivars["converged_end"] = n + np + W.shape[-1]
  769. self.S[:, n + np : ns] = W
  770. else:
  771. S_ = self.S[:, nc:ns]
  772. # Rayleigh-Ritz procedure
  773. E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
  774. # Update E, X, P
  775. self.X[:, nc:] = mm(S_, Z[:, : n - nc])
  776. self.E[nc:] = E_[: n - nc]
  777. P = mm(
  778. S_,
  779. mm(
  780. Z[:, n - nc :],
  781. _utils.basis(_utils.transpose(Z[: n - nc, n - nc :])),
  782. ),
  783. )
  784. np = P.shape[-1]
  785. # check convergence
  786. self.update_residual()
  787. nc = self.update_converged_count()
  788. # update S
  789. self.S[:, :n] = self.X
  790. self.S[:, n : n + np] = P
  791. W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np])
  792. ns = self.ivars["converged_end"] = n + np + W.shape[-1]
  793. self.S[:, n + np : ns] = W
  794. def _get_rayleigh_ritz_transform(self, S):
  795. """Return a transformation matrix that is used in Rayleigh-Ritz
  796. procedure for reducing a general eigenvalue problem :math:`(S^TAS)
  797. C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
  798. S^TAS Ri) Z = Z E` where `C = Ri Z`.
  799. .. note:: In the original Rayleight-Ritz procedure in
  800. [DuerschEtal2018], the problem is formulated as follows::
  801. SAS = S^T A S
  802. SBS = S^T B S
  803. D = (<diagonal matrix of SBS>) ** -1/2
  804. R^T R = Cholesky(D SBS D)
  805. Ri = D R^-1
  806. solve symeig problem Ri^T SAS Ri Z = Theta Z
  807. C = Ri Z
  808. To reduce the number of matrix products (denoted by empty
  809. space between matrices), here we introduce element-wise
  810. products (denoted by symbol `*`) so that the Rayleight-Ritz
  811. procedure becomes::
  812. SAS = S^T A S
  813. SBS = S^T B S
  814. d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector
  815. dd = d d^T # this is 2-d matrix
  816. R^T R = Cholesky(dd * SBS)
  817. Ri = R^-1 * d # broadcasting
  818. solve symeig problem Ri^T SAS Ri Z = Theta Z
  819. C = Ri Z
  820. where `dd` is 2-d matrix that replaces matrix products `D M
  821. D` with one element-wise product `M * dd`; and `d` replaces
  822. matrix product `D M` with element-wise product `M *
  823. d`. Also, creating the diagonal matrix `D` is avoided.
  824. Args:
  825. S (Tensor): the matrix basis for the search subspace, size is
  826. :math:`(m, n)`.
  827. Returns:
  828. Ri (tensor): upper-triangular transformation matrix of size
  829. :math:`(n, n)`.
  830. """
  831. B = self.B
  832. mm = torch.matmul
  833. SBS = _utils.qform(B, S)
  834. d_row = SBS.diagonal(0, -2, -1) ** -0.5
  835. d_col = d_row.reshape(d_row.shape[0], 1)
  836. # TODO use torch.linalg.cholesky_solve once it is implemented
  837. R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
  838. return torch.linalg.solve_triangular(
  839. R, d_row.diag_embed(), upper=True, left=False
  840. )
  841. def _get_svqb(
  842. self, U: Tensor, drop: bool, tau: float # Tensor # bool # float
  843. ) -> Tensor:
  844. """Return B-orthonormal U.
  845. .. note:: When `drop` is `False` then `svqb` is based on the
  846. Algorithm 4 from [DuerschPhD2015] that is a slight
  847. modification of the corresponding algorithm
  848. introduced in [StathopolousWu2002].
  849. Args:
  850. U (Tensor) : initial approximation, size is (m, n)
  851. drop (bool) : when True, drop columns that
  852. contribution to the `span([U])` is small.
  853. tau (float) : positive tolerance
  854. Returns:
  855. U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
  856. is (m, n1), where `n1 = n` if `drop` is `False,
  857. otherwise `n1 <= n`.
  858. """
  859. if torch.numel(U) == 0:
  860. return U
  861. UBU = _utils.qform(self.B, U)
  862. d = UBU.diagonal(0, -2, -1)
  863. # Detect and drop exact zero columns from U. While the test
  864. # `abs(d) == 0` is unlikely to be True for random data, it is
  865. # possible to construct input data to lobpcg where it will be
  866. # True leading to a failure (notice the `d ** -0.5` operation
  867. # in the original algorithm). To prevent the failure, we drop
  868. # the exact zero columns here and then continue with the
  869. # original algorithm below.
  870. nz = torch.where(abs(d) != 0.0)
  871. assert len(nz) == 1, nz
  872. if len(nz[0]) < len(d):
  873. U = U[:, nz[0]]
  874. if torch.numel(U) == 0:
  875. return U
  876. UBU = _utils.qform(self.B, U)
  877. d = UBU.diagonal(0, -2, -1)
  878. nz = torch.where(abs(d) != 0.0)
  879. assert len(nz[0]) == len(d)
  880. # The original algorithm 4 from [DuerschPhD2015].
  881. d_col = (d**-0.5).reshape(d.shape[0], 1)
  882. DUBUD = (UBU * d_col) * _utils.transpose(d_col)
  883. E, Z = _utils.symeig(DUBUD)
  884. t = tau * abs(E).max()
  885. if drop:
  886. keep = torch.where(E > t)
  887. assert len(keep) == 1, keep
  888. E = E[keep[0]]
  889. Z = Z[:, keep[0]]
  890. d_col = d_col[keep[0]]
  891. else:
  892. E[(torch.where(E < t))[0]] = t
  893. return torch.matmul(U * _utils.transpose(d_col), Z * E**-0.5)
  894. def _get_ortho(self, U, V):
  895. """Return B-orthonormal U with columns are B-orthogonal to V.
  896. .. note:: When `bparams["ortho_use_drop"] == False` then
  897. `_get_ortho` is based on the Algorithm 3 from
  898. [DuerschPhD2015] that is a slight modification of
  899. the corresponding algorithm introduced in
  900. [StathopolousWu2002]. Otherwise, the method
  901. implements Algorithm 6 from [DuerschPhD2015]
  902. .. note:: If all U columns are B-collinear to V then the
  903. returned tensor U will be empty.
  904. Args:
  905. U (Tensor) : initial approximation, size is (m, n)
  906. V (Tensor) : B-orthogonal external basis, size is (m, k)
  907. Returns:
  908. U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
  909. such that :math:`V^T B U=0`, size is (m, n1),
  910. where `n1 = n` if `drop` is `False, otherwise
  911. `n1 <= n`.
  912. """
  913. mm = torch.matmul
  914. mm_B = _utils.matmul
  915. m = self.iparams["m"]
  916. tau_ortho = self.fparams["ortho_tol"]
  917. tau_drop = self.fparams["ortho_tol_drop"]
  918. tau_replace = self.fparams["ortho_tol_replace"]
  919. i_max = self.iparams["ortho_i_max"]
  920. j_max = self.iparams["ortho_j_max"]
  921. # when use_drop==True, enable dropping U columns that have
  922. # small contribution to the `span([U, V])`.
  923. use_drop = self.bparams["ortho_use_drop"]
  924. # clean up variables from the previous call
  925. for vkey in list(self.fvars.keys()):
  926. if vkey.startswith("ortho_") and vkey.endswith("_rerr"):
  927. self.fvars.pop(vkey)
  928. self.ivars.pop("ortho_i", 0)
  929. self.ivars.pop("ortho_j", 0)
  930. BV_norm = torch.norm(mm_B(self.B, V))
  931. BU = mm_B(self.B, U)
  932. VBU = mm(_utils.transpose(V), BU)
  933. i = j = 0
  934. stats = ""
  935. for i in range(i_max):
  936. U = U - mm(V, VBU)
  937. drop = False
  938. tau_svqb = tau_drop
  939. for j in range(j_max):
  940. if use_drop:
  941. U = self._get_svqb(U, drop, tau_svqb)
  942. drop = True
  943. tau_svqb = tau_replace
  944. else:
  945. U = self._get_svqb(U, False, tau_replace)
  946. if torch.numel(U) == 0:
  947. # all initial U columns are B-collinear to V
  948. self.ivars["ortho_i"] = i
  949. self.ivars["ortho_j"] = j
  950. return U
  951. BU = mm_B(self.B, U)
  952. UBU = mm(_utils.transpose(U), BU)
  953. U_norm = torch.norm(U)
  954. BU_norm = torch.norm(BU)
  955. R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype)
  956. R_norm = torch.norm(R)
  957. # https://github.com/pytorch/pytorch/issues/33810 workaround:
  958. rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
  959. vkey = "ortho_UBUmI_rerr[{}, {}]".format(i, j)
  960. self.fvars[vkey] = rerr
  961. if rerr < tau_ortho:
  962. break
  963. VBU = mm(_utils.transpose(V), BU)
  964. VBU_norm = torch.norm(VBU)
  965. U_norm = torch.norm(U)
  966. rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
  967. vkey = "ortho_VBU_rerr[{}]".format(i)
  968. self.fvars[vkey] = rerr
  969. if rerr < tau_ortho:
  970. break
  971. if m < U.shape[-1] + V.shape[-1]:
  972. # TorchScript needs the class var to be assigned to a local to
  973. # do optional type refinement
  974. B = self.B
  975. assert B is not None
  976. raise ValueError(
  977. "Overdetermined shape of U:"
  978. " #B-cols(={}) >= #U-cols(={}) + #V-cols(={}) must hold".format(
  979. B.shape[-1], U.shape[-1], V.shape[-1]
  980. )
  981. )
  982. self.ivars["ortho_i"] = i
  983. self.ivars["ortho_j"] = j
  984. return U
  985. # Calling tracker is separated from LOBPCG definitions because
  986. # TorchScript does not support user-defined callback arguments:
  987. LOBPCG_call_tracker_orig = LOBPCG.call_tracker
  988. def LOBPCG_call_tracker(self):
  989. self.tracker(self)