test_gcrotmk.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #!/usr/bin/env python
  2. """Tests for the linalg._isolve.gcrotmk module
  3. """
  4. from numpy.testing import (assert_, assert_allclose, assert_equal,
  5. suppress_warnings)
  6. import numpy as np
  7. from numpy import zeros, array, allclose
  8. from scipy.linalg import norm
  9. from scipy.sparse import csr_matrix, eye, rand
  10. from scipy.sparse.linalg._interface import LinearOperator
  11. from scipy.sparse.linalg import splu
  12. from scipy.sparse.linalg._isolve import gcrotmk, gmres
  13. Am = csr_matrix(array([[-2,1,0,0,0,9],
  14. [1,-2,1,0,5,0],
  15. [0,1,-2,1,0,0],
  16. [0,0,1,-2,1,0],
  17. [0,3,0,1,-2,1],
  18. [1,0,0,0,1,-2]]))
  19. b = array([1,2,3,4,5,6])
  20. count = [0]
  21. def matvec(v):
  22. count[0] += 1
  23. return Am@v
  24. A = LinearOperator(matvec=matvec, shape=Am.shape, dtype=Am.dtype)
  25. def do_solve(**kw):
  26. count[0] = 0
  27. with suppress_warnings() as sup:
  28. sup.filter(DeprecationWarning, ".*called without specifying.*")
  29. x0, flag = gcrotmk(A, b, x0=zeros(A.shape[0]), tol=1e-14, **kw)
  30. count_0 = count[0]
  31. assert_(allclose(A@x0, b, rtol=1e-12, atol=1e-12), norm(A@x0-b))
  32. return x0, count_0
  33. class TestGCROTMK:
  34. def test_preconditioner(self):
  35. # Check that preconditioning works
  36. pc = splu(Am.tocsc())
  37. M = LinearOperator(matvec=pc.solve, shape=A.shape, dtype=A.dtype)
  38. x0, count_0 = do_solve()
  39. x1, count_1 = do_solve(M=M)
  40. assert_equal(count_1, 3)
  41. assert_(count_1 < count_0/2)
  42. assert_(allclose(x1, x0, rtol=1e-14))
  43. def test_arnoldi(self):
  44. np.random.seed(1)
  45. A = eye(2000) + rand(2000, 2000, density=5e-4)
  46. b = np.random.rand(2000)
  47. # The inner arnoldi should be equivalent to gmres
  48. with suppress_warnings() as sup:
  49. sup.filter(DeprecationWarning, ".*called without specifying.*")
  50. x0, flag0 = gcrotmk(A, b, x0=zeros(A.shape[0]), m=15, k=0, maxiter=1)
  51. x1, flag1 = gmres(A, b, x0=zeros(A.shape[0]), restart=15, maxiter=1)
  52. assert_equal(flag0, 1)
  53. assert_equal(flag1, 1)
  54. assert np.linalg.norm(A.dot(x0) - b) > 1e-3
  55. assert_allclose(x0, x1)
  56. def test_cornercase(self):
  57. np.random.seed(1234)
  58. # Rounding error may prevent convergence with tol=0 --- ensure
  59. # that the return values in this case are correct, and no
  60. # exceptions are raised
  61. for n in [3, 5, 10, 100]:
  62. A = 2*eye(n)
  63. with suppress_warnings() as sup:
  64. sup.filter(DeprecationWarning, ".*called without specifying.*")
  65. b = np.ones(n)
  66. x, info = gcrotmk(A, b, maxiter=10)
  67. assert_equal(info, 0)
  68. assert_allclose(A.dot(x) - b, 0, atol=1e-14)
  69. x, info = gcrotmk(A, b, tol=0, maxiter=10)
  70. if info == 0:
  71. assert_allclose(A.dot(x) - b, 0, atol=1e-14)
  72. b = np.random.rand(n)
  73. x, info = gcrotmk(A, b, maxiter=10)
  74. assert_equal(info, 0)
  75. assert_allclose(A.dot(x) - b, 0, atol=1e-14)
  76. x, info = gcrotmk(A, b, tol=0, maxiter=10)
  77. if info == 0:
  78. assert_allclose(A.dot(x) - b, 0, atol=1e-14)
  79. def test_nans(self):
  80. A = eye(3, format='lil')
  81. A[1,1] = np.nan
  82. b = np.ones(3)
  83. with suppress_warnings() as sup:
  84. sup.filter(DeprecationWarning, ".*called without specifying.*")
  85. x, info = gcrotmk(A, b, tol=0, maxiter=10)
  86. assert_equal(info, 1)
  87. def test_truncate(self):
  88. np.random.seed(1234)
  89. A = np.random.rand(30, 30) + np.eye(30)
  90. b = np.random.rand(30)
  91. for truncate in ['oldest', 'smallest']:
  92. with suppress_warnings() as sup:
  93. sup.filter(DeprecationWarning, ".*called without specifying.*")
  94. x, info = gcrotmk(A, b, m=10, k=10, truncate=truncate, tol=1e-4,
  95. maxiter=200)
  96. assert_equal(info, 0)
  97. assert_allclose(A.dot(x) - b, 0, atol=1e-3)
  98. def test_CU(self):
  99. for discard_C in (True, False):
  100. # Check that C,U behave as expected
  101. CU = []
  102. x0, count_0 = do_solve(CU=CU, discard_C=discard_C)
  103. assert_(len(CU) > 0)
  104. assert_(len(CU) <= 6)
  105. if discard_C:
  106. for c, u in CU:
  107. assert_(c is None)
  108. # should converge immediately
  109. x1, count_1 = do_solve(CU=CU, discard_C=discard_C)
  110. if discard_C:
  111. assert_equal(count_1, 2 + len(CU))
  112. else:
  113. assert_equal(count_1, 3)
  114. assert_(count_1 <= count_0/2)
  115. assert_allclose(x1, x0, atol=1e-14)
  116. def test_denormals(self):
  117. # Check that no warnings are emitted if the matrix contains
  118. # numbers for which 1/x has no float representation, and that
  119. # the solver behaves properly.
  120. A = np.array([[1, 2], [3, 4]], dtype=float)
  121. A *= 100 * np.nextafter(0, 1)
  122. b = np.array([1, 1])
  123. with suppress_warnings() as sup:
  124. sup.filter(DeprecationWarning, ".*called without specifying.*")
  125. xp, info = gcrotmk(A, b)
  126. if info == 0:
  127. assert_allclose(A.dot(xp), b)