test_interpolative.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. #******************************************************************************
  2. # Copyright (C) 2013 Kenneth L. Ho
  3. # Redistribution and use in source and binary forms, with or without
  4. # modification, are permitted provided that the following conditions are met:
  5. #
  6. # Redistributions of source code must retain the above copyright notice, this
  7. # list of conditions and the following disclaimer. Redistributions in binary
  8. # form must reproduce the above copyright notice, this list of conditions and
  9. # the following disclaimer in the documentation and/or other materials
  10. # provided with the distribution.
  11. #
  12. # None of the names of the copyright holders may be used to endorse or
  13. # promote products derived from this software without specific prior written
  14. # permission.
  15. #
  16. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  17. # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  18. # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  19. # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  20. # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  21. # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  22. # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  23. # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  24. # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  25. # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  26. # POSSIBILITY OF SUCH DAMAGE.
  27. #******************************************************************************
  28. import scipy.linalg.interpolative as pymatrixid
  29. import numpy as np
  30. from scipy.linalg import hilbert, svdvals, norm
  31. from scipy.sparse.linalg import aslinearoperator
  32. from scipy.linalg.interpolative import interp_decomp
  33. from numpy.testing import (assert_, assert_allclose, assert_equal,
  34. assert_array_equal)
  35. import pytest
  36. from pytest import raises as assert_raises
  37. import sys
  38. _IS_32BIT = (sys.maxsize < 2**32)
  39. @pytest.fixture()
  40. def eps():
  41. yield 1e-12
  42. @pytest.fixture(params=[np.float64, np.complex128])
  43. def A(request):
  44. # construct Hilbert matrix
  45. # set parameters
  46. n = 300
  47. yield hilbert(n).astype(request.param)
  48. @pytest.fixture()
  49. def L(A):
  50. yield aslinearoperator(A)
  51. @pytest.fixture()
  52. def rank(A, eps):
  53. S = np.linalg.svd(A, compute_uv=False)
  54. try:
  55. rank = np.nonzero(S < eps)[0][0]
  56. except IndexError:
  57. rank = A.shape[0]
  58. return rank
  59. class TestInterpolativeDecomposition:
  60. @pytest.mark.parametrize(
  61. "rand,lin_op",
  62. [(False, False), (True, False), (True, True)])
  63. def test_real_id_fixed_precision(self, A, L, eps, rand, lin_op):
  64. if _IS_32BIT and A.dtype == np.complex_ and rand:
  65. pytest.xfail("bug in external fortran code")
  66. # Test ID routines on a Hilbert matrix.
  67. A_or_L = A if not lin_op else L
  68. k, idx, proj = pymatrixid.interp_decomp(A_or_L, eps, rand=rand)
  69. B = pymatrixid.reconstruct_matrix_from_id(A[:, idx[:k]], idx, proj)
  70. assert_allclose(A, B, rtol=eps, atol=1e-08)
  71. @pytest.mark.parametrize(
  72. "rand,lin_op",
  73. [(False, False), (True, False), (True, True)])
  74. def test_real_id_fixed_rank(self, A, L, eps, rank, rand, lin_op):
  75. if _IS_32BIT and A.dtype == np.complex_ and rand:
  76. pytest.xfail("bug in external fortran code")
  77. k = rank
  78. A_or_L = A if not lin_op else L
  79. idx, proj = pymatrixid.interp_decomp(A_or_L, k, rand=rand)
  80. B = pymatrixid.reconstruct_matrix_from_id(A[:, idx[:k]], idx, proj)
  81. assert_allclose(A, B, rtol=eps, atol=1e-08)
  82. @pytest.mark.parametrize("rand,lin_op", [(False, False)])
  83. def test_real_id_skel_and_interp_matrices(
  84. self, A, L, eps, rank, rand, lin_op):
  85. k = rank
  86. A_or_L = A if not lin_op else L
  87. idx, proj = pymatrixid.interp_decomp(A_or_L, k, rand=rand)
  88. P = pymatrixid.reconstruct_interp_matrix(idx, proj)
  89. B = pymatrixid.reconstruct_skel_matrix(A, k, idx)
  90. assert_allclose(B, A[:, idx[:k]], rtol=eps, atol=1e-08)
  91. assert_allclose(B @ P, A, rtol=eps, atol=1e-08)
  92. @pytest.mark.parametrize(
  93. "rand,lin_op",
  94. [(False, False), (True, False), (True, True)])
  95. def test_svd_fixed_precison(self, A, L, eps, rand, lin_op):
  96. if _IS_32BIT and A.dtype == np.complex_ and rand:
  97. pytest.xfail("bug in external fortran code")
  98. A_or_L = A if not lin_op else L
  99. U, S, V = pymatrixid.svd(A_or_L, eps, rand=rand)
  100. B = U * S @ V.T.conj()
  101. assert_allclose(A, B, rtol=eps, atol=1e-08)
  102. @pytest.mark.parametrize(
  103. "rand,lin_op",
  104. [(False, False), (True, False), (True, True)])
  105. def test_svd_fixed_rank(self, A, L, eps, rank, rand, lin_op):
  106. if _IS_32BIT and A.dtype == np.complex_ and rand:
  107. pytest.xfail("bug in external fortran code")
  108. k = rank
  109. A_or_L = A if not lin_op else L
  110. U, S, V = pymatrixid.svd(A_or_L, k, rand=rand)
  111. B = U * S @ V.T.conj()
  112. assert_allclose(A, B, rtol=eps, atol=1e-08)
  113. def test_id_to_svd(self, A, eps, rank):
  114. k = rank
  115. idx, proj = pymatrixid.interp_decomp(A, k, rand=False)
  116. U, S, V = pymatrixid.id_to_svd(A[:, idx[:k]], idx, proj)
  117. B = U * S @ V.T.conj()
  118. assert_allclose(A, B, rtol=eps, atol=1e-08)
  119. def test_estimate_spectral_norm(self, A):
  120. s = svdvals(A)
  121. norm_2_est = pymatrixid.estimate_spectral_norm(A)
  122. assert_allclose(norm_2_est, s[0], rtol=1e-6, atol=1e-8)
  123. def test_estimate_spectral_norm_diff(self, A):
  124. B = A.copy()
  125. B[:, 0] *= 1.2
  126. s = svdvals(A - B)
  127. norm_2_est = pymatrixid.estimate_spectral_norm_diff(A, B)
  128. assert_allclose(norm_2_est, s[0], rtol=1e-6, atol=1e-8)
  129. def test_rank_estimates_array(self, A):
  130. B = np.array([[1, 1, 0], [0, 0, 1], [0, 0, 1]], dtype=A.dtype)
  131. for M in [A, B]:
  132. rank_tol = 1e-9
  133. rank_np = np.linalg.matrix_rank(M, norm(M, 2) * rank_tol)
  134. rank_est = pymatrixid.estimate_rank(M, rank_tol)
  135. assert_(rank_est >= rank_np)
  136. assert_(rank_est <= rank_np + 10)
  137. def test_rank_estimates_lin_op(self, A):
  138. B = np.array([[1, 1, 0], [0, 0, 1], [0, 0, 1]], dtype=A.dtype)
  139. for M in [A, B]:
  140. ML = aslinearoperator(M)
  141. rank_tol = 1e-9
  142. rank_np = np.linalg.matrix_rank(M, norm(M, 2) * rank_tol)
  143. rank_est = pymatrixid.estimate_rank(ML, rank_tol)
  144. assert_(rank_est >= rank_np - 4)
  145. assert_(rank_est <= rank_np + 4)
  146. def test_rand(self):
  147. pymatrixid.seed('default')
  148. assert_allclose(pymatrixid.rand(2), [0.8932059, 0.64500803],
  149. rtol=1e-4, atol=1e-8)
  150. pymatrixid.seed(1234)
  151. x1 = pymatrixid.rand(2)
  152. assert_allclose(x1, [0.7513823, 0.06861718], rtol=1e-4, atol=1e-8)
  153. np.random.seed(1234)
  154. pymatrixid.seed()
  155. x2 = pymatrixid.rand(2)
  156. np.random.seed(1234)
  157. pymatrixid.seed(np.random.rand(55))
  158. x3 = pymatrixid.rand(2)
  159. assert_allclose(x1, x2)
  160. assert_allclose(x1, x3)
  161. def test_badcall(self):
  162. A = hilbert(5).astype(np.float32)
  163. with assert_raises(ValueError):
  164. pymatrixid.interp_decomp(A, 1e-6, rand=False)
  165. def test_rank_too_large(self):
  166. # svd(array, k) should not segfault
  167. a = np.ones((4, 3))
  168. with assert_raises(ValueError):
  169. pymatrixid.svd(a, 4)
  170. def test_full_rank(self):
  171. eps = 1.0e-12
  172. # fixed precision
  173. A = np.random.rand(16, 8)
  174. k, idx, proj = pymatrixid.interp_decomp(A, eps)
  175. assert_equal(k, A.shape[1])
  176. P = pymatrixid.reconstruct_interp_matrix(idx, proj)
  177. B = pymatrixid.reconstruct_skel_matrix(A, k, idx)
  178. assert_allclose(A, B @ P)
  179. # fixed rank
  180. idx, proj = pymatrixid.interp_decomp(A, k)
  181. P = pymatrixid.reconstruct_interp_matrix(idx, proj)
  182. B = pymatrixid.reconstruct_skel_matrix(A, k, idx)
  183. assert_allclose(A, B @ P)
  184. @pytest.mark.parametrize("dtype", [np.float_, np.complex_])
  185. @pytest.mark.parametrize("rand", [True, False])
  186. @pytest.mark.parametrize("eps", [1, 0.1])
  187. def test_bug_9793(self, dtype, rand, eps):
  188. if _IS_32BIT and dtype == np.complex_ and rand:
  189. pytest.xfail("bug in external fortran code")
  190. A = np.array([[-1, -1, -1, 0, 0, 0],
  191. [0, 0, 0, 1, 1, 1],
  192. [1, 0, 0, 1, 0, 0],
  193. [0, 1, 0, 0, 1, 0],
  194. [0, 0, 1, 0, 0, 1]],
  195. dtype=dtype, order="C")
  196. B = A.copy()
  197. interp_decomp(A.T, eps, rand=rand)
  198. assert_array_equal(A, B)