12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175 |
- """Locally Optimal Block Preconditioned Conjugate Gradient methods.
- """
- # Author: Pearu Peterson
- # Created: February 2020
- from typing import Dict, Optional, Tuple
- import torch
- from torch import Tensor
- from . import _linalg_utils as _utils
- from .overrides import handle_torch_function, has_torch_function
- __all__ = ["lobpcg"]
- def _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U):
- # compute F, such that F_ij = (d_j - d_i)^{-1} for i != j, F_ii = 0
- F = D.unsqueeze(-2) - D.unsqueeze(-1)
- F.diagonal(dim1=-2, dim2=-1).fill_(float("inf"))
- F.pow_(-1)
- # A.grad = U (D.grad + (U^T U.grad * F)) U^T
- Ut = U.mT.contiguous()
- res = torch.matmul(
- U, torch.matmul(torch.diag_embed(D_grad) + torch.matmul(Ut, U_grad) * F, Ut)
- )
- return res
- def _polynomial_coefficients_given_roots(roots):
- """
- Given the `roots` of a polynomial, find the polynomial's coefficients.
- If roots = (r_1, ..., r_n), then the method returns
- coefficients (a_0, a_1, ..., a_n (== 1)) so that
- p(x) = (x - r_1) * ... * (x - r_n)
- = x^n + a_{n-1} * x^{n-1} + ... a_1 * x_1 + a_0
- Note: for better performance requires writing a low-level kernel
- """
- poly_order = roots.shape[-1]
- poly_coeffs_shape = list(roots.shape)
- # we assume p(x) = x^n + a_{n-1} * x^{n-1} + ... + a_1 * x + a_0,
- # so poly_coeffs = {a_0, ..., a_n, a_{n+1}(== 1)},
- # but we insert one extra coefficient to enable better vectorization below
- poly_coeffs_shape[-1] += 2
- poly_coeffs = roots.new_zeros(poly_coeffs_shape)
- poly_coeffs[..., 0] = 1
- poly_coeffs[..., -1] = 1
- # perform the Horner's rule
- for i in range(1, poly_order + 1):
- # note that it is computationally hard to compute backward for this method,
- # because then given the coefficients it would require finding the roots and/or
- # calculating the sensitivity based on the Vieta's theorem.
- # So the code below tries to circumvent the explicit root finding by series
- # of operations on memory copies imitating the Horner's method.
- # The memory copies are required to construct nodes in the computational graph
- # by exploting the explicit (not in-place, separate node for each step)
- # recursion of the Horner's method.
- # Needs more memory, O(... * k^2), but with only O(... * k^2) complexity.
- poly_coeffs_new = poly_coeffs.clone() if roots.requires_grad else poly_coeffs
- out = poly_coeffs_new.narrow(-1, poly_order - i, i + 1)
- out -= roots.narrow(-1, i - 1, 1) * poly_coeffs.narrow(
- -1, poly_order - i + 1, i + 1
- )
- poly_coeffs = poly_coeffs_new
- return poly_coeffs.narrow(-1, 1, poly_order + 1)
- def _polynomial_value(poly, x, zero_power, transition):
- """
- A generic method for computing poly(x) using the Horner's rule.
- Args:
- poly (Tensor): the (possibly batched) 1D Tensor representing
- polynomial coefficients such that
- poly[..., i] = (a_{i_0}, ..., a{i_n} (==1)), and
- poly(x) = poly[..., 0] * zero_power + ... + poly[..., n] * x^n
- x (Tensor): the value (possible batched) to evalate the polynomial `poly` at.
- zero_power (Tensor): the representation of `x^0`. It is application-specific.
- transition (Callable): the function that accepts some intermediate result `int_val`,
- the `x` and a specific polynomial coefficient
- `poly[..., k]` for some iteration `k`.
- It basically performs one iteration of the Horner's rule
- defined as `x * int_val + poly[..., k] * zero_power`.
- Note that `zero_power` is not a parameter,
- because the step `+ poly[..., k] * zero_power` depends on `x`,
- whether it is a vector, a matrix, or something else, so this
- functionality is delegated to the user.
- """
- res = zero_power.clone()
- for k in range(poly.size(-1) - 2, -1, -1):
- res = transition(res, x, poly[..., k])
- return res
- def _matrix_polynomial_value(poly, x, zero_power=None):
- """
- Evaluates `poly(x)` for the (batched) matrix input `x`.
- Check out `_polynomial_value` function for more details.
- """
- # matrix-aware Horner's rule iteration
- def transition(curr_poly_val, x, poly_coeff):
- res = x.matmul(curr_poly_val)
- res.diagonal(dim1=-2, dim2=-1).add_(poly_coeff.unsqueeze(-1))
- return res
- if zero_power is None:
- zero_power = torch.eye(
- x.size(-1), x.size(-1), dtype=x.dtype, device=x.device
- ).view(*([1] * len(list(x.shape[:-2]))), x.size(-1), x.size(-1))
- return _polynomial_value(poly, x, zero_power, transition)
- def _vector_polynomial_value(poly, x, zero_power=None):
- """
- Evaluates `poly(x)` for the (batched) vector input `x`.
- Check out `_polynomial_value` function for more details.
- """
- # vector-aware Horner's rule iteration
- def transition(curr_poly_val, x, poly_coeff):
- res = torch.addcmul(poly_coeff.unsqueeze(-1), x, curr_poly_val)
- return res
- if zero_power is None:
- zero_power = x.new_ones(1).expand(x.shape)
- return _polynomial_value(poly, x, zero_power, transition)
- def _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest):
- # compute a projection operator onto an orthogonal subspace spanned by the
- # columns of U defined as (I - UU^T)
- Ut = U.mT.contiguous()
- proj_U_ortho = -U.matmul(Ut)
- proj_U_ortho.diagonal(dim1=-2, dim2=-1).add_(1)
- # compute U_ortho, a basis for the orthogonal complement to the span(U),
- # by projecting a random [..., m, m - k] matrix onto the subspace spanned
- # by the columns of U.
- #
- # fix generator for determinism
- gen = torch.Generator(A.device)
- # orthogonal complement to the span(U)
- U_ortho = proj_U_ortho.matmul(
- torch.randn(
- (*A.shape[:-1], A.size(-1) - D.size(-1)),
- dtype=A.dtype,
- device=A.device,
- generator=gen,
- )
- )
- U_ortho_t = U_ortho.mT.contiguous()
- # compute the coefficients of the characteristic polynomial of the tensor D.
- # Note that D is diagonal, so the diagonal elements are exactly the roots
- # of the characteristic polynomial.
- chr_poly_D = _polynomial_coefficients_given_roots(D)
- # the code belows finds the explicit solution to the Sylvester equation
- # U_ortho^T A U_ortho dX - dX D = -U_ortho^T A U
- # and incorporates it into the whole gradient stored in the `res` variable.
- #
- # Equivalent to the following naive implementation:
- # res = A.new_zeros(A.shape)
- # p_res = A.new_zeros(*A.shape[:-1], D.size(-1))
- # for k in range(1, chr_poly_D.size(-1)):
- # p_res.zero_()
- # for i in range(0, k):
- # p_res += (A.matrix_power(k - 1 - i) @ U_grad) * D.pow(i).unsqueeze(-2)
- # res -= chr_poly_D[k] * (U_ortho @ poly_D_at_A.inverse() @ U_ortho_t @ p_res @ U.t())
- #
- # Note that dX is a differential, so the gradient contribution comes from the backward sensitivity
- # Tr(f(U_grad, D_grad, A, U, D)^T dX) = Tr(g(U_grad, A, U, D)^T dA) for some functions f and g,
- # and we need to compute g(U_grad, A, U, D)
- #
- # The naive implementation is based on the paper
- # Hu, Qingxi, and Daizhan Cheng.
- # "The polynomial solution to the Sylvester matrix equation."
- # Applied mathematics letters 19.9 (2006): 859-864.
- #
- # We can modify the computation of `p_res` from above in a more efficient way
- # p_res = U_grad * (chr_poly_D[1] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k)).unsqueeze(-2)
- # + A U_grad * (chr_poly_D[2] * D.pow(0) + ... + chr_poly_D[k] * D.pow(k - 1)).unsqueeze(-2)
- # + ...
- # + A.matrix_power(k - 1) U_grad * chr_poly_D[k]
- # Note that this saves us from redundant matrix products with A (elimination of matrix_power)
- U_grad_projected = U_grad
- series_acc = U_grad_projected.new_zeros(U_grad_projected.shape)
- for k in range(1, chr_poly_D.size(-1)):
- poly_D = _vector_polynomial_value(chr_poly_D[..., k:], D)
- series_acc += U_grad_projected * poly_D.unsqueeze(-2)
- U_grad_projected = A.matmul(U_grad_projected)
- # compute chr_poly_D(A) which essentially is:
- #
- # chr_poly_D_at_A = A.new_zeros(A.shape)
- # for k in range(chr_poly_D.size(-1)):
- # chr_poly_D_at_A += chr_poly_D[k] * A.matrix_power(k)
- #
- # Note, however, for better performance we use the Horner's rule
- chr_poly_D_at_A = _matrix_polynomial_value(chr_poly_D, A)
- # compute the action of `chr_poly_D_at_A` restricted to U_ortho_t
- chr_poly_D_at_A_to_U_ortho = torch.matmul(
- U_ortho_t, torch.matmul(chr_poly_D_at_A, U_ortho)
- )
- # we need to invert 'chr_poly_D_at_A_to_U_ortho`, for that we compute its
- # Cholesky decomposition and then use `torch.cholesky_solve` for better stability.
- # Cholesky decomposition requires the input to be positive-definite.
- # Note that `chr_poly_D_at_A_to_U_ortho` is positive-definite if
- # 1. `largest` == False, or
- # 2. `largest` == True and `k` is even
- # under the assumption that `A` has distinct eigenvalues.
- #
- # check if `chr_poly_D_at_A_to_U_ortho` is positive-definite or negative-definite
- chr_poly_D_at_A_to_U_ortho_sign = -1 if (largest and (k % 2 == 1)) else +1
- chr_poly_D_at_A_to_U_ortho_L = torch.linalg.cholesky(
- chr_poly_D_at_A_to_U_ortho_sign * chr_poly_D_at_A_to_U_ortho
- )
- # compute the gradient part in span(U)
- res = _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
- # incorporate the Sylvester equation solution into the full gradient
- # it resides in span(U_ortho)
- res -= U_ortho.matmul(
- chr_poly_D_at_A_to_U_ortho_sign
- * torch.cholesky_solve(
- U_ortho_t.matmul(series_acc), chr_poly_D_at_A_to_U_ortho_L
- )
- ).matmul(Ut)
- return res
- def _symeig_backward(D_grad, U_grad, A, D, U, largest):
- # if `U` is square, then the columns of `U` is a complete eigenspace
- if U.size(-1) == U.size(-2):
- return _symeig_backward_complete_eigenspace(D_grad, U_grad, A, D, U)
- else:
- return _symeig_backward_partial_eigenspace(D_grad, U_grad, A, D, U, largest)
- class LOBPCGAutogradFunction(torch.autograd.Function):
- @staticmethod
- def forward( # type: ignore[override]
- ctx,
- A: Tensor,
- k: Optional[int] = None,
- B: Optional[Tensor] = None,
- X: Optional[Tensor] = None,
- n: Optional[int] = None,
- iK: Optional[Tensor] = None,
- niter: Optional[int] = None,
- tol: Optional[float] = None,
- largest: Optional[bool] = None,
- method: Optional[str] = None,
- tracker: None = None,
- ortho_iparams: Optional[Dict[str, int]] = None,
- ortho_fparams: Optional[Dict[str, float]] = None,
- ortho_bparams: Optional[Dict[str, bool]] = None,
- ) -> Tuple[Tensor, Tensor]:
- # makes sure that input is contiguous for efficiency.
- # Note: autograd does not support dense gradients for sparse input yet.
- A = A.contiguous() if (not A.is_sparse) else A
- if B is not None:
- B = B.contiguous() if (not B.is_sparse) else B
- D, U = _lobpcg(
- A,
- k,
- B,
- X,
- n,
- iK,
- niter,
- tol,
- largest,
- method,
- tracker,
- ortho_iparams,
- ortho_fparams,
- ortho_bparams,
- )
- ctx.save_for_backward(A, B, D, U)
- ctx.largest = largest
- return D, U
- @staticmethod
- def backward(ctx, D_grad, U_grad):
- A_grad = B_grad = None
- grads = [None] * 14
- A, B, D, U = ctx.saved_tensors
- largest = ctx.largest
- # lobpcg.backward has some limitations. Checks for unsupported input
- if A.is_sparse or (B is not None and B.is_sparse and ctx.needs_input_grad[2]):
- raise ValueError(
- "lobpcg.backward does not support sparse input yet."
- "Note that lobpcg.forward does though."
- )
- if (
- A.dtype in (torch.complex64, torch.complex128)
- or B is not None
- and B.dtype in (torch.complex64, torch.complex128)
- ):
- raise ValueError(
- "lobpcg.backward does not support complex input yet."
- "Note that lobpcg.forward does though."
- )
- if B is not None:
- raise ValueError(
- "lobpcg.backward does not support backward with B != I yet."
- )
- if largest is None:
- largest = True
- # symeig backward
- if B is None:
- A_grad = _symeig_backward(D_grad, U_grad, A, D, U, largest)
- # A has index 0
- grads[0] = A_grad
- # B has index 2
- grads[2] = B_grad
- return tuple(grads)
- def lobpcg(
- A: Tensor,
- k: Optional[int] = None,
- B: Optional[Tensor] = None,
- X: Optional[Tensor] = None,
- n: Optional[int] = None,
- iK: Optional[Tensor] = None,
- niter: Optional[int] = None,
- tol: Optional[float] = None,
- largest: Optional[bool] = None,
- method: Optional[str] = None,
- tracker: None = None,
- ortho_iparams: Optional[Dict[str, int]] = None,
- ortho_fparams: Optional[Dict[str, float]] = None,
- ortho_bparams: Optional[Dict[str, bool]] = None,
- ) -> Tuple[Tensor, Tensor]:
- """Find the k largest (or smallest) eigenvalues and the corresponding
- eigenvectors of a symmetric positive definite generalized
- eigenvalue problem using matrix-free LOBPCG methods.
- This function is a front-end to the following LOBPCG algorithms
- selectable via `method` argument:
- `method="basic"` - the LOBPCG method introduced by Andrew
- Knyazev, see [Knyazev2001]. A less robust method, may fail when
- Cholesky is applied to singular input.
- `method="ortho"` - the LOBPCG method with orthogonal basis
- selection [StathopoulosEtal2002]. A robust method.
- Supported inputs are dense, sparse, and batches of dense matrices.
- .. note:: In general, the basic method spends least time per
- iteration. However, the robust methods converge much faster and
- are more stable. So, the usage of the basic method is generally
- not recommended but there exist cases where the usage of the
- basic method may be preferred.
- .. warning:: The backward method does not support sparse and complex inputs.
- It works only when `B` is not provided (i.e. `B == None`).
- We are actively working on extensions, and the details of
- the algorithms are going to be published promptly.
- .. warning:: While it is assumed that `A` is symmetric, `A.grad` is not.
- To make sure that `A.grad` is symmetric, so that `A - t * A.grad` is symmetric
- in first-order optimization routines, prior to running `lobpcg`
- we do the following symmetrization map: `A -> (A + A.t()) / 2`.
- The map is performed only when the `A` requires gradients.
- Args:
- A (Tensor): the input tensor of size :math:`(*, m, m)`
- B (Tensor, optional): the input tensor of size :math:`(*, m,
- m)`. When not specified, `B` is interpreted as
- identity matrix.
- X (tensor, optional): the input tensor of size :math:`(*, m, n)`
- where `k <= n <= m`. When specified, it is used as
- initial approximation of eigenvectors. X must be a
- dense tensor.
- iK (tensor, optional): the input tensor of size :math:`(*, m,
- m)`. When specified, it will be used as preconditioner.
- k (integer, optional): the number of requested
- eigenpairs. Default is the number of :math:`X`
- columns (when specified) or `1`.
- n (integer, optional): if :math:`X` is not specified then `n`
- specifies the size of the generated random
- approximation of eigenvectors. Default value for `n`
- is `k`. If :math:`X` is specified, the value of `n`
- (when specified) must be the number of :math:`X`
- columns.
- tol (float, optional): residual tolerance for stopping
- criterion. Default is `feps ** 0.5` where `feps` is
- smallest non-zero floating-point number of the given
- input tensor `A` data type.
- largest (bool, optional): when True, solve the eigenproblem for
- the largest eigenvalues. Otherwise, solve the
- eigenproblem for smallest eigenvalues. Default is
- `True`.
- method (str, optional): select LOBPCG method. See the
- description of the function above. Default is
- "ortho".
- niter (int, optional): maximum number of iterations. When
- reached, the iteration process is hard-stopped and
- the current approximation of eigenpairs is returned.
- For infinite iteration but until convergence criteria
- is met, use `-1`.
- tracker (callable, optional) : a function for tracing the
- iteration process. When specified, it is called at
- each iteration step with LOBPCG instance as an
- argument. The LOBPCG instance holds the full state of
- the iteration process in the following attributes:
- `iparams`, `fparams`, `bparams` - dictionaries of
- integer, float, and boolean valued input
- parameters, respectively
- `ivars`, `fvars`, `bvars`, `tvars` - dictionaries
- of integer, float, boolean, and Tensor valued
- iteration variables, respectively.
- `A`, `B`, `iK` - input Tensor arguments.
- `E`, `X`, `S`, `R` - iteration Tensor variables.
- For instance:
- `ivars["istep"]` - the current iteration step
- `X` - the current approximation of eigenvectors
- `E` - the current approximation of eigenvalues
- `R` - the current residual
- `ivars["converged_count"]` - the current number of converged eigenpairs
- `tvars["rerr"]` - the current state of convergence criteria
- Note that when `tracker` stores Tensor objects from
- the LOBPCG instance, it must make copies of these.
- If `tracker` sets `bvars["force_stop"] = True`, the
- iteration process will be hard-stopped.
- ortho_iparams, ortho_fparams, ortho_bparams (dict, optional):
- various parameters to LOBPCG algorithm when using
- `method="ortho"`.
- Returns:
- E (Tensor): tensor of eigenvalues of size :math:`(*, k)`
- X (Tensor): tensor of eigenvectors of size :math:`(*, m, k)`
- References:
- [Knyazev2001] Andrew V. Knyazev. (2001) Toward the Optimal
- Preconditioned Eigensolver: Locally Optimal Block Preconditioned
- Conjugate Gradient Method. SIAM J. Sci. Comput., 23(2),
- 517-541. (25 pages)
- https://epubs.siam.org/doi/abs/10.1137/S1064827500366124
- [StathopoulosEtal2002] Andreas Stathopoulos and Kesheng
- Wu. (2002) A Block Orthogonalization Procedure with Constant
- Synchronization Requirements. SIAM J. Sci. Comput., 23(6),
- 2165-2182. (18 pages)
- https://epubs.siam.org/doi/10.1137/S1064827500370883
- [DuerschEtal2018] Jed A. Duersch, Meiyue Shao, Chao Yang, Ming
- Gu. (2018) A Robust and Efficient Implementation of LOBPCG.
- SIAM J. Sci. Comput., 40(5), C655-C676. (22 pages)
- https://epubs.siam.org/doi/abs/10.1137/17M1129830
- """
- if not torch.jit.is_scripting():
- tensor_ops = (A, B, X, iK)
- if not set(map(type, tensor_ops)).issubset(
- (torch.Tensor, type(None))
- ) and has_torch_function(tensor_ops):
- return handle_torch_function(
- lobpcg,
- tensor_ops,
- A,
- k=k,
- B=B,
- X=X,
- n=n,
- iK=iK,
- niter=niter,
- tol=tol,
- largest=largest,
- method=method,
- tracker=tracker,
- ortho_iparams=ortho_iparams,
- ortho_fparams=ortho_fparams,
- ortho_bparams=ortho_bparams,
- )
- if not torch._jit_internal.is_scripting():
- if A.requires_grad or (B is not None and B.requires_grad):
- # While it is expected that `A` is symmetric,
- # the `A_grad` might be not. Therefore we perform the trick below,
- # so that `A_grad` becomes symmetric.
- # The symmetrization is important for first-order optimization methods,
- # so that (A - alpha * A_grad) is still a symmetric matrix.
- # Same holds for `B`.
- A_sym = (A + A.mT) / 2
- B_sym = (B + B.mT) / 2 if (B is not None) else None
- return LOBPCGAutogradFunction.apply(
- A_sym,
- k,
- B_sym,
- X,
- n,
- iK,
- niter,
- tol,
- largest,
- method,
- tracker,
- ortho_iparams,
- ortho_fparams,
- ortho_bparams,
- )
- else:
- if A.requires_grad or (B is not None and B.requires_grad):
- raise RuntimeError(
- "Script and require grads is not supported atm."
- "If you just want to do the forward, use .detach()"
- "on A and B before calling into lobpcg"
- )
- return _lobpcg(
- A,
- k,
- B,
- X,
- n,
- iK,
- niter,
- tol,
- largest,
- method,
- tracker,
- ortho_iparams,
- ortho_fparams,
- ortho_bparams,
- )
- def _lobpcg(
- A: Tensor,
- k: Optional[int] = None,
- B: Optional[Tensor] = None,
- X: Optional[Tensor] = None,
- n: Optional[int] = None,
- iK: Optional[Tensor] = None,
- niter: Optional[int] = None,
- tol: Optional[float] = None,
- largest: Optional[bool] = None,
- method: Optional[str] = None,
- tracker: None = None,
- ortho_iparams: Optional[Dict[str, int]] = None,
- ortho_fparams: Optional[Dict[str, float]] = None,
- ortho_bparams: Optional[Dict[str, bool]] = None,
- ) -> Tuple[Tensor, Tensor]:
- # A must be square:
- assert A.shape[-2] == A.shape[-1], A.shape
- if B is not None:
- # A and B must have the same shapes:
- assert A.shape == B.shape, (A.shape, B.shape)
- dtype = _utils.get_floating_dtype(A)
- device = A.device
- if tol is None:
- feps = {torch.float32: 1.2e-07, torch.float64: 2.23e-16}[dtype]
- tol = feps**0.5
- m = A.shape[-1]
- k = (1 if X is None else X.shape[-1]) if k is None else k
- n = (k if n is None else n) if X is None else X.shape[-1]
- if m < 3 * n:
- raise ValueError(
- "LPBPCG algorithm is not applicable when the number of A rows (={})"
- " is smaller than 3 x the number of requested eigenpairs (={})".format(m, n)
- )
- method = "ortho" if method is None else method
- iparams = {
- "m": m,
- "n": n,
- "k": k,
- "niter": 1000 if niter is None else niter,
- }
- fparams = {
- "tol": tol,
- }
- bparams = {"largest": True if largest is None else largest}
- if method == "ortho":
- if ortho_iparams is not None:
- iparams.update(ortho_iparams)
- if ortho_fparams is not None:
- fparams.update(ortho_fparams)
- if ortho_bparams is not None:
- bparams.update(ortho_bparams)
- iparams["ortho_i_max"] = iparams.get("ortho_i_max", 3)
- iparams["ortho_j_max"] = iparams.get("ortho_j_max", 3)
- fparams["ortho_tol"] = fparams.get("ortho_tol", tol)
- fparams["ortho_tol_drop"] = fparams.get("ortho_tol_drop", tol)
- fparams["ortho_tol_replace"] = fparams.get("ortho_tol_replace", tol)
- bparams["ortho_use_drop"] = bparams.get("ortho_use_drop", False)
- if not torch.jit.is_scripting():
- LOBPCG.call_tracker = LOBPCG_call_tracker # type: ignore[assignment]
- if len(A.shape) > 2:
- N = int(torch.prod(torch.tensor(A.shape[:-2])))
- bA = A.reshape((N,) + A.shape[-2:])
- bB = B.reshape((N,) + A.shape[-2:]) if B is not None else None
- bX = X.reshape((N,) + X.shape[-2:]) if X is not None else None
- bE = torch.empty((N, k), dtype=dtype, device=device)
- bXret = torch.empty((N, m, k), dtype=dtype, device=device)
- for i in range(N):
- A_ = bA[i]
- B_ = bB[i] if bB is not None else None
- X_ = (
- torch.randn((m, n), dtype=dtype, device=device) if bX is None else bX[i]
- )
- assert len(X_.shape) == 2 and X_.shape == (m, n), (X_.shape, (m, n))
- iparams["batch_index"] = i
- worker = LOBPCG(A_, B_, X_, iK, iparams, fparams, bparams, method, tracker)
- worker.run()
- bE[i] = worker.E[:k]
- bXret[i] = worker.X[:, :k]
- if not torch.jit.is_scripting():
- LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[assignment]
- return bE.reshape(A.shape[:-2] + (k,)), bXret.reshape(A.shape[:-2] + (m, k))
- X = torch.randn((m, n), dtype=dtype, device=device) if X is None else X
- assert len(X.shape) == 2 and X.shape == (m, n), (X.shape, (m, n))
- worker = LOBPCG(A, B, X, iK, iparams, fparams, bparams, method, tracker)
- worker.run()
- if not torch.jit.is_scripting():
- LOBPCG.call_tracker = LOBPCG_call_tracker_orig # type: ignore[assignment]
- return worker.E[:k], worker.X[:, :k]
- class LOBPCG:
- """Worker class of LOBPCG methods."""
- def __init__(
- self,
- A: Optional[Tensor],
- B: Optional[Tensor],
- X: Tensor,
- iK: Optional[Tensor],
- iparams: Dict[str, int],
- fparams: Dict[str, float],
- bparams: Dict[str, bool],
- method: str,
- tracker: None,
- ) -> None:
- # constant parameters
- self.A = A
- self.B = B
- self.iK = iK
- self.iparams = iparams
- self.fparams = fparams
- self.bparams = bparams
- self.method = method
- self.tracker = tracker
- m = iparams["m"]
- n = iparams["n"]
- # variable parameters
- self.X = X
- self.E = torch.zeros((n,), dtype=X.dtype, device=X.device)
- self.R = torch.zeros((m, n), dtype=X.dtype, device=X.device)
- self.S = torch.zeros((m, 3 * n), dtype=X.dtype, device=X.device)
- self.tvars: Dict[str, Tensor] = {}
- self.ivars: Dict[str, int] = {"istep": 0}
- self.fvars: Dict[str, float] = {"_": 0.0}
- self.bvars: Dict[str, bool] = {"_": False}
- def __str__(self):
- lines = ["LOPBCG:"]
- lines += [" iparams={}".format(self.iparams)]
- lines += [" fparams={}".format(self.fparams)]
- lines += [" bparams={}".format(self.bparams)]
- lines += [" ivars={}".format(self.ivars)]
- lines += [" fvars={}".format(self.fvars)]
- lines += [" bvars={}".format(self.bvars)]
- lines += [" tvars={}".format(self.tvars)]
- lines += [" A={}".format(self.A)]
- lines += [" B={}".format(self.B)]
- lines += [" iK={}".format(self.iK)]
- lines += [" X={}".format(self.X)]
- lines += [" E={}".format(self.E)]
- r = ""
- for line in lines:
- r += line + "\n"
- return r
- def update(self):
- """Set and update iteration variables."""
- if self.ivars["istep"] == 0:
- X_norm = float(torch.norm(self.X))
- iX_norm = X_norm**-1
- A_norm = float(torch.norm(_utils.matmul(self.A, self.X))) * iX_norm
- B_norm = float(torch.norm(_utils.matmul(self.B, self.X))) * iX_norm
- self.fvars["X_norm"] = X_norm
- self.fvars["A_norm"] = A_norm
- self.fvars["B_norm"] = B_norm
- self.ivars["iterations_left"] = self.iparams["niter"]
- self.ivars["converged_count"] = 0
- self.ivars["converged_end"] = 0
- if self.method == "ortho":
- self._update_ortho()
- else:
- self._update_basic()
- self.ivars["iterations_left"] = self.ivars["iterations_left"] - 1
- self.ivars["istep"] = self.ivars["istep"] + 1
- def update_residual(self):
- """Update residual R from A, B, X, E."""
- mm = _utils.matmul
- self.R = mm(self.A, self.X) - mm(self.B, self.X) * self.E
- def update_converged_count(self):
- """Determine the number of converged eigenpairs using backward stable
- convergence criterion, see discussion in Sec 4.3 of [DuerschEtal2018].
- Users may redefine this method for custom convergence criteria.
- """
- # (...) -> int
- prev_count = self.ivars["converged_count"]
- tol = self.fparams["tol"]
- A_norm = self.fvars["A_norm"]
- B_norm = self.fvars["B_norm"]
- E, X, R = self.E, self.X, self.R
- rerr = (
- torch.norm(R, 2, (0,))
- * (torch.norm(X, 2, (0,)) * (A_norm + E[: X.shape[-1]] * B_norm)) ** -1
- )
- converged = rerr < tol
- count = 0
- for b in converged:
- if not b:
- # ignore convergence of following pairs to ensure
- # strict ordering of eigenpairs
- break
- count += 1
- assert count >= prev_count, (
- "the number of converged eigenpairs "
- "(was {}, got {}) cannot decrease".format(prev_count, count)
- )
- self.ivars["converged_count"] = count
- self.tvars["rerr"] = rerr
- return count
- def stop_iteration(self):
- """Return True to stop iterations.
- Note that tracker (if defined) can force-stop iterations by
- setting ``worker.bvars['force_stop'] = True``.
- """
- return (
- self.bvars.get("force_stop", False)
- or self.ivars["iterations_left"] == 0
- or self.ivars["converged_count"] >= self.iparams["k"]
- )
- def run(self):
- """Run LOBPCG iterations.
- Use this method as a template for implementing LOBPCG
- iteration scheme with custom tracker that is compatible with
- TorchScript.
- """
- self.update()
- if not torch.jit.is_scripting() and self.tracker is not None:
- self.call_tracker()
- while not self.stop_iteration():
- self.update()
- if not torch.jit.is_scripting() and self.tracker is not None:
- self.call_tracker()
- @torch.jit.unused
- def call_tracker(self):
- """Interface for tracking iteration process in Python mode.
- Tracking the iteration process is disabled in TorchScript
- mode. In fact, one should specify tracker=None when JIT
- compiling functions using lobpcg.
- """
- # do nothing when in TorchScript mode
- pass
- # Internal methods
- def _update_basic(self):
- """
- Update or initialize iteration variables when `method == "basic"`.
- """
- mm = torch.matmul
- ns = self.ivars["converged_end"]
- nc = self.ivars["converged_count"]
- n = self.iparams["n"]
- largest = self.bparams["largest"]
- if self.ivars["istep"] == 0:
- Ri = self._get_rayleigh_ritz_transform(self.X)
- M = _utils.qform(_utils.qform(self.A, self.X), Ri)
- E, Z = _utils.symeig(M, largest)
- self.X[:] = mm(self.X, mm(Ri, Z))
- self.E[:] = E
- np = 0
- self.update_residual()
- nc = self.update_converged_count()
- self.S[..., :n] = self.X
- W = _utils.matmul(self.iK, self.R)
- self.ivars["converged_end"] = ns = n + np + W.shape[-1]
- self.S[:, n + np : ns] = W
- else:
- S_ = self.S[:, nc:ns]
- Ri = self._get_rayleigh_ritz_transform(S_)
- M = _utils.qform(_utils.qform(self.A, S_), Ri)
- E_, Z = _utils.symeig(M, largest)
- self.X[:, nc:] = mm(S_, mm(Ri, Z[:, : n - nc]))
- self.E[nc:] = E_[: n - nc]
- P = mm(S_, mm(Ri, Z[:, n : 2 * n - nc]))
- np = P.shape[-1]
- self.update_residual()
- nc = self.update_converged_count()
- self.S[..., :n] = self.X
- self.S[:, n : n + np] = P
- W = _utils.matmul(self.iK, self.R[:, nc:])
- self.ivars["converged_end"] = ns = n + np + W.shape[-1]
- self.S[:, n + np : ns] = W
- def _update_ortho(self):
- """
- Update or initialize iteration variables when `method == "ortho"`.
- """
- mm = torch.matmul
- ns = self.ivars["converged_end"]
- nc = self.ivars["converged_count"]
- n = self.iparams["n"]
- largest = self.bparams["largest"]
- if self.ivars["istep"] == 0:
- Ri = self._get_rayleigh_ritz_transform(self.X)
- M = _utils.qform(_utils.qform(self.A, self.X), Ri)
- E, Z = _utils.symeig(M, largest)
- self.X = mm(self.X, mm(Ri, Z))
- self.update_residual()
- np = 0
- nc = self.update_converged_count()
- self.S[:, :n] = self.X
- W = self._get_ortho(self.R, self.X)
- ns = self.ivars["converged_end"] = n + np + W.shape[-1]
- self.S[:, n + np : ns] = W
- else:
- S_ = self.S[:, nc:ns]
- # Rayleigh-Ritz procedure
- E_, Z = _utils.symeig(_utils.qform(self.A, S_), largest)
- # Update E, X, P
- self.X[:, nc:] = mm(S_, Z[:, : n - nc])
- self.E[nc:] = E_[: n - nc]
- P = mm(
- S_,
- mm(
- Z[:, n - nc :],
- _utils.basis(_utils.transpose(Z[: n - nc, n - nc :])),
- ),
- )
- np = P.shape[-1]
- # check convergence
- self.update_residual()
- nc = self.update_converged_count()
- # update S
- self.S[:, :n] = self.X
- self.S[:, n : n + np] = P
- W = self._get_ortho(self.R[:, nc:], self.S[:, : n + np])
- ns = self.ivars["converged_end"] = n + np + W.shape[-1]
- self.S[:, n + np : ns] = W
- def _get_rayleigh_ritz_transform(self, S):
- """Return a transformation matrix that is used in Rayleigh-Ritz
- procedure for reducing a general eigenvalue problem :math:`(S^TAS)
- C = (S^TBS) C E` to a standard eigenvalue problem :math: `(Ri^T
- S^TAS Ri) Z = Z E` where `C = Ri Z`.
- .. note:: In the original Rayleight-Ritz procedure in
- [DuerschEtal2018], the problem is formulated as follows::
- SAS = S^T A S
- SBS = S^T B S
- D = (<diagonal matrix of SBS>) ** -1/2
- R^T R = Cholesky(D SBS D)
- Ri = D R^-1
- solve symeig problem Ri^T SAS Ri Z = Theta Z
- C = Ri Z
- To reduce the number of matrix products (denoted by empty
- space between matrices), here we introduce element-wise
- products (denoted by symbol `*`) so that the Rayleight-Ritz
- procedure becomes::
- SAS = S^T A S
- SBS = S^T B S
- d = (<diagonal of SBS>) ** -1/2 # this is 1-d column vector
- dd = d d^T # this is 2-d matrix
- R^T R = Cholesky(dd * SBS)
- Ri = R^-1 * d # broadcasting
- solve symeig problem Ri^T SAS Ri Z = Theta Z
- C = Ri Z
- where `dd` is 2-d matrix that replaces matrix products `D M
- D` with one element-wise product `M * dd`; and `d` replaces
- matrix product `D M` with element-wise product `M *
- d`. Also, creating the diagonal matrix `D` is avoided.
- Args:
- S (Tensor): the matrix basis for the search subspace, size is
- :math:`(m, n)`.
- Returns:
- Ri (tensor): upper-triangular transformation matrix of size
- :math:`(n, n)`.
- """
- B = self.B
- mm = torch.matmul
- SBS = _utils.qform(B, S)
- d_row = SBS.diagonal(0, -2, -1) ** -0.5
- d_col = d_row.reshape(d_row.shape[0], 1)
- # TODO use torch.linalg.cholesky_solve once it is implemented
- R = torch.linalg.cholesky((SBS * d_row) * d_col, upper=True)
- return torch.linalg.solve_triangular(
- R, d_row.diag_embed(), upper=True, left=False
- )
- def _get_svqb(
- self, U: Tensor, drop: bool, tau: float # Tensor # bool # float
- ) -> Tensor:
- """Return B-orthonormal U.
- .. note:: When `drop` is `False` then `svqb` is based on the
- Algorithm 4 from [DuerschPhD2015] that is a slight
- modification of the corresponding algorithm
- introduced in [StathopolousWu2002].
- Args:
- U (Tensor) : initial approximation, size is (m, n)
- drop (bool) : when True, drop columns that
- contribution to the `span([U])` is small.
- tau (float) : positive tolerance
- Returns:
- U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`), size
- is (m, n1), where `n1 = n` if `drop` is `False,
- otherwise `n1 <= n`.
- """
- if torch.numel(U) == 0:
- return U
- UBU = _utils.qform(self.B, U)
- d = UBU.diagonal(0, -2, -1)
- # Detect and drop exact zero columns from U. While the test
- # `abs(d) == 0` is unlikely to be True for random data, it is
- # possible to construct input data to lobpcg where it will be
- # True leading to a failure (notice the `d ** -0.5` operation
- # in the original algorithm). To prevent the failure, we drop
- # the exact zero columns here and then continue with the
- # original algorithm below.
- nz = torch.where(abs(d) != 0.0)
- assert len(nz) == 1, nz
- if len(nz[0]) < len(d):
- U = U[:, nz[0]]
- if torch.numel(U) == 0:
- return U
- UBU = _utils.qform(self.B, U)
- d = UBU.diagonal(0, -2, -1)
- nz = torch.where(abs(d) != 0.0)
- assert len(nz[0]) == len(d)
- # The original algorithm 4 from [DuerschPhD2015].
- d_col = (d**-0.5).reshape(d.shape[0], 1)
- DUBUD = (UBU * d_col) * _utils.transpose(d_col)
- E, Z = _utils.symeig(DUBUD)
- t = tau * abs(E).max()
- if drop:
- keep = torch.where(E > t)
- assert len(keep) == 1, keep
- E = E[keep[0]]
- Z = Z[:, keep[0]]
- d_col = d_col[keep[0]]
- else:
- E[(torch.where(E < t))[0]] = t
- return torch.matmul(U * _utils.transpose(d_col), Z * E**-0.5)
- def _get_ortho(self, U, V):
- """Return B-orthonormal U with columns are B-orthogonal to V.
- .. note:: When `bparams["ortho_use_drop"] == False` then
- `_get_ortho` is based on the Algorithm 3 from
- [DuerschPhD2015] that is a slight modification of
- the corresponding algorithm introduced in
- [StathopolousWu2002]. Otherwise, the method
- implements Algorithm 6 from [DuerschPhD2015]
- .. note:: If all U columns are B-collinear to V then the
- returned tensor U will be empty.
- Args:
- U (Tensor) : initial approximation, size is (m, n)
- V (Tensor) : B-orthogonal external basis, size is (m, k)
- Returns:
- U (Tensor) : B-orthonormal columns (:math:`U^T B U = I`)
- such that :math:`V^T B U=0`, size is (m, n1),
- where `n1 = n` if `drop` is `False, otherwise
- `n1 <= n`.
- """
- mm = torch.matmul
- mm_B = _utils.matmul
- m = self.iparams["m"]
- tau_ortho = self.fparams["ortho_tol"]
- tau_drop = self.fparams["ortho_tol_drop"]
- tau_replace = self.fparams["ortho_tol_replace"]
- i_max = self.iparams["ortho_i_max"]
- j_max = self.iparams["ortho_j_max"]
- # when use_drop==True, enable dropping U columns that have
- # small contribution to the `span([U, V])`.
- use_drop = self.bparams["ortho_use_drop"]
- # clean up variables from the previous call
- for vkey in list(self.fvars.keys()):
- if vkey.startswith("ortho_") and vkey.endswith("_rerr"):
- self.fvars.pop(vkey)
- self.ivars.pop("ortho_i", 0)
- self.ivars.pop("ortho_j", 0)
- BV_norm = torch.norm(mm_B(self.B, V))
- BU = mm_B(self.B, U)
- VBU = mm(_utils.transpose(V), BU)
- i = j = 0
- stats = ""
- for i in range(i_max):
- U = U - mm(V, VBU)
- drop = False
- tau_svqb = tau_drop
- for j in range(j_max):
- if use_drop:
- U = self._get_svqb(U, drop, tau_svqb)
- drop = True
- tau_svqb = tau_replace
- else:
- U = self._get_svqb(U, False, tau_replace)
- if torch.numel(U) == 0:
- # all initial U columns are B-collinear to V
- self.ivars["ortho_i"] = i
- self.ivars["ortho_j"] = j
- return U
- BU = mm_B(self.B, U)
- UBU = mm(_utils.transpose(U), BU)
- U_norm = torch.norm(U)
- BU_norm = torch.norm(BU)
- R = UBU - torch.eye(UBU.shape[-1], device=UBU.device, dtype=UBU.dtype)
- R_norm = torch.norm(R)
- # https://github.com/pytorch/pytorch/issues/33810 workaround:
- rerr = float(R_norm) * float(BU_norm * U_norm) ** -1
- vkey = "ortho_UBUmI_rerr[{}, {}]".format(i, j)
- self.fvars[vkey] = rerr
- if rerr < tau_ortho:
- break
- VBU = mm(_utils.transpose(V), BU)
- VBU_norm = torch.norm(VBU)
- U_norm = torch.norm(U)
- rerr = float(VBU_norm) * float(BV_norm * U_norm) ** -1
- vkey = "ortho_VBU_rerr[{}]".format(i)
- self.fvars[vkey] = rerr
- if rerr < tau_ortho:
- break
- if m < U.shape[-1] + V.shape[-1]:
- # TorchScript needs the class var to be assigned to a local to
- # do optional type refinement
- B = self.B
- assert B is not None
- raise ValueError(
- "Overdetermined shape of U:"
- " #B-cols(={}) >= #U-cols(={}) + #V-cols(={}) must hold".format(
- B.shape[-1], U.shape[-1], V.shape[-1]
- )
- )
- self.ivars["ortho_i"] = i
- self.ivars["ortho_j"] = j
- return U
- # Calling tracker is separated from LOBPCG definitions because
- # TorchScript does not support user-defined callback arguments:
- LOBPCG_call_tracker_orig = LOBPCG.call_tracker
- def LOBPCG_call_tracker(self):
- self.tracker(self)
|