_gcrotmk.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # Copyright (C) 2015, Pauli Virtanen <pav@iki.fi>
  2. # Distributed under the same license as SciPy.
  3. import warnings
  4. import numpy as np
  5. from numpy.linalg import LinAlgError
  6. from scipy.linalg import (get_blas_funcs, qr, solve, svd, qr_insert, lstsq)
  7. from scipy.sparse.linalg._isolve.utils import make_system
  8. __all__ = ['gcrotmk']
  9. def _fgmres(matvec, v0, m, atol, lpsolve=None, rpsolve=None, cs=(), outer_v=(),
  10. prepend_outer_v=False):
  11. """
  12. FGMRES Arnoldi process, with optional projection or augmentation
  13. Parameters
  14. ----------
  15. matvec : callable
  16. Operation A*x
  17. v0 : ndarray
  18. Initial vector, normalized to nrm2(v0) == 1
  19. m : int
  20. Number of GMRES rounds
  21. atol : float
  22. Absolute tolerance for early exit
  23. lpsolve : callable
  24. Left preconditioner L
  25. rpsolve : callable
  26. Right preconditioner R
  27. cs : list of (ndarray, ndarray)
  28. Columns of matrices C and U in GCROT
  29. outer_v : list of ndarrays
  30. Augmentation vectors in LGMRES
  31. prepend_outer_v : bool, optional
  32. Whether augmentation vectors come before or after
  33. Krylov iterates
  34. Raises
  35. ------
  36. LinAlgError
  37. If nans encountered
  38. Returns
  39. -------
  40. Q, R : ndarray
  41. QR decomposition of the upper Hessenberg H=QR
  42. B : ndarray
  43. Projections corresponding to matrix C
  44. vs : list of ndarray
  45. Columns of matrix V
  46. zs : list of ndarray
  47. Columns of matrix Z
  48. y : ndarray
  49. Solution to ||H y - e_1||_2 = min!
  50. res : float
  51. The final (preconditioned) residual norm
  52. """
  53. if lpsolve is None:
  54. lpsolve = lambda x: x
  55. if rpsolve is None:
  56. rpsolve = lambda x: x
  57. axpy, dot, scal, nrm2 = get_blas_funcs(['axpy', 'dot', 'scal', 'nrm2'], (v0,))
  58. vs = [v0]
  59. zs = []
  60. y = None
  61. res = np.nan
  62. m = m + len(outer_v)
  63. # Orthogonal projection coefficients
  64. B = np.zeros((len(cs), m), dtype=v0.dtype)
  65. # H is stored in QR factorized form
  66. Q = np.ones((1, 1), dtype=v0.dtype)
  67. R = np.zeros((1, 0), dtype=v0.dtype)
  68. eps = np.finfo(v0.dtype).eps
  69. breakdown = False
  70. # FGMRES Arnoldi process
  71. for j in range(m):
  72. # L A Z = C B + V H
  73. if prepend_outer_v and j < len(outer_v):
  74. z, w = outer_v[j]
  75. elif prepend_outer_v and j == len(outer_v):
  76. z = rpsolve(v0)
  77. w = None
  78. elif not prepend_outer_v and j >= m - len(outer_v):
  79. z, w = outer_v[j - (m - len(outer_v))]
  80. else:
  81. z = rpsolve(vs[-1])
  82. w = None
  83. if w is None:
  84. w = lpsolve(matvec(z))
  85. else:
  86. # w is clobbered below
  87. w = w.copy()
  88. w_norm = nrm2(w)
  89. # GCROT projection: L A -> (1 - C C^H) L A
  90. # i.e. orthogonalize against C
  91. for i, c in enumerate(cs):
  92. alpha = dot(c, w)
  93. B[i,j] = alpha
  94. w = axpy(c, w, c.shape[0], -alpha) # w -= alpha*c
  95. # Orthogonalize against V
  96. hcur = np.zeros(j+2, dtype=Q.dtype)
  97. for i, v in enumerate(vs):
  98. alpha = dot(v, w)
  99. hcur[i] = alpha
  100. w = axpy(v, w, v.shape[0], -alpha) # w -= alpha*v
  101. hcur[i+1] = nrm2(w)
  102. with np.errstate(over='ignore', divide='ignore'):
  103. # Careful with denormals
  104. alpha = 1/hcur[-1]
  105. if np.isfinite(alpha):
  106. w = scal(alpha, w)
  107. if not (hcur[-1] > eps * w_norm):
  108. # w essentially in the span of previous vectors,
  109. # or we have nans. Bail out after updating the QR
  110. # solution.
  111. breakdown = True
  112. vs.append(w)
  113. zs.append(z)
  114. # Arnoldi LSQ problem
  115. # Add new column to H=Q@R, padding other columns with zeros
  116. Q2 = np.zeros((j+2, j+2), dtype=Q.dtype, order='F')
  117. Q2[:j+1,:j+1] = Q
  118. Q2[j+1,j+1] = 1
  119. R2 = np.zeros((j+2, j), dtype=R.dtype, order='F')
  120. R2[:j+1,:] = R
  121. Q, R = qr_insert(Q2, R2, hcur, j, which='col',
  122. overwrite_qru=True, check_finite=False)
  123. # Transformed least squares problem
  124. # || Q R y - inner_res_0 * e_1 ||_2 = min!
  125. # Since R = [R'; 0], solution is y = inner_res_0 (R')^{-1} (Q^H)[:j,0]
  126. # Residual is immediately known
  127. res = abs(Q[0,-1])
  128. # Check for termination
  129. if res < atol or breakdown:
  130. break
  131. if not np.isfinite(R[j,j]):
  132. # nans encountered, bail out
  133. raise LinAlgError()
  134. # -- Get the LSQ problem solution
  135. # The problem is triangular, but the condition number may be
  136. # bad (or in case of breakdown the last diagonal entry may be
  137. # zero), so use lstsq instead of trtrs.
  138. y, _, _, _, = lstsq(R[:j+1,:j+1], Q[0,:j+1].conj())
  139. B = B[:,:j+1]
  140. return Q, R, B, vs, zs, y, res
  141. def gcrotmk(A, b, x0=None, tol=1e-5, maxiter=1000, M=None, callback=None,
  142. m=20, k=None, CU=None, discard_C=False, truncate='oldest',
  143. atol=None):
  144. """
  145. Solve a matrix equation using flexible GCROT(m,k) algorithm.
  146. Parameters
  147. ----------
  148. A : {sparse matrix, ndarray, LinearOperator}
  149. The real or complex N-by-N matrix of the linear system.
  150. Alternatively, ``A`` can be a linear operator which can
  151. produce ``Ax`` using, e.g.,
  152. ``scipy.sparse.linalg.LinearOperator``.
  153. b : ndarray
  154. Right hand side of the linear system. Has shape (N,) or (N,1).
  155. x0 : ndarray
  156. Starting guess for the solution.
  157. tol, atol : float, optional
  158. Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
  159. The default for ``atol`` is `tol`.
  160. .. warning::
  161. The default value for `atol` will be changed in a future release.
  162. For future compatibility, specify `atol` explicitly.
  163. maxiter : int, optional
  164. Maximum number of iterations. Iteration will stop after maxiter
  165. steps even if the specified tolerance has not been achieved.
  166. M : {sparse matrix, ndarray, LinearOperator}, optional
  167. Preconditioner for A. The preconditioner should approximate the
  168. inverse of A. gcrotmk is a 'flexible' algorithm and the preconditioner
  169. can vary from iteration to iteration. Effective preconditioning
  170. dramatically improves the rate of convergence, which implies that
  171. fewer iterations are needed to reach a given error tolerance.
  172. callback : function, optional
  173. User-supplied function to call after each iteration. It is called
  174. as callback(xk), where xk is the current solution vector.
  175. m : int, optional
  176. Number of inner FGMRES iterations per each outer iteration.
  177. Default: 20
  178. k : int, optional
  179. Number of vectors to carry between inner FGMRES iterations.
  180. According to [2]_, good values are around m.
  181. Default: m
  182. CU : list of tuples, optional
  183. List of tuples ``(c, u)`` which contain the columns of the matrices
  184. C and U in the GCROT(m,k) algorithm. For details, see [2]_.
  185. The list given and vectors contained in it are modified in-place.
  186. If not given, start from empty matrices. The ``c`` elements in the
  187. tuples can be ``None``, in which case the vectors are recomputed
  188. via ``c = A u`` on start and orthogonalized as described in [3]_.
  189. discard_C : bool, optional
  190. Discard the C-vectors at the end. Useful if recycling Krylov subspaces
  191. for different linear systems.
  192. truncate : {'oldest', 'smallest'}, optional
  193. Truncation scheme to use. Drop: oldest vectors, or vectors with
  194. smallest singular values using the scheme discussed in [1,2].
  195. See [2]_ for detailed comparison.
  196. Default: 'oldest'
  197. Returns
  198. -------
  199. x : ndarray
  200. The solution found.
  201. info : int
  202. Provides convergence information:
  203. * 0 : successful exit
  204. * >0 : convergence to tolerance not achieved, number of iterations
  205. Examples
  206. --------
  207. >>> import numpy as np
  208. >>> from scipy.sparse import csc_matrix
  209. >>> from scipy.sparse.linalg import gcrotmk
  210. >>> R = np.random.randn(5, 5)
  211. >>> A = csc_matrix(R)
  212. >>> b = np.random.randn(5)
  213. >>> x, exit_code = gcrotmk(A, b)
  214. >>> print(exit_code)
  215. 0
  216. >>> np.allclose(A.dot(x), b)
  217. True
  218. References
  219. ----------
  220. .. [1] E. de Sturler, ''Truncation strategies for optimal Krylov subspace
  221. methods'', SIAM J. Numer. Anal. 36, 864 (1999).
  222. .. [2] J.E. Hicken and D.W. Zingg, ''A simplified and flexible variant
  223. of GCROT for solving nonsymmetric linear systems'',
  224. SIAM J. Sci. Comput. 32, 172 (2010).
  225. .. [3] M.L. Parks, E. de Sturler, G. Mackey, D.D. Johnson, S. Maiti,
  226. ''Recycling Krylov subspaces for sequences of linear systems'',
  227. SIAM J. Sci. Comput. 28, 1651 (2006).
  228. """
  229. A,M,x,b,postprocess = make_system(A,M,x0,b)
  230. if not np.isfinite(b).all():
  231. raise ValueError("RHS must contain only finite numbers")
  232. if truncate not in ('oldest', 'smallest'):
  233. raise ValueError("Invalid value for 'truncate': %r" % (truncate,))
  234. if atol is None:
  235. warnings.warn("scipy.sparse.linalg.gcrotmk called without specifying `atol`. "
  236. "The default value will change in the future. To preserve "
  237. "current behavior, set ``atol=tol``.",
  238. category=DeprecationWarning, stacklevel=2)
  239. atol = tol
  240. matvec = A.matvec
  241. psolve = M.matvec
  242. if CU is None:
  243. CU = []
  244. if k is None:
  245. k = m
  246. axpy, dot, scal = None, None, None
  247. if x0 is None:
  248. r = b.copy()
  249. else:
  250. r = b - matvec(x)
  251. axpy, dot, scal, nrm2 = get_blas_funcs(['axpy', 'dot', 'scal', 'nrm2'], (x, r))
  252. b_norm = nrm2(b)
  253. if b_norm == 0:
  254. x = b
  255. return (postprocess(x), 0)
  256. if discard_C:
  257. CU[:] = [(None, u) for c, u in CU]
  258. # Reorthogonalize old vectors
  259. if CU:
  260. # Sort already existing vectors to the front
  261. CU.sort(key=lambda cu: cu[0] is not None)
  262. # Fill-in missing ones
  263. C = np.empty((A.shape[0], len(CU)), dtype=r.dtype, order='F')
  264. us = []
  265. j = 0
  266. while CU:
  267. # More memory-efficient: throw away old vectors as we go
  268. c, u = CU.pop(0)
  269. if c is None:
  270. c = matvec(u)
  271. C[:,j] = c
  272. j += 1
  273. us.append(u)
  274. # Orthogonalize
  275. Q, R, P = qr(C, overwrite_a=True, mode='economic', pivoting=True)
  276. del C
  277. # C := Q
  278. cs = list(Q.T)
  279. # U := U P R^-1, back-substitution
  280. new_us = []
  281. for j in range(len(cs)):
  282. u = us[P[j]]
  283. for i in range(j):
  284. u = axpy(us[P[i]], u, u.shape[0], -R[i,j])
  285. if abs(R[j,j]) < 1e-12 * abs(R[0,0]):
  286. # discard rest of the vectors
  287. break
  288. u = scal(1.0/R[j,j], u)
  289. new_us.append(u)
  290. # Form the new CU lists
  291. CU[:] = list(zip(cs, new_us))[::-1]
  292. if CU:
  293. axpy, dot = get_blas_funcs(['axpy', 'dot'], (r,))
  294. # Solve first the projection operation with respect to the CU
  295. # vectors. This corresponds to modifying the initial guess to
  296. # be
  297. #
  298. # x' = x + U y
  299. # y = argmin_y || b - A (x + U y) ||^2
  300. #
  301. # The solution is y = C^H (b - A x)
  302. for c, u in CU:
  303. yc = dot(c, r)
  304. x = axpy(u, x, x.shape[0], yc)
  305. r = axpy(c, r, r.shape[0], -yc)
  306. # GCROT main iteration
  307. for j_outer in range(maxiter):
  308. # -- callback
  309. if callback is not None:
  310. callback(x)
  311. beta = nrm2(r)
  312. # -- check stopping condition
  313. beta_tol = max(atol, tol * b_norm)
  314. if beta <= beta_tol and (j_outer > 0 or CU):
  315. # recompute residual to avoid rounding error
  316. r = b - matvec(x)
  317. beta = nrm2(r)
  318. if beta <= beta_tol:
  319. j_outer = -1
  320. break
  321. ml = m + max(k - len(CU), 0)
  322. cs = [c for c, u in CU]
  323. try:
  324. Q, R, B, vs, zs, y, pres = _fgmres(matvec,
  325. r/beta,
  326. ml,
  327. rpsolve=psolve,
  328. atol=max(atol, tol*b_norm)/beta,
  329. cs=cs)
  330. y *= beta
  331. except LinAlgError:
  332. # Floating point over/underflow, non-finite result from
  333. # matmul etc. -- report failure.
  334. break
  335. #
  336. # At this point,
  337. #
  338. # [A U, A Z] = [C, V] G; G = [ I B ]
  339. # [ 0 H ]
  340. #
  341. # where [C, V] has orthonormal columns, and r = beta v_0. Moreover,
  342. #
  343. # || b - A (x + Z y + U q) ||_2 = || r - C B y - V H y - C q ||_2 = min!
  344. #
  345. # from which y = argmin_y || beta e_1 - H y ||_2, and q = -B y
  346. #
  347. #
  348. # GCROT(m,k) update
  349. #
  350. # Define new outer vectors
  351. # ux := (Z - U B) y
  352. ux = zs[0]*y[0]
  353. for z, yc in zip(zs[1:], y[1:]):
  354. ux = axpy(z, ux, ux.shape[0], yc) # ux += z*yc
  355. by = B.dot(y)
  356. for cu, byc in zip(CU, by):
  357. c, u = cu
  358. ux = axpy(u, ux, ux.shape[0], -byc) # ux -= u*byc
  359. # cx := V H y
  360. hy = Q.dot(R.dot(y))
  361. cx = vs[0] * hy[0]
  362. for v, hyc in zip(vs[1:], hy[1:]):
  363. cx = axpy(v, cx, cx.shape[0], hyc) # cx += v*hyc
  364. # Normalize cx, maintaining cx = A ux
  365. # This new cx is orthogonal to the previous C, by construction
  366. try:
  367. alpha = 1/nrm2(cx)
  368. if not np.isfinite(alpha):
  369. raise FloatingPointError()
  370. except (FloatingPointError, ZeroDivisionError):
  371. # Cannot update, so skip it
  372. continue
  373. cx = scal(alpha, cx)
  374. ux = scal(alpha, ux)
  375. # Update residual and solution
  376. gamma = dot(cx, r)
  377. r = axpy(cx, r, r.shape[0], -gamma) # r -= gamma*cx
  378. x = axpy(ux, x, x.shape[0], gamma) # x += gamma*ux
  379. # Truncate CU
  380. if truncate == 'oldest':
  381. while len(CU) >= k and CU:
  382. del CU[0]
  383. elif truncate == 'smallest':
  384. if len(CU) >= k and CU:
  385. # cf. [1,2]
  386. D = solve(R[:-1,:].T, B.T).T
  387. W, sigma, V = svd(D)
  388. # C := C W[:,:k-1], U := U W[:,:k-1]
  389. new_CU = []
  390. for j, w in enumerate(W[:,:k-1].T):
  391. c, u = CU[0]
  392. c = c * w[0]
  393. u = u * w[0]
  394. for cup, wp in zip(CU[1:], w[1:]):
  395. cp, up = cup
  396. c = axpy(cp, c, c.shape[0], wp)
  397. u = axpy(up, u, u.shape[0], wp)
  398. # Reorthogonalize at the same time; not necessary
  399. # in exact arithmetic, but floating point error
  400. # tends to accumulate here
  401. for cp, up in new_CU:
  402. alpha = dot(cp, c)
  403. c = axpy(cp, c, c.shape[0], -alpha)
  404. u = axpy(up, u, u.shape[0], -alpha)
  405. alpha = nrm2(c)
  406. c = scal(1.0/alpha, c)
  407. u = scal(1.0/alpha, u)
  408. new_CU.append((c, u))
  409. CU[:] = new_CU
  410. # Add new vector to CU
  411. CU.append((cx, ux))
  412. # Include the solution vector to the span
  413. CU.append((None, x.copy()))
  414. if discard_C:
  415. CU[:] = [(None, uz) for cz, uz in CU]
  416. return postprocess(x), j_outer + 1