test_iterative.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. """ Test functions for the sparse.linalg._isolve module
  2. """
  3. import itertools
  4. import platform
  5. import sys
  6. import numpy as np
  7. from numpy.testing import (assert_equal, assert_array_equal,
  8. assert_, assert_allclose, suppress_warnings)
  9. import pytest
  10. from pytest import raises as assert_raises
  11. from numpy import zeros, arange, array, ones, eye, iscomplexobj
  12. from scipy.linalg import norm
  13. from scipy.sparse import spdiags, csr_matrix, SparseEfficiencyWarning, kronsum
  14. from scipy.sparse.linalg import LinearOperator, aslinearoperator
  15. from scipy.sparse.linalg._isolve import cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres, gcrotmk, tfqmr
  16. # TODO check that method preserve shape and type
  17. # TODO test both preconditioner methods
  18. class Case:
  19. def __init__(self, name, A, b=None, skip=None, nonconvergence=None):
  20. self.name = name
  21. self.A = A
  22. if b is None:
  23. self.b = arange(A.shape[0], dtype=float)
  24. else:
  25. self.b = b
  26. if skip is None:
  27. self.skip = []
  28. else:
  29. self.skip = skip
  30. if nonconvergence is None:
  31. self.nonconvergence = []
  32. else:
  33. self.nonconvergence = nonconvergence
  34. def __repr__(self):
  35. return "<%s>" % self.name
  36. class IterativeParams:
  37. def __init__(self):
  38. # list of tuples (solver, symmetric, positive_definite )
  39. solvers = [cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres, gcrotmk, tfqmr]
  40. sym_solvers = [minres, cg]
  41. posdef_solvers = [cg]
  42. real_solvers = [minres]
  43. self.solvers = solvers
  44. # list of tuples (A, symmetric, positive_definite )
  45. self.cases = []
  46. # Symmetric and Positive Definite
  47. N = 40
  48. data = ones((3,N))
  49. data[0,:] = 2
  50. data[1,:] = -1
  51. data[2,:] = -1
  52. Poisson1D = spdiags(data, [0,-1,1], N, N, format='csr')
  53. self.Poisson1D = Case("poisson1d", Poisson1D)
  54. self.cases.append(Case("poisson1d", Poisson1D))
  55. # note: minres fails for single precision
  56. self.cases.append(Case("poisson1d", Poisson1D.astype('f'),
  57. skip=[minres]))
  58. # Symmetric and Negative Definite
  59. self.cases.append(Case("neg-poisson1d", -Poisson1D,
  60. skip=posdef_solvers))
  61. # note: minres fails for single precision
  62. self.cases.append(Case("neg-poisson1d", (-Poisson1D).astype('f'),
  63. skip=posdef_solvers + [minres]))
  64. # 2-dimensional Poisson equations
  65. Poisson2D = kronsum(Poisson1D, Poisson1D)
  66. self.Poisson2D = Case("poisson2d", Poisson2D)
  67. # note: minres fails for 2-d poisson problem, it will be fixed in the future PR
  68. self.cases.append(Case("poisson2d", Poisson2D, skip=[minres]))
  69. # note: minres fails for single precision
  70. self.cases.append(Case("poisson2d", Poisson2D.astype('f'),
  71. skip=[minres]))
  72. # Symmetric and Indefinite
  73. data = array([[6, -5, 2, 7, -1, 10, 4, -3, -8, 9]],dtype='d')
  74. RandDiag = spdiags(data, [0], 10, 10, format='csr')
  75. self.cases.append(Case("rand-diag", RandDiag, skip=posdef_solvers))
  76. self.cases.append(Case("rand-diag", RandDiag.astype('f'),
  77. skip=posdef_solvers))
  78. # Random real-valued
  79. np.random.seed(1234)
  80. data = np.random.rand(4, 4)
  81. self.cases.append(Case("rand", data, skip=posdef_solvers+sym_solvers))
  82. self.cases.append(Case("rand", data.astype('f'),
  83. skip=posdef_solvers+sym_solvers))
  84. # Random symmetric real-valued
  85. np.random.seed(1234)
  86. data = np.random.rand(4, 4)
  87. data = data + data.T
  88. self.cases.append(Case("rand-sym", data, skip=posdef_solvers))
  89. self.cases.append(Case("rand-sym", data.astype('f'),
  90. skip=posdef_solvers))
  91. # Random pos-def symmetric real
  92. np.random.seed(1234)
  93. data = np.random.rand(9, 9)
  94. data = np.dot(data.conj(), data.T)
  95. self.cases.append(Case("rand-sym-pd", data))
  96. # note: minres fails for single precision
  97. self.cases.append(Case("rand-sym-pd", data.astype('f'),
  98. skip=[minres]))
  99. # Random complex-valued
  100. np.random.seed(1234)
  101. data = np.random.rand(4, 4) + 1j*np.random.rand(4, 4)
  102. self.cases.append(Case("rand-cmplx", data,
  103. skip=posdef_solvers+sym_solvers+real_solvers))
  104. self.cases.append(Case("rand-cmplx", data.astype('F'),
  105. skip=posdef_solvers+sym_solvers+real_solvers))
  106. # Random hermitian complex-valued
  107. np.random.seed(1234)
  108. data = np.random.rand(4, 4) + 1j*np.random.rand(4, 4)
  109. data = data + data.T.conj()
  110. self.cases.append(Case("rand-cmplx-herm", data,
  111. skip=posdef_solvers+real_solvers))
  112. self.cases.append(Case("rand-cmplx-herm", data.astype('F'),
  113. skip=posdef_solvers+real_solvers))
  114. # Random pos-def hermitian complex-valued
  115. np.random.seed(1234)
  116. data = np.random.rand(9, 9) + 1j*np.random.rand(9, 9)
  117. data = np.dot(data.conj(), data.T)
  118. self.cases.append(Case("rand-cmplx-sym-pd", data, skip=real_solvers))
  119. self.cases.append(Case("rand-cmplx-sym-pd", data.astype('F'),
  120. skip=real_solvers))
  121. # Non-symmetric and Positive Definite
  122. #
  123. # cgs, qmr, bicg and tfqmr fail to converge on this one
  124. # -- algorithmic limitation apparently
  125. data = ones((2,10))
  126. data[0,:] = 2
  127. data[1,:] = -1
  128. A = spdiags(data, [0,-1], 10, 10, format='csr')
  129. self.cases.append(Case("nonsymposdef", A,
  130. skip=sym_solvers+[cgs, qmr, bicg, tfqmr]))
  131. self.cases.append(Case("nonsymposdef", A.astype('F'),
  132. skip=sym_solvers+[cgs, qmr, bicg, tfqmr]))
  133. # Symmetric, non-pd, hitting cgs/bicg/bicgstab/qmr/tfqmr breakdown
  134. A = np.array([[0, 0, 0, 0, 0, 1, -1, -0, -0, -0, -0],
  135. [0, 0, 0, 0, 0, 2, -0, -1, -0, -0, -0],
  136. [0, 0, 0, 0, 0, 2, -0, -0, -1, -0, -0],
  137. [0, 0, 0, 0, 0, 2, -0, -0, -0, -1, -0],
  138. [0, 0, 0, 0, 0, 1, -0, -0, -0, -0, -1],
  139. [1, 2, 2, 2, 1, 0, -0, -0, -0, -0, -0],
  140. [-1, 0, 0, 0, 0, 0, -1, -0, -0, -0, -0],
  141. [0, -1, 0, 0, 0, 0, -0, -1, -0, -0, -0],
  142. [0, 0, -1, 0, 0, 0, -0, -0, -1, -0, -0],
  143. [0, 0, 0, -1, 0, 0, -0, -0, -0, -1, -0],
  144. [0, 0, 0, 0, -1, 0, -0, -0, -0, -0, -1]], dtype=float)
  145. b = np.array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], dtype=float)
  146. assert (A == A.T).all()
  147. self.cases.append(Case("sym-nonpd", A, b,
  148. skip=posdef_solvers,
  149. nonconvergence=[cgs,bicg,bicgstab,qmr,tfqmr]))
  150. params = IterativeParams()
  151. def check_maxiter(solver, case):
  152. A = case.A
  153. tol = 1e-12
  154. b = case.b
  155. x0 = 0*b
  156. residuals = []
  157. def callback(x):
  158. residuals.append(norm(b - case.A*x))
  159. x, info = solver(A, b, x0=x0, tol=tol, maxiter=1, callback=callback)
  160. assert_equal(len(residuals), 1)
  161. assert_equal(info, 1)
  162. def test_maxiter():
  163. for case in params.cases:
  164. for solver in params.solvers:
  165. if solver in case.skip + case.nonconvergence:
  166. continue
  167. with suppress_warnings() as sup:
  168. sup.filter(DeprecationWarning, ".*called without specifying.*")
  169. check_maxiter(solver, case)
  170. def assert_normclose(a, b, tol=1e-8):
  171. residual = norm(a - b)
  172. tolerance = tol * norm(b)
  173. msg = f"residual ({residual}) not smaller than tolerance ({tolerance})"
  174. assert_(residual < tolerance, msg=msg)
  175. def check_convergence(solver, case):
  176. A = case.A
  177. if A.dtype.char in "dD":
  178. tol = 1e-8
  179. else:
  180. tol = 1e-2
  181. b = case.b
  182. x0 = 0*b
  183. x, info = solver(A, b, x0=x0, tol=tol)
  184. assert_array_equal(x0, 0*b) # ensure that x0 is not overwritten
  185. if solver not in case.nonconvergence:
  186. assert_equal(info,0)
  187. assert_normclose(A.dot(x), b, tol=tol)
  188. else:
  189. assert_(info != 0)
  190. assert_(np.linalg.norm(A.dot(x) - b) <= np.linalg.norm(b))
  191. def test_convergence():
  192. for solver in params.solvers:
  193. for case in params.cases:
  194. if solver in case.skip:
  195. continue
  196. with suppress_warnings() as sup:
  197. sup.filter(DeprecationWarning, ".*called without specifying.*")
  198. check_convergence(solver, case)
  199. def check_precond_dummy(solver, case):
  200. tol = 1e-8
  201. def identity(b,which=None):
  202. """trivial preconditioner"""
  203. return b
  204. A = case.A
  205. M,N = A.shape
  206. # Ensure the diagonal elements of A are non-zero before calculating
  207. # 1.0/A.diagonal()
  208. diagOfA = A.diagonal()
  209. if np.count_nonzero(diagOfA) == len(diagOfA):
  210. spdiags([1.0/diagOfA], [0], M, N)
  211. b = case.b
  212. x0 = 0*b
  213. precond = LinearOperator(A.shape, identity, rmatvec=identity)
  214. if solver is qmr:
  215. x, info = solver(A, b, M1=precond, M2=precond, x0=x0, tol=tol)
  216. else:
  217. x, info = solver(A, b, M=precond, x0=x0, tol=tol)
  218. assert_equal(info,0)
  219. assert_normclose(A.dot(x), b, tol)
  220. A = aslinearoperator(A)
  221. A.psolve = identity
  222. A.rpsolve = identity
  223. x, info = solver(A, b, x0=x0, tol=tol)
  224. assert_equal(info,0)
  225. assert_normclose(A@x, b, tol=tol)
  226. def test_precond_dummy():
  227. for case in params.cases:
  228. for solver in params.solvers:
  229. if solver in case.skip + case.nonconvergence:
  230. continue
  231. with suppress_warnings() as sup:
  232. sup.filter(DeprecationWarning, ".*called without specifying.*")
  233. check_precond_dummy(solver, case)
  234. def check_precond_inverse(solver, case):
  235. tol = 1e-8
  236. def inverse(b,which=None):
  237. """inverse preconditioner"""
  238. A = case.A
  239. if not isinstance(A, np.ndarray):
  240. A = A.toarray()
  241. return np.linalg.solve(A, b)
  242. def rinverse(b,which=None):
  243. """inverse preconditioner"""
  244. A = case.A
  245. if not isinstance(A, np.ndarray):
  246. A = A.toarray()
  247. return np.linalg.solve(A.T, b)
  248. matvec_count = [0]
  249. def matvec(b):
  250. matvec_count[0] += 1
  251. return case.A.dot(b)
  252. def rmatvec(b):
  253. matvec_count[0] += 1
  254. return case.A.T.dot(b)
  255. b = case.b
  256. x0 = 0*b
  257. A = LinearOperator(case.A.shape, matvec, rmatvec=rmatvec)
  258. precond = LinearOperator(case.A.shape, inverse, rmatvec=rinverse)
  259. # Solve with preconditioner
  260. matvec_count = [0]
  261. x, info = solver(A, b, M=precond, x0=x0, tol=tol)
  262. assert_equal(info, 0)
  263. assert_normclose(case.A.dot(x), b, tol)
  264. # Solution should be nearly instant
  265. assert_(matvec_count[0] <= 3, repr(matvec_count))
  266. @pytest.mark.parametrize("case", [params.Poisson1D, params.Poisson2D])
  267. def test_precond_inverse(case):
  268. for solver in params.solvers:
  269. if solver in case.skip:
  270. continue
  271. if solver is qmr:
  272. continue
  273. with suppress_warnings() as sup:
  274. sup.filter(DeprecationWarning, ".*called without specifying.*")
  275. check_precond_inverse(solver, case)
  276. def test_reentrancy():
  277. non_reentrant = [cg, cgs, bicg, bicgstab, gmres, qmr]
  278. reentrant = [lgmres, minres, gcrotmk, tfqmr]
  279. for solver in reentrant + non_reentrant:
  280. with suppress_warnings() as sup:
  281. sup.filter(DeprecationWarning, ".*called without specifying.*")
  282. _check_reentrancy(solver, solver in reentrant)
  283. def _check_reentrancy(solver, is_reentrant):
  284. def matvec(x):
  285. A = np.array([[1.0, 0, 0], [0, 2.0, 0], [0, 0, 3.0]])
  286. y, info = solver(A, x)
  287. assert_equal(info, 0)
  288. return y
  289. b = np.array([1, 1./2, 1./3])
  290. op = LinearOperator((3, 3), matvec=matvec, rmatvec=matvec,
  291. dtype=b.dtype)
  292. if not is_reentrant:
  293. assert_raises(RuntimeError, solver, op, b)
  294. else:
  295. y, info = solver(op, b)
  296. assert_equal(info, 0)
  297. assert_allclose(y, [1, 1, 1])
  298. @pytest.mark.parametrize("solver", [cg, cgs, bicg, bicgstab, gmres, qmr, lgmres, gcrotmk])
  299. def test_atol(solver):
  300. # TODO: minres. It didn't historically use absolute tolerances, so
  301. # fixing it is less urgent.
  302. np.random.seed(1234)
  303. A = np.random.rand(10, 10)
  304. A = A.dot(A.T) + 10 * np.eye(10)
  305. b = 1e3 * np.random.rand(10)
  306. b_norm = np.linalg.norm(b)
  307. tols = np.r_[0, np.logspace(np.log10(1e-10), np.log10(1e2), 7), np.inf]
  308. # Check effect of badly scaled preconditioners
  309. M0 = np.random.randn(10, 10)
  310. M0 = M0.dot(M0.T)
  311. Ms = [None, 1e-6 * M0, 1e6 * M0]
  312. for M, tol, atol in itertools.product(Ms, tols, tols):
  313. if tol == 0 and atol == 0:
  314. continue
  315. if solver is qmr:
  316. if M is not None:
  317. M = aslinearoperator(M)
  318. M2 = aslinearoperator(np.eye(10))
  319. else:
  320. M2 = None
  321. x, info = solver(A, b, M1=M, M2=M2, tol=tol, atol=atol)
  322. else:
  323. x, info = solver(A, b, M=M, tol=tol, atol=atol)
  324. assert_equal(info, 0)
  325. residual = A.dot(x) - b
  326. err = np.linalg.norm(residual)
  327. atol2 = tol * b_norm
  328. # Added 1.00025 fudge factor because of `err` exceeding `atol` just
  329. # very slightly on s390x (see gh-17839)
  330. assert_(err <= 1.00025 * max(atol, atol2))
  331. @pytest.mark.parametrize("solver", [cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres, gcrotmk, tfqmr])
  332. def test_zero_rhs(solver):
  333. np.random.seed(1234)
  334. A = np.random.rand(10, 10)
  335. A = A.dot(A.T) + 10 * np.eye(10)
  336. b = np.zeros(10)
  337. tols = np.r_[np.logspace(np.log10(1e-10), np.log10(1e2), 7)]
  338. for tol in tols:
  339. with suppress_warnings() as sup:
  340. sup.filter(DeprecationWarning, ".*called without specifying.*")
  341. x, info = solver(A, b, tol=tol)
  342. assert_equal(info, 0)
  343. assert_allclose(x, 0, atol=1e-15)
  344. x, info = solver(A, b, tol=tol, x0=ones(10))
  345. assert_equal(info, 0)
  346. assert_allclose(x, 0, atol=tol)
  347. if solver is not minres:
  348. x, info = solver(A, b, tol=tol, atol=0, x0=ones(10))
  349. if info == 0:
  350. assert_allclose(x, 0)
  351. x, info = solver(A, b, tol=tol, atol=tol)
  352. assert_equal(info, 0)
  353. assert_allclose(x, 0, atol=1e-300)
  354. x, info = solver(A, b, tol=tol, atol=0)
  355. assert_equal(info, 0)
  356. assert_allclose(x, 0, atol=1e-300)
  357. @pytest.mark.parametrize("solver", [
  358. pytest.param(gmres, marks=pytest.mark.xfail(platform.machine() == 'aarch64'
  359. and sys.version_info[1] == 9,
  360. reason="gh-13019")),
  361. qmr,
  362. pytest.param(lgmres, marks=pytest.mark.xfail(
  363. platform.machine() not in ['x86_64' 'x86', 'aarch64', 'arm64'],
  364. reason="fails on at least ppc64le, ppc64 and riscv64, see gh-17839")
  365. ),
  366. pytest.param(cgs, marks=pytest.mark.xfail),
  367. pytest.param(bicg, marks=pytest.mark.xfail),
  368. pytest.param(bicgstab, marks=pytest.mark.xfail),
  369. pytest.param(gcrotmk, marks=pytest.mark.xfail),
  370. pytest.param(tfqmr, marks=pytest.mark.xfail)])
  371. def test_maxiter_worsening(solver):
  372. # Check error does not grow (boundlessly) with increasing maxiter.
  373. # This can occur due to the solvers hitting close to breakdown,
  374. # which they should detect and halt as necessary.
  375. # cf. gh-9100
  376. # Singular matrix, rhs numerically not in range
  377. A = np.array([[-0.1112795288033378, 0, 0, 0.16127952880333685],
  378. [0, -0.13627952880333782+6.283185307179586j, 0, 0],
  379. [0, 0, -0.13627952880333782-6.283185307179586j, 0],
  380. [0.1112795288033368, 0j, 0j, -0.16127952880333785]])
  381. v = np.ones(4)
  382. best_error = np.inf
  383. tol = 7 if platform.machine() == 'aarch64' else 5
  384. for maxiter in range(1, 20):
  385. x, info = solver(A, v, maxiter=maxiter, tol=1e-8, atol=0)
  386. if info == 0:
  387. assert_(np.linalg.norm(A.dot(x) - v) <= 1e-8*np.linalg.norm(v))
  388. error = np.linalg.norm(A.dot(x) - v)
  389. best_error = min(best_error, error)
  390. # Check with slack
  391. assert_(error <= tol*best_error)
  392. @pytest.mark.parametrize("solver", [cg, cgs, bicg, bicgstab, gmres, qmr, minres, lgmres, gcrotmk, tfqmr])
  393. def test_x0_working(solver):
  394. # Easy problem
  395. np.random.seed(1)
  396. n = 10
  397. A = np.random.rand(n, n)
  398. A = A.dot(A.T)
  399. b = np.random.rand(n)
  400. x0 = np.random.rand(n)
  401. if solver is minres:
  402. kw = dict(tol=1e-6)
  403. else:
  404. kw = dict(atol=0, tol=1e-6)
  405. x, info = solver(A, b, **kw)
  406. assert_equal(info, 0)
  407. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-6*np.linalg.norm(b))
  408. x, info = solver(A, b, x0=x0, **kw)
  409. assert_equal(info, 0)
  410. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-6*np.linalg.norm(b))
  411. @pytest.mark.parametrize('solver', [cg, cgs, bicg, bicgstab, gmres, qmr,
  412. minres, lgmres, gcrotmk])
  413. def test_x0_equals_Mb(solver):
  414. for case in params.cases:
  415. if solver in case.skip:
  416. continue
  417. with suppress_warnings() as sup:
  418. sup.filter(DeprecationWarning, ".*called without specifying.*")
  419. A = case.A
  420. b = case.b
  421. x0 = 'Mb'
  422. tol = 1e-8
  423. x, info = solver(A, b, x0=x0, tol=tol)
  424. assert_array_equal(x0, 'Mb') # ensure that x0 is not overwritten
  425. assert_equal(info, 0)
  426. assert_normclose(A.dot(x), b, tol=tol)
  427. @pytest.mark.parametrize(('solver', 'solverstring'), [(tfqmr, 'TFQMR')])
  428. def test_show(solver, solverstring, capsys):
  429. def cb(x):
  430. count[0] += 1
  431. for i in [0, 20]:
  432. case = params.cases[i]
  433. A = case.A
  434. b = case.b
  435. count = [0]
  436. x, info = solver(A, b, callback=cb, show=True)
  437. out, err = capsys.readouterr()
  438. if i == 20: # Asymmetric and Positive Definite
  439. assert_equal(out, f"{solverstring}: Linear solve not converged "
  440. f"due to reach MAXIT iterations {count[0]}\n")
  441. else: # 1-D Poisson equations
  442. assert_equal(out, f"{solverstring}: Linear solve converged due to "
  443. f"reach TOL iterations {count[0]}\n")
  444. assert_equal(err, '')
  445. #------------------------------------------------------------------------------
  446. class TestQMR:
  447. def test_leftright_precond(self):
  448. """Check that QMR works with left and right preconditioners"""
  449. from scipy.sparse.linalg._dsolve import splu
  450. from scipy.sparse.linalg._interface import LinearOperator
  451. n = 100
  452. dat = ones(n)
  453. A = spdiags([-2*dat, 4*dat, -dat], [-1,0,1],n,n)
  454. b = arange(n,dtype='d')
  455. L = spdiags([-dat/2, dat], [-1,0], n, n)
  456. U = spdiags([4*dat, -dat], [0,1], n, n)
  457. with suppress_warnings() as sup:
  458. sup.filter(SparseEfficiencyWarning,
  459. "splu converted its input to CSC format")
  460. L_solver = splu(L)
  461. U_solver = splu(U)
  462. def L_solve(b):
  463. return L_solver.solve(b)
  464. def U_solve(b):
  465. return U_solver.solve(b)
  466. def LT_solve(b):
  467. return L_solver.solve(b,'T')
  468. def UT_solve(b):
  469. return U_solver.solve(b,'T')
  470. M1 = LinearOperator((n,n), matvec=L_solve, rmatvec=LT_solve)
  471. M2 = LinearOperator((n,n), matvec=U_solve, rmatvec=UT_solve)
  472. with suppress_warnings() as sup:
  473. sup.filter(DeprecationWarning, ".*called without specifying.*")
  474. x,info = qmr(A, b, tol=1e-8, maxiter=15, M1=M1, M2=M2)
  475. assert_equal(info,0)
  476. assert_normclose(A@x, b, tol=1e-8)
  477. class TestGMRES:
  478. def test_basic(self):
  479. A = np.vander(np.arange(10) + 1)[:, ::-1]
  480. b = np.zeros(10)
  481. b[0] = 1
  482. with suppress_warnings() as sup:
  483. sup.filter(DeprecationWarning, ".*called without specifying.*")
  484. x_gm, err = gmres(A, b, restart=5, maxiter=1)
  485. assert_allclose(x_gm[0], 0.359, rtol=1e-2)
  486. def test_callback(self):
  487. def store_residual(r, rvec):
  488. rvec[rvec.nonzero()[0].max()+1] = r
  489. # Define, A,b
  490. A = csr_matrix(array([[-2,1,0,0,0,0],[1,-2,1,0,0,0],[0,1,-2,1,0,0],[0,0,1,-2,1,0],[0,0,0,1,-2,1],[0,0,0,0,1,-2]]))
  491. b = ones((A.shape[0],))
  492. maxiter = 1
  493. rvec = zeros(maxiter+1)
  494. rvec[0] = 1.0
  495. callback = lambda r:store_residual(r, rvec)
  496. with suppress_warnings() as sup:
  497. sup.filter(DeprecationWarning, ".*called without specifying.*")
  498. x,flag = gmres(A, b, x0=zeros(A.shape[0]), tol=1e-16, maxiter=maxiter, callback=callback)
  499. # Expected output from SciPy 1.0.0
  500. assert_allclose(rvec, array([1.0, 0.81649658092772603]), rtol=1e-10)
  501. # Test preconditioned callback
  502. M = 1e-3 * np.eye(A.shape[0])
  503. rvec = zeros(maxiter+1)
  504. rvec[0] = 1.0
  505. with suppress_warnings() as sup:
  506. sup.filter(DeprecationWarning, ".*called without specifying.*")
  507. x, flag = gmres(A, b, M=M, tol=1e-16, maxiter=maxiter, callback=callback)
  508. # Expected output from SciPy 1.0.0 (callback has preconditioned residual!)
  509. assert_allclose(rvec, array([1.0, 1e-3 * 0.81649658092772603]), rtol=1e-10)
  510. def test_abi(self):
  511. # Check we don't segfault on gmres with complex argument
  512. A = eye(2)
  513. b = ones(2)
  514. with suppress_warnings() as sup:
  515. sup.filter(DeprecationWarning, ".*called without specifying.*")
  516. r_x, r_info = gmres(A, b)
  517. r_x = r_x.astype(complex)
  518. x, info = gmres(A.astype(complex), b.astype(complex))
  519. assert_(iscomplexobj(x))
  520. assert_allclose(r_x, x)
  521. assert_(r_info == info)
  522. def test_atol_legacy(self):
  523. with suppress_warnings() as sup:
  524. sup.filter(DeprecationWarning, ".*called without specifying.*")
  525. # Check the strange legacy behavior: the tolerance is interpreted
  526. # as atol, but only for the initial residual
  527. A = eye(2)
  528. b = 1e-6 * ones(2)
  529. x, info = gmres(A, b, tol=1e-5)
  530. assert_array_equal(x, np.zeros(2))
  531. A = eye(2)
  532. b = ones(2)
  533. x, info = gmres(A, b, tol=1e-5)
  534. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-5*np.linalg.norm(b))
  535. assert_allclose(x, b, atol=0, rtol=1e-8)
  536. rndm = np.random.RandomState(12345)
  537. A = rndm.rand(30, 30)
  538. b = 1e-6 * ones(30)
  539. x, info = gmres(A, b, tol=1e-7, restart=20)
  540. assert_(np.linalg.norm(A.dot(x) - b) > 1e-7)
  541. A = eye(2)
  542. b = 1e-10 * ones(2)
  543. x, info = gmres(A, b, tol=1e-8, atol=0)
  544. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-8*np.linalg.norm(b))
  545. def test_defective_precond_breakdown(self):
  546. # Breakdown due to defective preconditioner
  547. M = np.eye(3)
  548. M[2,2] = 0
  549. b = np.array([0, 1, 1])
  550. x = np.array([1, 0, 0])
  551. A = np.diag([2, 3, 4])
  552. x, info = gmres(A, b, x0=x, M=M, tol=1e-15, atol=0)
  553. # Should not return nans, nor terminate with false success
  554. assert_(not np.isnan(x).any())
  555. if info == 0:
  556. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-15*np.linalg.norm(b))
  557. # The solution should be OK outside null space of M
  558. assert_allclose(M.dot(A.dot(x)), M.dot(b))
  559. def test_defective_matrix_breakdown(self):
  560. # Breakdown due to defective matrix
  561. A = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 0]])
  562. b = np.array([1, 0, 1])
  563. x, info = gmres(A, b, tol=1e-8, atol=0)
  564. # Should not return nans, nor terminate with false success
  565. assert_(not np.isnan(x).any())
  566. if info == 0:
  567. assert_(np.linalg.norm(A.dot(x) - b) <= 1e-8*np.linalg.norm(b))
  568. # The solution should be OK outside null space of A
  569. assert_allclose(A.dot(A.dot(x)), A.dot(b))
  570. def test_callback_type(self):
  571. # The legacy callback type changes meaning of 'maxiter'
  572. np.random.seed(1)
  573. A = np.random.rand(20, 20)
  574. b = np.random.rand(20)
  575. cb_count = [0]
  576. def pr_norm_cb(r):
  577. cb_count[0] += 1
  578. assert_(isinstance(r, float))
  579. def x_cb(x):
  580. cb_count[0] += 1
  581. assert_(isinstance(x, np.ndarray))
  582. with suppress_warnings() as sup:
  583. sup.filter(DeprecationWarning, ".*called without specifying.*")
  584. # 2 iterations is not enough to solve the problem
  585. cb_count = [0]
  586. x, info = gmres(A, b, tol=1e-6, atol=0, callback=pr_norm_cb, maxiter=2, restart=50)
  587. assert info == 2
  588. assert cb_count[0] == 2
  589. # With `callback_type` specified, no warning should be raised
  590. cb_count = [0]
  591. x, info = gmres(A, b, tol=1e-6, atol=0, callback=pr_norm_cb, maxiter=2, restart=50,
  592. callback_type='legacy')
  593. assert info == 2
  594. assert cb_count[0] == 2
  595. # 2 restart cycles is enough to solve the problem
  596. cb_count = [0]
  597. x, info = gmres(A, b, tol=1e-6, atol=0, callback=pr_norm_cb, maxiter=2, restart=50,
  598. callback_type='pr_norm')
  599. assert info == 0
  600. assert cb_count[0] > 2
  601. # 2 restart cycles is enough to solve the problem
  602. cb_count = [0]
  603. x, info = gmres(A, b, tol=1e-6, atol=0, callback=x_cb, maxiter=2, restart=50,
  604. callback_type='x')
  605. assert info == 0
  606. assert cb_count[0] == 2
  607. def test_callback_x_monotonic(self):
  608. # Check that callback_type='x' gives monotonic norm decrease
  609. np.random.seed(1)
  610. A = np.random.rand(20, 20) + np.eye(20)
  611. b = np.random.rand(20)
  612. prev_r = [np.inf]
  613. count = [0]
  614. def x_cb(x):
  615. r = np.linalg.norm(A.dot(x) - b)
  616. assert r <= prev_r[0]
  617. prev_r[0] = r
  618. count[0] += 1
  619. x, info = gmres(A, b, tol=1e-6, atol=0, callback=x_cb, maxiter=20, restart=10,
  620. callback_type='x')
  621. assert info == 20
  622. assert count[0] == 21
  623. x_cb(x)
  624. def test_restrt_dep(self):
  625. with pytest.warns(
  626. DeprecationWarning,
  627. match="'gmres' keyword argument 'restrt'"
  628. ):
  629. gmres(np.array([1]), np.array([1]), restrt=10)