iterative.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881
  1. """Iterative methods for solving linear systems"""
  2. __all__ = ['bicg','bicgstab','cg','cgs','gmres','qmr']
  3. import warnings
  4. from textwrap import dedent
  5. import numpy as np
  6. from . import _iterative
  7. from scipy.sparse.linalg._interface import LinearOperator
  8. from .utils import make_system
  9. from scipy._lib._util import _aligned_zeros
  10. from scipy._lib._threadsafety import non_reentrant
  11. _type_conv = {'f':'s', 'd':'d', 'F':'c', 'D':'z'}
  12. # Part of the docstring common to all iterative solvers
  13. common_doc1 = \
  14. """
  15. Parameters
  16. ----------
  17. A : {sparse matrix, ndarray, LinearOperator}"""
  18. common_doc2 = \
  19. """b : ndarray
  20. Right hand side of the linear system. Has shape (N,) or (N,1).
  21. Returns
  22. -------
  23. x : ndarray
  24. The converged solution.
  25. info : integer
  26. Provides convergence information:
  27. 0 : successful exit
  28. >0 : convergence to tolerance not achieved, number of iterations
  29. <0 : illegal input or breakdown
  30. Other Parameters
  31. ----------------
  32. x0 : ndarray
  33. Starting guess for the solution.
  34. tol, atol : float, optional
  35. Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
  36. The default for ``atol`` is ``'legacy'``, which emulates
  37. a different legacy behavior.
  38. .. warning::
  39. The default value for `atol` will be changed in a future release.
  40. For future compatibility, specify `atol` explicitly.
  41. maxiter : integer
  42. Maximum number of iterations. Iteration will stop after maxiter
  43. steps even if the specified tolerance has not been achieved.
  44. M : {sparse matrix, ndarray, LinearOperator}
  45. Preconditioner for A. The preconditioner should approximate the
  46. inverse of A. Effective preconditioning dramatically improves the
  47. rate of convergence, which implies that fewer iterations are needed
  48. to reach a given error tolerance.
  49. callback : function
  50. User-supplied function to call after each iteration. It is called
  51. as callback(xk), where xk is the current solution vector.
  52. """
  53. def _stoptest(residual, atol):
  54. """
  55. Successful termination condition for the solvers.
  56. """
  57. resid = np.linalg.norm(residual)
  58. if resid <= atol:
  59. return resid, 1
  60. else:
  61. return resid, 0
  62. def _get_atol(tol, atol, bnrm2, get_residual, routine_name):
  63. """
  64. Parse arguments for absolute tolerance in termination condition.
  65. Parameters
  66. ----------
  67. tol, atol : object
  68. The arguments passed into the solver routine by user.
  69. bnrm2 : float
  70. 2-norm of the rhs vector.
  71. get_residual : callable
  72. Callable ``get_residual()`` that returns the initial value of
  73. the residual.
  74. routine_name : str
  75. Name of the routine.
  76. """
  77. if atol is None:
  78. warnings.warn("scipy.sparse.linalg.{name} called without specifying `atol`. "
  79. "The default value will be changed in a future release. "
  80. "For compatibility, specify a value for `atol` explicitly, e.g., "
  81. "``{name}(..., atol=0)``, or to retain the old behavior "
  82. "``{name}(..., atol='legacy')``".format(name=routine_name),
  83. category=DeprecationWarning, stacklevel=4)
  84. atol = 'legacy'
  85. tol = float(tol)
  86. if atol == 'legacy':
  87. # emulate old legacy behavior
  88. resid = get_residual()
  89. if resid <= tol:
  90. return 'exit'
  91. if bnrm2 == 0:
  92. return tol
  93. else:
  94. return tol * float(bnrm2)
  95. else:
  96. return max(float(atol), tol * float(bnrm2))
  97. def set_docstring(header, Ainfo, footer='', atol_default='0'):
  98. def combine(fn):
  99. fn.__doc__ = '\n'.join((header, common_doc1,
  100. ' ' + Ainfo.replace('\n', '\n '),
  101. common_doc2, dedent(footer)))
  102. return fn
  103. return combine
  104. @set_docstring('Use BIConjugate Gradient iteration to solve ``Ax = b``.',
  105. 'The real or complex N-by-N matrix of the linear system.\n'
  106. 'Alternatively, ``A`` can be a linear operator which can\n'
  107. 'produce ``Ax`` and ``A^T x`` using, e.g.,\n'
  108. '``scipy.sparse.linalg.LinearOperator``.',
  109. footer="""\
  110. Examples
  111. --------
  112. >>> import numpy as np
  113. >>> from scipy.sparse import csc_matrix
  114. >>> from scipy.sparse.linalg import bicg
  115. >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
  116. >>> b = np.array([2, 4, -1], dtype=float)
  117. >>> x, exitCode = bicg(A, b)
  118. >>> print(exitCode) # 0 indicates successful convergence
  119. 0
  120. >>> np.allclose(A.dot(x), b)
  121. True
  122. """
  123. )
  124. @non_reentrant()
  125. def bicg(A, b, x0=None, tol=1e-5, maxiter=None, M=None, callback=None, atol=None):
  126. A,M,x,b,postprocess = make_system(A, M, x0, b)
  127. n = len(b)
  128. if maxiter is None:
  129. maxiter = n*10
  130. matvec, rmatvec = A.matvec, A.rmatvec
  131. psolve, rpsolve = M.matvec, M.rmatvec
  132. ltr = _type_conv[x.dtype.char]
  133. revcom = getattr(_iterative, ltr + 'bicgrevcom')
  134. get_residual = lambda: np.linalg.norm(matvec(x) - b)
  135. atol = _get_atol(tol, atol, np.linalg.norm(b), get_residual, 'bicg')
  136. if atol == 'exit':
  137. return postprocess(x), 0
  138. resid = atol
  139. ndx1 = 1
  140. ndx2 = -1
  141. # Use _aligned_zeros to work around a f2py bug in Numpy 1.9.1
  142. work = _aligned_zeros(6*n,dtype=x.dtype)
  143. ijob = 1
  144. info = 0
  145. ftflag = True
  146. iter_ = maxiter
  147. while True:
  148. olditer = iter_
  149. x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
  150. revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
  151. if callback is not None and iter_ > olditer:
  152. callback(x)
  153. slice1 = slice(ndx1-1, ndx1-1+n)
  154. slice2 = slice(ndx2-1, ndx2-1+n)
  155. if (ijob == -1):
  156. if callback is not None:
  157. callback(x)
  158. break
  159. elif (ijob == 1):
  160. work[slice2] *= sclr2
  161. work[slice2] += sclr1*matvec(work[slice1])
  162. elif (ijob == 2):
  163. work[slice2] *= sclr2
  164. work[slice2] += sclr1*rmatvec(work[slice1])
  165. elif (ijob == 3):
  166. work[slice1] = psolve(work[slice2])
  167. elif (ijob == 4):
  168. work[slice1] = rpsolve(work[slice2])
  169. elif (ijob == 5):
  170. work[slice2] *= sclr2
  171. work[slice2] += sclr1*matvec(x)
  172. elif (ijob == 6):
  173. if ftflag:
  174. info = -1
  175. ftflag = False
  176. resid, info = _stoptest(work[slice1], atol)
  177. ijob = 2
  178. if info > 0 and iter_ == maxiter and not (resid <= atol):
  179. # info isn't set appropriately otherwise
  180. info = iter_
  181. return postprocess(x), info
  182. @set_docstring('Use BIConjugate Gradient STABilized iteration to solve '
  183. '``Ax = b``.',
  184. 'The real or complex N-by-N matrix of the linear system.\n'
  185. 'Alternatively, ``A`` can be a linear operator which can\n'
  186. 'produce ``Ax`` using, e.g.,\n'
  187. '``scipy.sparse.linalg.LinearOperator``.',
  188. footer="""\
  189. Examples
  190. --------
  191. >>> import numpy as np
  192. >>> from scipy.sparse import csc_matrix
  193. >>> from scipy.sparse.linalg import bicgstab
  194. >>> R = np.array([[4, 2, 0, 1],
  195. ... [3, 0, 0, 2],
  196. ... [0, 1, 1, 1],
  197. ... [0, 2, 1, 0]])
  198. >>> A = csc_matrix(R)
  199. >>> b = np.array([-1, -0.5, -1, 2])
  200. >>> x, exit_code = bicgstab(A, b)
  201. >>> print(exit_code) # 0 indicates successful convergence
  202. 0
  203. >>> np.allclose(A.dot(x), b)
  204. True
  205. """)
  206. @non_reentrant()
  207. def bicgstab(A, b, x0=None, tol=1e-5, maxiter=None, M=None, callback=None, atol=None):
  208. A, M, x, b, postprocess = make_system(A, M, x0, b)
  209. n = len(b)
  210. if maxiter is None:
  211. maxiter = n*10
  212. matvec = A.matvec
  213. psolve = M.matvec
  214. ltr = _type_conv[x.dtype.char]
  215. revcom = getattr(_iterative, ltr + 'bicgstabrevcom')
  216. get_residual = lambda: np.linalg.norm(matvec(x) - b)
  217. atol = _get_atol(tol, atol, np.linalg.norm(b), get_residual, 'bicgstab')
  218. if atol == 'exit':
  219. return postprocess(x), 0
  220. resid = atol
  221. ndx1 = 1
  222. ndx2 = -1
  223. # Use _aligned_zeros to work around a f2py bug in Numpy 1.9.1
  224. work = _aligned_zeros(7*n,dtype=x.dtype)
  225. ijob = 1
  226. info = 0
  227. ftflag = True
  228. iter_ = maxiter
  229. while True:
  230. olditer = iter_
  231. x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
  232. revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
  233. if callback is not None and iter_ > olditer:
  234. callback(x)
  235. slice1 = slice(ndx1-1, ndx1-1+n)
  236. slice2 = slice(ndx2-1, ndx2-1+n)
  237. if (ijob == -1):
  238. if callback is not None:
  239. callback(x)
  240. break
  241. elif (ijob == 1):
  242. work[slice2] *= sclr2
  243. work[slice2] += sclr1*matvec(work[slice1])
  244. elif (ijob == 2):
  245. work[slice1] = psolve(work[slice2])
  246. elif (ijob == 3):
  247. work[slice2] *= sclr2
  248. work[slice2] += sclr1*matvec(x)
  249. elif (ijob == 4):
  250. if ftflag:
  251. info = -1
  252. ftflag = False
  253. resid, info = _stoptest(work[slice1], atol)
  254. ijob = 2
  255. if info > 0 and iter_ == maxiter and not (resid <= atol):
  256. # info isn't set appropriately otherwise
  257. info = iter_
  258. return postprocess(x), info
  259. @set_docstring('Use Conjugate Gradient iteration to solve ``Ax = b``.',
  260. 'The real or complex N-by-N matrix of the linear system.\n'
  261. '``A`` must represent a hermitian, positive definite matrix.\n'
  262. 'Alternatively, ``A`` can be a linear operator which can\n'
  263. 'produce ``Ax`` using, e.g.,\n'
  264. '``scipy.sparse.linalg.LinearOperator``.',
  265. footer="""\
  266. Examples
  267. --------
  268. >>> import numpy as np
  269. >>> from scipy.sparse import csc_matrix
  270. >>> from scipy.sparse.linalg import cg
  271. >>> P = np.array([[4, 0, 1, 0],
  272. ... [0, 5, 0, 0],
  273. ... [1, 0, 3, 2],
  274. ... [0, 0, 2, 4]])
  275. >>> A = csc_matrix(P)
  276. >>> b = np.array([-1, -0.5, -1, 2])
  277. >>> x, exit_code = cg(A, b)
  278. >>> print(exit_code) # 0 indicates successful convergence
  279. 0
  280. >>> np.allclose(A.dot(x), b)
  281. True
  282. """)
  283. @non_reentrant()
  284. def cg(A, b, x0=None, tol=1e-5, maxiter=None, M=None, callback=None, atol=None):
  285. A, M, x, b, postprocess = make_system(A, M, x0, b)
  286. n = len(b)
  287. if maxiter is None:
  288. maxiter = n*10
  289. matvec = A.matvec
  290. psolve = M.matvec
  291. ltr = _type_conv[x.dtype.char]
  292. revcom = getattr(_iterative, ltr + 'cgrevcom')
  293. get_residual = lambda: np.linalg.norm(matvec(x) - b)
  294. atol = _get_atol(tol, atol, np.linalg.norm(b), get_residual, 'cg')
  295. if atol == 'exit':
  296. return postprocess(x), 0
  297. resid = atol
  298. ndx1 = 1
  299. ndx2 = -1
  300. # Use _aligned_zeros to work around a f2py bug in Numpy 1.9.1
  301. work = _aligned_zeros(4*n,dtype=x.dtype)
  302. ijob = 1
  303. info = 0
  304. ftflag = True
  305. iter_ = maxiter
  306. while True:
  307. olditer = iter_
  308. x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
  309. revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
  310. if callback is not None and iter_ > olditer:
  311. callback(x)
  312. slice1 = slice(ndx1-1, ndx1-1+n)
  313. slice2 = slice(ndx2-1, ndx2-1+n)
  314. if (ijob == -1):
  315. if callback is not None:
  316. callback(x)
  317. break
  318. elif (ijob == 1):
  319. work[slice2] *= sclr2
  320. work[slice2] += sclr1*matvec(work[slice1])
  321. elif (ijob == 2):
  322. work[slice1] = psolve(work[slice2])
  323. elif (ijob == 3):
  324. work[slice2] *= sclr2
  325. work[slice2] += sclr1*matvec(x)
  326. elif (ijob == 4):
  327. if ftflag:
  328. info = -1
  329. ftflag = False
  330. resid, info = _stoptest(work[slice1], atol)
  331. if info == 1 and iter_ > 1:
  332. # recompute residual and recheck, to avoid
  333. # accumulating rounding error
  334. work[slice1] = b - matvec(x)
  335. resid, info = _stoptest(work[slice1], atol)
  336. ijob = 2
  337. if info > 0 and iter_ == maxiter and not (resid <= atol):
  338. # info isn't set appropriately otherwise
  339. info = iter_
  340. return postprocess(x), info
  341. @set_docstring('Use Conjugate Gradient Squared iteration to solve ``Ax = b``.',
  342. 'The real-valued N-by-N matrix of the linear system.\n'
  343. 'Alternatively, ``A`` can be a linear operator which can\n'
  344. 'produce ``Ax`` using, e.g.,\n'
  345. '``scipy.sparse.linalg.LinearOperator``.',
  346. footer="""\
  347. Examples
  348. --------
  349. >>> import numpy as np
  350. >>> from scipy.sparse import csc_matrix
  351. >>> from scipy.sparse.linalg import cgs
  352. >>> R = np.array([[4, 2, 0, 1],
  353. ... [3, 0, 0, 2],
  354. ... [0, 1, 1, 1],
  355. ... [0, 2, 1, 0]])
  356. >>> A = csc_matrix(R)
  357. >>> b = np.array([-1, -0.5, -1, 2])
  358. >>> x, exit_code = cgs(A, b)
  359. >>> print(exit_code) # 0 indicates successful convergence
  360. 0
  361. >>> np.allclose(A.dot(x), b)
  362. True
  363. """
  364. )
  365. @non_reentrant()
  366. def cgs(A, b, x0=None, tol=1e-5, maxiter=None, M=None, callback=None, atol=None):
  367. A, M, x, b, postprocess = make_system(A, M, x0, b)
  368. n = len(b)
  369. if maxiter is None:
  370. maxiter = n*10
  371. matvec = A.matvec
  372. psolve = M.matvec
  373. ltr = _type_conv[x.dtype.char]
  374. revcom = getattr(_iterative, ltr + 'cgsrevcom')
  375. get_residual = lambda: np.linalg.norm(matvec(x) - b)
  376. atol = _get_atol(tol, atol, np.linalg.norm(b), get_residual, 'cgs')
  377. if atol == 'exit':
  378. return postprocess(x), 0
  379. resid = atol
  380. ndx1 = 1
  381. ndx2 = -1
  382. # Use _aligned_zeros to work around a f2py bug in Numpy 1.9.1
  383. work = _aligned_zeros(7*n,dtype=x.dtype)
  384. ijob = 1
  385. info = 0
  386. ftflag = True
  387. iter_ = maxiter
  388. while True:
  389. olditer = iter_
  390. x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
  391. revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
  392. if callback is not None and iter_ > olditer:
  393. callback(x)
  394. slice1 = slice(ndx1-1, ndx1-1+n)
  395. slice2 = slice(ndx2-1, ndx2-1+n)
  396. if (ijob == -1):
  397. if callback is not None:
  398. callback(x)
  399. break
  400. elif (ijob == 1):
  401. work[slice2] *= sclr2
  402. work[slice2] += sclr1*matvec(work[slice1])
  403. elif (ijob == 2):
  404. work[slice1] = psolve(work[slice2])
  405. elif (ijob == 3):
  406. work[slice2] *= sclr2
  407. work[slice2] += sclr1*matvec(x)
  408. elif (ijob == 4):
  409. if ftflag:
  410. info = -1
  411. ftflag = False
  412. resid, info = _stoptest(work[slice1], atol)
  413. if info == 1 and iter_ > 1:
  414. # recompute residual and recheck, to avoid
  415. # accumulating rounding error
  416. work[slice1] = b - matvec(x)
  417. resid, info = _stoptest(work[slice1], atol)
  418. ijob = 2
  419. if info == -10:
  420. # termination due to breakdown: check for convergence
  421. resid, ok = _stoptest(b - matvec(x), atol)
  422. if ok:
  423. info = 0
  424. if info > 0 and iter_ == maxiter and not (resid <= atol):
  425. # info isn't set appropriately otherwise
  426. info = iter_
  427. return postprocess(x), info
  428. @non_reentrant()
  429. def gmres(A, b, x0=None, tol=1e-5, restart=None, maxiter=None, M=None, callback=None,
  430. restrt=None, atol=None, callback_type=None):
  431. """
  432. Use Generalized Minimal RESidual iteration to solve ``Ax = b``.
  433. Parameters
  434. ----------
  435. A : {sparse matrix, ndarray, LinearOperator}
  436. The real or complex N-by-N matrix of the linear system.
  437. Alternatively, ``A`` can be a linear operator which can
  438. produce ``Ax`` using, e.g.,
  439. ``scipy.sparse.linalg.LinearOperator``.
  440. b : ndarray
  441. Right hand side of the linear system. Has shape (N,) or (N,1).
  442. Returns
  443. -------
  444. x : ndarray
  445. The converged solution.
  446. info : int
  447. Provides convergence information:
  448. * 0 : successful exit
  449. * >0 : convergence to tolerance not achieved, number of iterations
  450. * <0 : illegal input or breakdown
  451. Other parameters
  452. ----------------
  453. x0 : ndarray
  454. Starting guess for the solution (a vector of zeros by default).
  455. tol, atol : float, optional
  456. Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
  457. The default for ``atol`` is ``'legacy'``, which emulates
  458. a different legacy behavior.
  459. .. warning::
  460. The default value for `atol` will be changed in a future release.
  461. For future compatibility, specify `atol` explicitly.
  462. restart : int, optional
  463. Number of iterations between restarts. Larger values increase
  464. iteration cost, but may be necessary for convergence.
  465. Default is 20.
  466. maxiter : int, optional
  467. Maximum number of iterations (restart cycles). Iteration will stop
  468. after maxiter steps even if the specified tolerance has not been
  469. achieved.
  470. M : {sparse matrix, ndarray, LinearOperator}
  471. Inverse of the preconditioner of A. M should approximate the
  472. inverse of A and be easy to solve for (see Notes). Effective
  473. preconditioning dramatically improves the rate of convergence,
  474. which implies that fewer iterations are needed to reach a given
  475. error tolerance. By default, no preconditioner is used.
  476. callback : function
  477. User-supplied function to call after each iteration. It is called
  478. as `callback(args)`, where `args` are selected by `callback_type`.
  479. callback_type : {'x', 'pr_norm', 'legacy'}, optional
  480. Callback function argument requested:
  481. - ``x``: current iterate (ndarray), called on every restart
  482. - ``pr_norm``: relative (preconditioned) residual norm (float),
  483. called on every inner iteration
  484. - ``legacy`` (default): same as ``pr_norm``, but also changes the
  485. meaning of 'maxiter' to count inner iterations instead of restart
  486. cycles.
  487. restrt : int, optional, deprecated
  488. .. deprecated:: 0.11.0
  489. `gmres` keyword argument `restrt` is deprecated infavour of
  490. `restart` and will be removed in SciPy 1.12.0.
  491. See Also
  492. --------
  493. LinearOperator
  494. Notes
  495. -----
  496. A preconditioner, P, is chosen such that P is close to A but easy to solve
  497. for. The preconditioner parameter required by this routine is
  498. ``M = P^-1``. The inverse should preferably not be calculated
  499. explicitly. Rather, use the following template to produce M::
  500. # Construct a linear operator that computes P^-1 @ x.
  501. import scipy.sparse.linalg as spla
  502. M_x = lambda x: spla.spsolve(P, x)
  503. M = spla.LinearOperator((n, n), M_x)
  504. Examples
  505. --------
  506. >>> import numpy as np
  507. >>> from scipy.sparse import csc_matrix
  508. >>> from scipy.sparse.linalg import gmres
  509. >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
  510. >>> b = np.array([2, 4, -1], dtype=float)
  511. >>> x, exitCode = gmres(A, b)
  512. >>> print(exitCode) # 0 indicates successful convergence
  513. 0
  514. >>> np.allclose(A.dot(x), b)
  515. True
  516. """
  517. # Change 'restrt' keyword to 'restart'
  518. if restrt is None:
  519. restrt = restart
  520. elif restart is not None:
  521. raise ValueError("Cannot specify both restart and restrt keywords. "
  522. "Preferably use 'restart' only.")
  523. else:
  524. msg = ("'gmres' keyword argument 'restrt' is deprecated infavour of "
  525. "'restart' and will be removed in SciPy 1.12.0.")
  526. warnings.warn(msg, DeprecationWarning, stacklevel=2)
  527. if callback is not None and callback_type is None:
  528. # Warn about 'callback_type' semantic changes.
  529. # Probably should be removed only in far future, Scipy 2.0 or so.
  530. warnings.warn("scipy.sparse.linalg.gmres called without specifying `callback_type`. "
  531. "The default value will be changed in a future release. "
  532. "For compatibility, specify a value for `callback_type` explicitly, e.g., "
  533. "``{name}(..., callback_type='pr_norm')``, or to retain the old behavior "
  534. "``{name}(..., callback_type='legacy')``",
  535. category=DeprecationWarning, stacklevel=3)
  536. if callback_type is None:
  537. callback_type = 'legacy'
  538. if callback_type not in ('x', 'pr_norm', 'legacy'):
  539. raise ValueError("Unknown callback_type: {!r}".format(callback_type))
  540. if callback is None:
  541. callback_type = 'none'
  542. A, M, x, b,postprocess = make_system(A, M, x0, b)
  543. n = len(b)
  544. if maxiter is None:
  545. maxiter = n*10
  546. if restrt is None:
  547. restrt = 20
  548. restrt = min(restrt, n)
  549. matvec = A.matvec
  550. psolve = M.matvec
  551. ltr = _type_conv[x.dtype.char]
  552. revcom = getattr(_iterative, ltr + 'gmresrevcom')
  553. bnrm2 = np.linalg.norm(b)
  554. Mb_nrm2 = np.linalg.norm(psolve(b))
  555. get_residual = lambda: np.linalg.norm(matvec(x) - b)
  556. atol = _get_atol(tol, atol, bnrm2, get_residual, 'gmres')
  557. if atol == 'exit':
  558. return postprocess(x), 0
  559. if bnrm2 == 0:
  560. return postprocess(b), 0
  561. # Tolerance passed to GMRESREVCOM applies to the inner iteration
  562. # and deals with the left-preconditioned residual.
  563. ptol_max_factor = 1.0
  564. ptol = Mb_nrm2 * min(ptol_max_factor, atol / bnrm2)
  565. resid = np.nan
  566. presid = np.nan
  567. ndx1 = 1
  568. ndx2 = -1
  569. # Use _aligned_zeros to work around a f2py bug in Numpy 1.9.1
  570. work = _aligned_zeros((6+restrt)*n,dtype=x.dtype)
  571. work2 = _aligned_zeros((restrt+1)*(2*restrt+2),dtype=x.dtype)
  572. ijob = 1
  573. info = 0
  574. ftflag = True
  575. iter_ = maxiter
  576. old_ijob = ijob
  577. first_pass = True
  578. resid_ready = False
  579. iter_num = 1
  580. while True:
  581. olditer = iter_
  582. x, iter_, presid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
  583. revcom(b, x, restrt, work, work2, iter_, presid, info, ndx1, ndx2, ijob, ptol)
  584. if callback_type == 'x' and iter_ != olditer:
  585. callback(x)
  586. slice1 = slice(ndx1-1, ndx1-1+n)
  587. slice2 = slice(ndx2-1, ndx2-1+n)
  588. if (ijob == -1): # gmres success, update last residual
  589. if callback_type in ('pr_norm', 'legacy'):
  590. if resid_ready:
  591. callback(presid / bnrm2)
  592. elif callback_type == 'x':
  593. callback(x)
  594. break
  595. elif (ijob == 1):
  596. work[slice2] *= sclr2
  597. work[slice2] += sclr1*matvec(x)
  598. elif (ijob == 2):
  599. work[slice1] = psolve(work[slice2])
  600. if not first_pass and old_ijob == 3:
  601. resid_ready = True
  602. first_pass = False
  603. elif (ijob == 3):
  604. work[slice2] *= sclr2
  605. work[slice2] += sclr1*matvec(work[slice1])
  606. if resid_ready:
  607. if callback_type in ('pr_norm', 'legacy'):
  608. callback(presid / bnrm2)
  609. resid_ready = False
  610. iter_num = iter_num+1
  611. elif (ijob == 4):
  612. if ftflag:
  613. info = -1
  614. ftflag = False
  615. resid, info = _stoptest(work[slice1], atol)
  616. # Inner loop tolerance control
  617. if info or presid > ptol:
  618. ptol_max_factor = min(1.0, 1.5 * ptol_max_factor)
  619. else:
  620. # Inner loop tolerance OK, but outer loop not.
  621. ptol_max_factor = max(1e-16, 0.25 * ptol_max_factor)
  622. if resid != 0:
  623. ptol = presid * min(ptol_max_factor, atol / resid)
  624. else:
  625. ptol = presid * ptol_max_factor
  626. old_ijob = ijob
  627. ijob = 2
  628. if callback_type == 'legacy':
  629. # Legacy behavior
  630. if iter_num > maxiter:
  631. info = maxiter
  632. break
  633. if info >= 0 and not (resid <= atol):
  634. # info isn't set appropriately otherwise
  635. info = maxiter
  636. return postprocess(x), info
  637. @non_reentrant()
  638. def qmr(A, b, x0=None, tol=1e-5, maxiter=None, M1=None, M2=None, callback=None,
  639. atol=None):
  640. """Use Quasi-Minimal Residual iteration to solve ``Ax = b``.
  641. Parameters
  642. ----------
  643. A : {sparse matrix, ndarray, LinearOperator}
  644. The real-valued N-by-N matrix of the linear system.
  645. Alternatively, ``A`` can be a linear operator which can
  646. produce ``Ax`` and ``A^T x`` using, e.g.,
  647. ``scipy.sparse.linalg.LinearOperator``.
  648. b : ndarray
  649. Right hand side of the linear system. Has shape (N,) or (N,1).
  650. Returns
  651. -------
  652. x : ndarray
  653. The converged solution.
  654. info : integer
  655. Provides convergence information:
  656. 0 : successful exit
  657. >0 : convergence to tolerance not achieved, number of iterations
  658. <0 : illegal input or breakdown
  659. Other Parameters
  660. ----------------
  661. x0 : ndarray
  662. Starting guess for the solution.
  663. tol, atol : float, optional
  664. Tolerances for convergence, ``norm(residual) <= max(tol*norm(b), atol)``.
  665. The default for ``atol`` is ``'legacy'``, which emulates
  666. a different legacy behavior.
  667. .. warning::
  668. The default value for `atol` will be changed in a future release.
  669. For future compatibility, specify `atol` explicitly.
  670. maxiter : integer
  671. Maximum number of iterations. Iteration will stop after maxiter
  672. steps even if the specified tolerance has not been achieved.
  673. M1 : {sparse matrix, ndarray, LinearOperator}
  674. Left preconditioner for A.
  675. M2 : {sparse matrix, ndarray, LinearOperator}
  676. Right preconditioner for A. Used together with the left
  677. preconditioner M1. The matrix M1@A@M2 should have better
  678. conditioned than A alone.
  679. callback : function
  680. User-supplied function to call after each iteration. It is called
  681. as callback(xk), where xk is the current solution vector.
  682. See Also
  683. --------
  684. LinearOperator
  685. Examples
  686. --------
  687. >>> import numpy as np
  688. >>> from scipy.sparse import csc_matrix
  689. >>> from scipy.sparse.linalg import qmr
  690. >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
  691. >>> b = np.array([2, 4, -1], dtype=float)
  692. >>> x, exitCode = qmr(A, b)
  693. >>> print(exitCode) # 0 indicates successful convergence
  694. 0
  695. >>> np.allclose(A.dot(x), b)
  696. True
  697. """
  698. A_ = A
  699. A, M, x, b, postprocess = make_system(A, None, x0, b)
  700. if M1 is None and M2 is None:
  701. if hasattr(A_,'psolve'):
  702. def left_psolve(b):
  703. return A_.psolve(b,'left')
  704. def right_psolve(b):
  705. return A_.psolve(b,'right')
  706. def left_rpsolve(b):
  707. return A_.rpsolve(b,'left')
  708. def right_rpsolve(b):
  709. return A_.rpsolve(b,'right')
  710. M1 = LinearOperator(A.shape, matvec=left_psolve, rmatvec=left_rpsolve)
  711. M2 = LinearOperator(A.shape, matvec=right_psolve, rmatvec=right_rpsolve)
  712. else:
  713. def id(b):
  714. return b
  715. M1 = LinearOperator(A.shape, matvec=id, rmatvec=id)
  716. M2 = LinearOperator(A.shape, matvec=id, rmatvec=id)
  717. n = len(b)
  718. if maxiter is None:
  719. maxiter = n*10
  720. ltr = _type_conv[x.dtype.char]
  721. revcom = getattr(_iterative, ltr + 'qmrrevcom')
  722. get_residual = lambda: np.linalg.norm(A.matvec(x) - b)
  723. atol = _get_atol(tol, atol, np.linalg.norm(b), get_residual, 'qmr')
  724. if atol == 'exit':
  725. return postprocess(x), 0
  726. resid = atol
  727. ndx1 = 1
  728. ndx2 = -1
  729. # Use _aligned_zeros to work around a f2py bug in Numpy 1.9.1
  730. work = _aligned_zeros(11*n,x.dtype)
  731. ijob = 1
  732. info = 0
  733. ftflag = True
  734. iter_ = maxiter
  735. while True:
  736. olditer = iter_
  737. x, iter_, resid, info, ndx1, ndx2, sclr1, sclr2, ijob = \
  738. revcom(b, x, work, iter_, resid, info, ndx1, ndx2, ijob)
  739. if callback is not None and iter_ > olditer:
  740. callback(x)
  741. slice1 = slice(ndx1-1, ndx1-1+n)
  742. slice2 = slice(ndx2-1, ndx2-1+n)
  743. if (ijob == -1):
  744. if callback is not None:
  745. callback(x)
  746. break
  747. elif (ijob == 1):
  748. work[slice2] *= sclr2
  749. work[slice2] += sclr1*A.matvec(work[slice1])
  750. elif (ijob == 2):
  751. work[slice2] *= sclr2
  752. work[slice2] += sclr1*A.rmatvec(work[slice1])
  753. elif (ijob == 3):
  754. work[slice1] = M1.matvec(work[slice2])
  755. elif (ijob == 4):
  756. work[slice1] = M2.matvec(work[slice2])
  757. elif (ijob == 5):
  758. work[slice1] = M1.rmatvec(work[slice2])
  759. elif (ijob == 6):
  760. work[slice1] = M2.rmatvec(work[slice2])
  761. elif (ijob == 7):
  762. work[slice2] *= sclr2
  763. work[slice2] += sclr1*A.matvec(x)
  764. elif (ijob == 8):
  765. if ftflag:
  766. info = -1
  767. ftflag = False
  768. resid, info = _stoptest(work[slice1], atol)
  769. ijob = 2
  770. if info > 0 and iter_ == maxiter and not (resid <= atol):
  771. # info isn't set appropriately otherwise
  772. info = iter_
  773. return postprocess(x), info