test_minres.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import numpy as np
  2. from numpy.testing import assert_equal, assert_allclose, assert_
  3. from scipy.sparse.linalg._isolve import minres
  4. from pytest import raises as assert_raises
  5. from .test_iterative import assert_normclose
  6. def get_sample_problem():
  7. # A random 10 x 10 symmetric matrix
  8. np.random.seed(1234)
  9. matrix = np.random.rand(10, 10)
  10. matrix = matrix + matrix.T
  11. # A random vector of length 10
  12. vector = np.random.rand(10)
  13. return matrix, vector
  14. def test_singular():
  15. A, b = get_sample_problem()
  16. A[0, ] = 0
  17. b[0] = 0
  18. xp, info = minres(A, b)
  19. assert_equal(info, 0)
  20. assert_normclose(A.dot(xp), b, tol=1e-5)
  21. def test_x0_is_used_by():
  22. A, b = get_sample_problem()
  23. # Random x0 to feed minres
  24. np.random.seed(12345)
  25. x0 = np.random.rand(10)
  26. trace = []
  27. def trace_iterates(xk):
  28. trace.append(xk)
  29. minres(A, b, x0=x0, callback=trace_iterates)
  30. trace_with_x0 = trace
  31. trace = []
  32. minres(A, b, callback=trace_iterates)
  33. assert_(not np.array_equal(trace_with_x0[0], trace[0]))
  34. def test_shift():
  35. A, b = get_sample_problem()
  36. shift = 0.5
  37. shifted_A = A - shift * np.eye(10)
  38. x1, info1 = minres(A, b, shift=shift)
  39. x2, info2 = minres(shifted_A, b)
  40. assert_equal(info1, 0)
  41. assert_allclose(x1, x2, rtol=1e-5)
  42. def test_asymmetric_fail():
  43. """Asymmetric matrix should raise `ValueError` when check=True"""
  44. A, b = get_sample_problem()
  45. A[1, 2] = 1
  46. A[2, 1] = 2
  47. with assert_raises(ValueError):
  48. xp, info = minres(A, b, check=True)
  49. def test_minres_non_default_x0():
  50. np.random.seed(1234)
  51. tol = 10**(-6)
  52. a = np.random.randn(5, 5)
  53. a = np.dot(a, a.T)
  54. b = np.random.randn(5)
  55. c = np.random.randn(5)
  56. x = minres(a, b, x0=c, tol=tol)[0]
  57. assert_normclose(a.dot(x), b, tol=tol)
  58. def test_minres_precond_non_default_x0():
  59. np.random.seed(12345)
  60. tol = 10**(-6)
  61. a = np.random.randn(5, 5)
  62. a = np.dot(a, a.T)
  63. b = np.random.randn(5)
  64. c = np.random.randn(5)
  65. m = np.random.randn(5, 5)
  66. m = np.dot(m, m.T)
  67. x = minres(a, b, M=m, x0=c, tol=tol)[0]
  68. assert_normclose(a.dot(x), b, tol=tol)
  69. def test_minres_precond_exact_x0():
  70. np.random.seed(1234)
  71. tol = 10**(-6)
  72. a = np.eye(10)
  73. b = np.ones(10)
  74. c = np.ones(10)
  75. m = np.random.randn(10, 10)
  76. m = np.dot(m, m.T)
  77. x = minres(a, b, M=m, x0=c, tol=tol)[0]
  78. assert_normclose(a.dot(x), b, tol=tol)