minres.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. from numpy import inner, zeros, inf, finfo
  2. from numpy.linalg import norm
  3. from math import sqrt
  4. from .utils import make_system
  5. __all__ = ['minres']
  6. def minres(A, b, x0=None, shift=0.0, tol=1e-5, maxiter=None,
  7. M=None, callback=None, show=False, check=False):
  8. """
  9. Use MINimum RESidual iteration to solve Ax=b
  10. MINRES minimizes norm(Ax - b) for a real symmetric matrix A. Unlike
  11. the Conjugate Gradient method, A can be indefinite or singular.
  12. If shift != 0 then the method solves (A - shift*I)x = b
  13. Parameters
  14. ----------
  15. A : {sparse matrix, ndarray, LinearOperator}
  16. The real symmetric N-by-N matrix of the linear system
  17. Alternatively, ``A`` can be a linear operator which can
  18. produce ``Ax`` using, e.g.,
  19. ``scipy.sparse.linalg.LinearOperator``.
  20. b : ndarray
  21. Right hand side of the linear system. Has shape (N,) or (N,1).
  22. Returns
  23. -------
  24. x : ndarray
  25. The converged solution.
  26. info : integer
  27. Provides convergence information:
  28. 0 : successful exit
  29. >0 : convergence to tolerance not achieved, number of iterations
  30. <0 : illegal input or breakdown
  31. Other Parameters
  32. ----------------
  33. x0 : ndarray
  34. Starting guess for the solution.
  35. shift : float
  36. Value to apply to the system ``(A - shift * I)x = b``. Default is 0.
  37. tol : float
  38. Tolerance to achieve. The algorithm terminates when the relative
  39. residual is below `tol`.
  40. maxiter : integer
  41. Maximum number of iterations. Iteration will stop after maxiter
  42. steps even if the specified tolerance has not been achieved.
  43. M : {sparse matrix, ndarray, LinearOperator}
  44. Preconditioner for A. The preconditioner should approximate the
  45. inverse of A. Effective preconditioning dramatically improves the
  46. rate of convergence, which implies that fewer iterations are needed
  47. to reach a given error tolerance.
  48. callback : function
  49. User-supplied function to call after each iteration. It is called
  50. as callback(xk), where xk is the current solution vector.
  51. show : bool
  52. If ``True``, print out a summary and metrics related to the solution
  53. during iterations. Default is ``False``.
  54. check : bool
  55. If ``True``, run additional input validation to check that `A` and
  56. `M` (if specified) are symmetric. Default is ``False``.
  57. Examples
  58. --------
  59. >>> import numpy as np
  60. >>> from scipy.sparse import csc_matrix
  61. >>> from scipy.sparse.linalg import minres
  62. >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
  63. >>> A = A + A.T
  64. >>> b = np.array([2, 4, -1], dtype=float)
  65. >>> x, exitCode = minres(A, b)
  66. >>> print(exitCode) # 0 indicates successful convergence
  67. 0
  68. >>> np.allclose(A.dot(x), b)
  69. True
  70. References
  71. ----------
  72. Solution of sparse indefinite systems of linear equations,
  73. C. C. Paige and M. A. Saunders (1975),
  74. SIAM J. Numer. Anal. 12(4), pp. 617-629.
  75. https://web.stanford.edu/group/SOL/software/minres/
  76. This file is a translation of the following MATLAB implementation:
  77. https://web.stanford.edu/group/SOL/software/minres/minres-matlab.zip
  78. """
  79. A, M, x, b, postprocess = make_system(A, M, x0, b)
  80. matvec = A.matvec
  81. psolve = M.matvec
  82. first = 'Enter minres. '
  83. last = 'Exit minres. '
  84. n = A.shape[0]
  85. if maxiter is None:
  86. maxiter = 5 * n
  87. msg = [' beta2 = 0. If M = I, b and x are eigenvectors ', # -1
  88. ' beta1 = 0. The exact solution is x0 ', # 0
  89. ' A solution to Ax = b was found, given rtol ', # 1
  90. ' A least-squares solution was found, given rtol ', # 2
  91. ' Reasonable accuracy achieved, given eps ', # 3
  92. ' x has converged to an eigenvector ', # 4
  93. ' acond has exceeded 0.1/eps ', # 5
  94. ' The iteration limit was reached ', # 6
  95. ' A does not define a symmetric matrix ', # 7
  96. ' M does not define a symmetric matrix ', # 8
  97. ' M does not define a pos-def preconditioner '] # 9
  98. if show:
  99. print(first + 'Solution of symmetric Ax = b')
  100. print(first + 'n = %3g shift = %23.14e' % (n,shift))
  101. print(first + 'itnlim = %3g rtol = %11.2e' % (maxiter,tol))
  102. print()
  103. istop = 0
  104. itn = 0
  105. Anorm = 0
  106. Acond = 0
  107. rnorm = 0
  108. ynorm = 0
  109. xtype = x.dtype
  110. eps = finfo(xtype).eps
  111. # Set up y and v for the first Lanczos vector v1.
  112. # y = beta1 P' v1, where P = C**(-1).
  113. # v is really P' v1.
  114. if x0 is None:
  115. r1 = b.copy()
  116. else:
  117. r1 = b - A@x
  118. y = psolve(r1)
  119. beta1 = inner(r1, y)
  120. if beta1 < 0:
  121. raise ValueError('indefinite preconditioner')
  122. elif beta1 == 0:
  123. return (postprocess(x), 0)
  124. bnorm = norm(b)
  125. if bnorm == 0:
  126. x = b
  127. return (postprocess(x), 0)
  128. beta1 = sqrt(beta1)
  129. if check:
  130. # are these too strict?
  131. # see if A is symmetric
  132. w = matvec(y)
  133. r2 = matvec(w)
  134. s = inner(w,w)
  135. t = inner(y,r2)
  136. z = abs(s - t)
  137. epsa = (s + eps) * eps**(1.0/3.0)
  138. if z > epsa:
  139. raise ValueError('non-symmetric matrix')
  140. # see if M is symmetric
  141. r2 = psolve(y)
  142. s = inner(y,y)
  143. t = inner(r1,r2)
  144. z = abs(s - t)
  145. epsa = (s + eps) * eps**(1.0/3.0)
  146. if z > epsa:
  147. raise ValueError('non-symmetric preconditioner')
  148. # Initialize other quantities
  149. oldb = 0
  150. beta = beta1
  151. dbar = 0
  152. epsln = 0
  153. qrnorm = beta1
  154. phibar = beta1
  155. rhs1 = beta1
  156. rhs2 = 0
  157. tnorm2 = 0
  158. gmax = 0
  159. gmin = finfo(xtype).max
  160. cs = -1
  161. sn = 0
  162. w = zeros(n, dtype=xtype)
  163. w2 = zeros(n, dtype=xtype)
  164. r2 = r1
  165. if show:
  166. print()
  167. print()
  168. print(' Itn x(1) Compatible LS norm(A) cond(A) gbar/|A|')
  169. while itn < maxiter:
  170. itn += 1
  171. s = 1.0/beta
  172. v = s*y
  173. y = matvec(v)
  174. y = y - shift * v
  175. if itn >= 2:
  176. y = y - (beta/oldb)*r1
  177. alfa = inner(v,y)
  178. y = y - (alfa/beta)*r2
  179. r1 = r2
  180. r2 = y
  181. y = psolve(r2)
  182. oldb = beta
  183. beta = inner(r2,y)
  184. if beta < 0:
  185. raise ValueError('non-symmetric matrix')
  186. beta = sqrt(beta)
  187. tnorm2 += alfa**2 + oldb**2 + beta**2
  188. if itn == 1:
  189. if beta/beta1 <= 10*eps:
  190. istop = -1 # Terminate later
  191. # Apply previous rotation Qk-1 to get
  192. # [deltak epslnk+1] = [cs sn][dbark 0 ]
  193. # [gbar k dbar k+1] [sn -cs][alfak betak+1].
  194. oldeps = epsln
  195. delta = cs * dbar + sn * alfa # delta1 = 0 deltak
  196. gbar = sn * dbar - cs * alfa # gbar 1 = alfa1 gbar k
  197. epsln = sn * beta # epsln2 = 0 epslnk+1
  198. dbar = - cs * beta # dbar 2 = beta2 dbar k+1
  199. root = norm([gbar, dbar])
  200. Arnorm = phibar * root
  201. # Compute the next plane rotation Qk
  202. gamma = norm([gbar, beta]) # gammak
  203. gamma = max(gamma, eps)
  204. cs = gbar / gamma # ck
  205. sn = beta / gamma # sk
  206. phi = cs * phibar # phik
  207. phibar = sn * phibar # phibark+1
  208. # Update x.
  209. denom = 1.0/gamma
  210. w1 = w2
  211. w2 = w
  212. w = (v - oldeps*w1 - delta*w2) * denom
  213. x = x + phi*w
  214. # Go round again.
  215. gmax = max(gmax, gamma)
  216. gmin = min(gmin, gamma)
  217. z = rhs1 / gamma
  218. rhs1 = rhs2 - delta*z
  219. rhs2 = - epsln*z
  220. # Estimate various norms and test for convergence.
  221. Anorm = sqrt(tnorm2)
  222. ynorm = norm(x)
  223. epsa = Anorm * eps
  224. epsx = Anorm * ynorm * eps
  225. epsr = Anorm * ynorm * tol
  226. diag = gbar
  227. if diag == 0:
  228. diag = epsa
  229. qrnorm = phibar
  230. rnorm = qrnorm
  231. if ynorm == 0 or Anorm == 0:
  232. test1 = inf
  233. else:
  234. test1 = rnorm / (Anorm*ynorm) # ||r|| / (||A|| ||x||)
  235. if Anorm == 0:
  236. test2 = inf
  237. else:
  238. test2 = root / Anorm # ||Ar|| / (||A|| ||r||)
  239. # Estimate cond(A).
  240. # In this version we look at the diagonals of R in the
  241. # factorization of the lower Hessenberg matrix, Q @ H = R,
  242. # where H is the tridiagonal matrix from Lanczos with one
  243. # extra row, beta(k+1) e_k^T.
  244. Acond = gmax/gmin
  245. # See if any of the stopping criteria are satisfied.
  246. # In rare cases, istop is already -1 from above (Abar = const*I).
  247. if istop == 0:
  248. t1 = 1 + test1 # These tests work if tol < eps
  249. t2 = 1 + test2
  250. if t2 <= 1:
  251. istop = 2
  252. if t1 <= 1:
  253. istop = 1
  254. if itn >= maxiter:
  255. istop = 6
  256. if Acond >= 0.1/eps:
  257. istop = 4
  258. if epsx >= beta1:
  259. istop = 3
  260. # if rnorm <= epsx : istop = 2
  261. # if rnorm <= epsr : istop = 1
  262. if test2 <= tol:
  263. istop = 2
  264. if test1 <= tol:
  265. istop = 1
  266. # See if it is time to print something.
  267. prnt = False
  268. if n <= 40:
  269. prnt = True
  270. if itn <= 10:
  271. prnt = True
  272. if itn >= maxiter-10:
  273. prnt = True
  274. if itn % 10 == 0:
  275. prnt = True
  276. if qrnorm <= 10*epsx:
  277. prnt = True
  278. if qrnorm <= 10*epsr:
  279. prnt = True
  280. if Acond <= 1e-2/eps:
  281. prnt = True
  282. if istop != 0:
  283. prnt = True
  284. if show and prnt:
  285. str1 = '%6g %12.5e %10.3e' % (itn, x[0], test1)
  286. str2 = ' %10.3e' % (test2,)
  287. str3 = ' %8.1e %8.1e %8.1e' % (Anorm, Acond, gbar/Anorm)
  288. print(str1 + str2 + str3)
  289. if itn % 10 == 0:
  290. print()
  291. if callback is not None:
  292. callback(x)
  293. if istop != 0:
  294. break # TODO check this
  295. if show:
  296. print()
  297. print(last + ' istop = %3g itn =%5g' % (istop,itn))
  298. print(last + ' Anorm = %12.4e Acond = %12.4e' % (Anorm,Acond))
  299. print(last + ' rnorm = %12.4e ynorm = %12.4e' % (rnorm,ynorm))
  300. print(last + ' Arnorm = %12.4e' % (Arnorm,))
  301. print(last + msg[istop+1])
  302. if istop == 6:
  303. info = maxiter
  304. else:
  305. info = 0
  306. return (postprocess(x),info)
  307. if __name__ == '__main__':
  308. from numpy import arange
  309. from scipy.sparse import spdiags
  310. n = 10
  311. residuals = []
  312. def cb(x):
  313. residuals.append(norm(b - A@x))
  314. # A = poisson((10,),format='csr')
  315. A = spdiags([arange(1,n+1,dtype=float)], [0], n, n, format='csr')
  316. M = spdiags([1.0/arange(1,n+1,dtype=float)], [0], n, n, format='csr')
  317. A.psolve = M.matvec
  318. b = zeros(A.shape[0])
  319. x = minres(A,b,tol=1e-12,maxiter=None,callback=cb)
  320. # x = cg(A,b,x0=b,tol=1e-12,maxiter=None,callback=cb)[0]