test_propack.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import os
  2. import pytest
  3. import sys
  4. import numpy as np
  5. from numpy.testing import assert_allclose
  6. from pytest import raises as assert_raises
  7. from scipy.sparse.linalg._svdp import _svdp
  8. from scipy.sparse import csr_matrix, csc_matrix
  9. # dtype_flavour to tolerance
  10. TOLS = {
  11. np.float32: 1e-4,
  12. np.float64: 1e-8,
  13. np.complex64: 1e-4,
  14. np.complex128: 1e-8,
  15. }
  16. def is_complex_type(dtype):
  17. return np.dtype(dtype).kind == "c"
  18. def is_32bit():
  19. return sys.maxsize <= 2**32 # (usually 2**31-1 on 32-bit)
  20. def is_windows():
  21. return 'win32' in sys.platform
  22. _dtypes = []
  23. for dtype_flavour in TOLS.keys():
  24. marks = []
  25. if is_complex_type(dtype_flavour):
  26. if is_32bit():
  27. # PROPACK has issues w/ complex on 32-bit; see gh-14433
  28. marks = [pytest.mark.skip]
  29. elif is_windows() and np.dtype(dtype_flavour).itemsize == 16:
  30. # windows crashes for complex128 (so don't xfail); see gh-15108
  31. marks = [pytest.mark.skip]
  32. else:
  33. marks = [pytest.mark.slow] # type: ignore[list-item]
  34. _dtypes.append(pytest.param(dtype_flavour, marks=marks,
  35. id=dtype_flavour.__name__))
  36. _dtypes = tuple(_dtypes) # type: ignore[assignment]
  37. def generate_matrix(constructor, n, m, f,
  38. dtype=float, rseed=0, **kwargs):
  39. """Generate a random sparse matrix"""
  40. rng = np.random.RandomState(rseed)
  41. if is_complex_type(dtype):
  42. M = (- 5 + 10 * rng.rand(n, m)
  43. - 5j + 10j * rng.rand(n, m)).astype(dtype)
  44. else:
  45. M = (-5 + 10 * rng.rand(n, m)).astype(dtype)
  46. M[M.real > 10 * f - 5] = 0
  47. return constructor(M, **kwargs)
  48. def assert_orthogonal(u1, u2, rtol, atol):
  49. """Check that the first k rows of u1 and u2 are orthogonal"""
  50. A = abs(np.dot(u1.conj().T, u2))
  51. assert_allclose(A, np.eye(u1.shape[1], u2.shape[1]), rtol=rtol, atol=atol)
  52. def check_svdp(n, m, constructor, dtype, k, irl_mode, which, f=0.8):
  53. tol = TOLS[dtype]
  54. M = generate_matrix(np.asarray, n, m, f, dtype)
  55. Msp = constructor(M)
  56. u1, sigma1, vt1 = np.linalg.svd(M, full_matrices=False)
  57. u2, sigma2, vt2, _ = _svdp(Msp, k=k, which=which, irl_mode=irl_mode,
  58. tol=tol)
  59. # check the which
  60. if which.upper() == 'SM':
  61. u1 = np.roll(u1, k, 1)
  62. vt1 = np.roll(vt1, k, 0)
  63. sigma1 = np.roll(sigma1, k)
  64. # check that singular values agree
  65. assert_allclose(sigma1[:k], sigma2, rtol=tol, atol=tol)
  66. # check that singular vectors are orthogonal
  67. assert_orthogonal(u1, u2, rtol=tol, atol=tol)
  68. assert_orthogonal(vt1.T, vt2.T, rtol=tol, atol=tol)
  69. @pytest.mark.parametrize('ctor', (np.array, csr_matrix, csc_matrix))
  70. @pytest.mark.parametrize('dtype', _dtypes)
  71. @pytest.mark.parametrize('irl', (True, False))
  72. @pytest.mark.parametrize('which', ('LM', 'SM'))
  73. def test_svdp(ctor, dtype, irl, which):
  74. np.random.seed(0)
  75. n, m, k = 10, 20, 3
  76. if which == 'SM' and not irl:
  77. message = "`which`='SM' requires irl_mode=True"
  78. with assert_raises(ValueError, match=message):
  79. check_svdp(n, m, ctor, dtype, k, irl, which)
  80. else:
  81. if is_32bit() and is_complex_type(dtype):
  82. message = 'PROPACK complex-valued SVD methods not available '
  83. with assert_raises(TypeError, match=message):
  84. check_svdp(n, m, ctor, dtype, k, irl, which)
  85. else:
  86. check_svdp(n, m, ctor, dtype, k, irl, which)
  87. @pytest.mark.parametrize('dtype', _dtypes)
  88. @pytest.mark.parametrize('irl', (False, True))
  89. @pytest.mark.timeout(120) # True, complex64 > 60 s: prerel deps cov 64bit blas
  90. def test_examples(dtype, irl):
  91. # Note: atol for complex64 bumped from 1e-4 to 1e-3 due to test failures
  92. # with BLIS, Netlib, and MKL+AVX512 - see
  93. # https://github.com/conda-forge/scipy-feedstock/pull/198#issuecomment-999180432
  94. atol = {
  95. np.float32: 1.3e-4,
  96. np.float64: 1e-9,
  97. np.complex64: 1e-3,
  98. np.complex128: 1e-9,
  99. }[dtype]
  100. path_prefix = os.path.dirname(__file__)
  101. # Test matrices from `illc1850.coord` and `mhd1280b.cua` distributed with
  102. # PROPACK 2.1: http://sun.stanford.edu/~rmunk/PROPACK/
  103. relative_path = "propack_test_data.npz"
  104. filename = os.path.join(path_prefix, relative_path)
  105. data = np.load(filename, allow_pickle=True)
  106. if is_complex_type(dtype):
  107. A = data['A_complex'].item().astype(dtype)
  108. else:
  109. A = data['A_real'].item().astype(dtype)
  110. k = 200
  111. u, s, vh, _ = _svdp(A, k, irl_mode=irl, random_state=0)
  112. # complex example matrix has many repeated singular values, so check only
  113. # beginning non-repeated singular vectors to avoid permutations
  114. sv_check = 27 if is_complex_type(dtype) else k
  115. u = u[:, :sv_check]
  116. vh = vh[:sv_check, :]
  117. s = s[:sv_check]
  118. # Check orthogonality of singular vectors
  119. assert_allclose(np.eye(u.shape[1]), u.conj().T @ u, atol=atol)
  120. assert_allclose(np.eye(vh.shape[0]), vh @ vh.conj().T, atol=atol)
  121. # Ensure the norm of the difference between the np.linalg.svd and
  122. # PROPACK reconstructed matrices is small
  123. u3, s3, vh3 = np.linalg.svd(A.todense())
  124. u3 = u3[:, :sv_check]
  125. s3 = s3[:sv_check]
  126. vh3 = vh3[:sv_check, :]
  127. A3 = u3 @ np.diag(s3) @ vh3
  128. recon = u @ np.diag(s) @ vh
  129. assert_allclose(np.linalg.norm(A3 - recon), 0, atol=atol)
  130. @pytest.mark.parametrize('shifts', (None, -10, 0, 1, 10, 70))
  131. @pytest.mark.parametrize('dtype', _dtypes[:2])
  132. def test_shifts(shifts, dtype):
  133. np.random.seed(0)
  134. n, k = 70, 10
  135. A = np.random.random((n, n))
  136. if shifts is not None and ((shifts < 0) or (k > min(n-1-shifts, n))):
  137. with pytest.raises(ValueError):
  138. _svdp(A, k, shifts=shifts, kmax=5*k, irl_mode=True)
  139. else:
  140. _svdp(A, k, shifts=shifts, kmax=5*k, irl_mode=True)
  141. @pytest.mark.slow
  142. @pytest.mark.xfail()
  143. def test_shifts_accuracy():
  144. np.random.seed(0)
  145. n, k = 70, 10
  146. A = np.random.random((n, n)).astype(np.double)
  147. u1, s1, vt1, _ = _svdp(A, k, shifts=None, which='SM', irl_mode=True)
  148. u2, s2, vt2, _ = _svdp(A, k, shifts=32, which='SM', irl_mode=True)
  149. # shifts <= 32 doesn't agree with shifts > 32
  150. # Does agree when which='LM' instead of 'SM'
  151. assert_allclose(s1, s2)