test_interface.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449
  1. """Test functions for the sparse.linalg._interface module
  2. """
  3. from functools import partial
  4. from itertools import product
  5. import operator
  6. from pytest import raises as assert_raises, warns
  7. from numpy.testing import assert_, assert_equal
  8. import numpy as np
  9. import scipy.sparse as sparse
  10. import scipy.sparse.linalg._interface as interface
  11. from scipy.sparse._sputils import matrix
  12. class TestLinearOperator:
  13. def setup_method(self):
  14. self.A = np.array([[1,2,3],
  15. [4,5,6]])
  16. self.B = np.array([[1,2],
  17. [3,4],
  18. [5,6]])
  19. self.C = np.array([[1,2],
  20. [3,4]])
  21. def test_matvec(self):
  22. def get_matvecs(A):
  23. return [{
  24. 'shape': A.shape,
  25. 'matvec': lambda x: np.dot(A, x).reshape(A.shape[0]),
  26. 'rmatvec': lambda x: np.dot(A.T.conj(),
  27. x).reshape(A.shape[1])
  28. },
  29. {
  30. 'shape': A.shape,
  31. 'matvec': lambda x: np.dot(A, x),
  32. 'rmatvec': lambda x: np.dot(A.T.conj(), x),
  33. 'rmatmat': lambda x: np.dot(A.T.conj(), x),
  34. 'matmat': lambda x: np.dot(A, x)
  35. }]
  36. for matvecs in get_matvecs(self.A):
  37. A = interface.LinearOperator(**matvecs)
  38. assert_(A.args == ())
  39. assert_equal(A.matvec(np.array([1,2,3])), [14,32])
  40. assert_equal(A.matvec(np.array([[1],[2],[3]])), [[14],[32]])
  41. assert_equal(A * np.array([1,2,3]), [14,32])
  42. assert_equal(A * np.array([[1],[2],[3]]), [[14],[32]])
  43. assert_equal(A.dot(np.array([1,2,3])), [14,32])
  44. assert_equal(A.dot(np.array([[1],[2],[3]])), [[14],[32]])
  45. assert_equal(A.matvec(matrix([[1],[2],[3]])), [[14],[32]])
  46. assert_equal(A * matrix([[1],[2],[3]]), [[14],[32]])
  47. assert_equal(A.dot(matrix([[1],[2],[3]])), [[14],[32]])
  48. assert_equal((2*A)*[1,1,1], [12,30])
  49. assert_equal((2 * A).rmatvec([1, 1]), [10, 14, 18])
  50. assert_equal((2*A).H.matvec([1,1]), [10, 14, 18])
  51. assert_equal((2*A)*[[1],[1],[1]], [[12],[30]])
  52. assert_equal((2 * A).matmat([[1], [1], [1]]), [[12], [30]])
  53. assert_equal((A*2)*[1,1,1], [12,30])
  54. assert_equal((A*2)*[[1],[1],[1]], [[12],[30]])
  55. assert_equal((2j*A)*[1,1,1], [12j,30j])
  56. assert_equal((A+A)*[1,1,1], [12, 30])
  57. assert_equal((A + A).rmatvec([1, 1]), [10, 14, 18])
  58. assert_equal((A+A).H.matvec([1,1]), [10, 14, 18])
  59. assert_equal((A+A)*[[1],[1],[1]], [[12], [30]])
  60. assert_equal((A+A).matmat([[1],[1],[1]]), [[12], [30]])
  61. assert_equal((-A)*[1,1,1], [-6,-15])
  62. assert_equal((-A)*[[1],[1],[1]], [[-6],[-15]])
  63. assert_equal((A-A)*[1,1,1], [0,0])
  64. assert_equal((A - A) * [[1], [1], [1]], [[0], [0]])
  65. X = np.array([[1, 2], [3, 4]])
  66. # A_asarray = np.array([[1, 2, 3], [4, 5, 6]])
  67. assert_equal((2 * A).rmatmat(X), np.dot((2 * self.A).T, X))
  68. assert_equal((A * 2).rmatmat(X), np.dot((self.A * 2).T, X))
  69. assert_equal((2j * A).rmatmat(X),
  70. np.dot((2j * self.A).T.conj(), X))
  71. assert_equal((A * 2j).rmatmat(X),
  72. np.dot((self.A * 2j).T.conj(), X))
  73. assert_equal((A + A).rmatmat(X),
  74. np.dot((self.A + self.A).T, X))
  75. assert_equal((A + 2j * A).rmatmat(X),
  76. np.dot((self.A + 2j * self.A).T.conj(), X))
  77. assert_equal((-A).rmatmat(X), np.dot((-self.A).T, X))
  78. assert_equal((A - A).rmatmat(X),
  79. np.dot((self.A - self.A).T, X))
  80. assert_equal((2j * A).rmatmat(2j * X),
  81. np.dot((2j * self.A).T.conj(), 2j * X))
  82. z = A+A
  83. assert_(len(z.args) == 2 and z.args[0] is A and z.args[1] is A)
  84. z = 2*A
  85. assert_(len(z.args) == 2 and z.args[0] is A and z.args[1] == 2)
  86. assert_(isinstance(A.matvec([1, 2, 3]), np.ndarray))
  87. assert_(isinstance(A.matvec(np.array([[1],[2],[3]])), np.ndarray))
  88. assert_(isinstance(A * np.array([1,2,3]), np.ndarray))
  89. assert_(isinstance(A * np.array([[1],[2],[3]]), np.ndarray))
  90. assert_(isinstance(A.dot(np.array([1,2,3])), np.ndarray))
  91. assert_(isinstance(A.dot(np.array([[1],[2],[3]])), np.ndarray))
  92. assert_(isinstance(A.matvec(matrix([[1],[2],[3]])), np.ndarray))
  93. assert_(isinstance(A * matrix([[1],[2],[3]]), np.ndarray))
  94. assert_(isinstance(A.dot(matrix([[1],[2],[3]])), np.ndarray))
  95. assert_(isinstance(2*A, interface._ScaledLinearOperator))
  96. assert_(isinstance(2j*A, interface._ScaledLinearOperator))
  97. assert_(isinstance(A+A, interface._SumLinearOperator))
  98. assert_(isinstance(-A, interface._ScaledLinearOperator))
  99. assert_(isinstance(A-A, interface._SumLinearOperator))
  100. assert_((2j*A).dtype == np.complex_)
  101. assert_raises(ValueError, A.matvec, np.array([1,2]))
  102. assert_raises(ValueError, A.matvec, np.array([1,2,3,4]))
  103. assert_raises(ValueError, A.matvec, np.array([[1],[2]]))
  104. assert_raises(ValueError, A.matvec, np.array([[1],[2],[3],[4]]))
  105. assert_raises(ValueError, lambda: A*A)
  106. assert_raises(ValueError, lambda: A**2)
  107. for matvecsA, matvecsB in product(get_matvecs(self.A),
  108. get_matvecs(self.B)):
  109. A = interface.LinearOperator(**matvecsA)
  110. B = interface.LinearOperator(**matvecsB)
  111. # AtimesB = np.array([[22, 28], [49, 64]])
  112. AtimesB = self.A.dot(self.B)
  113. X = np.array([[1, 2], [3, 4]])
  114. assert_equal((A * B).rmatmat(X), np.dot((AtimesB).T, X))
  115. assert_equal((2j * A * B).rmatmat(X),
  116. np.dot((2j * AtimesB).T.conj(), X))
  117. assert_equal((A*B)*[1,1], [50,113])
  118. assert_equal((A*B)*[[1],[1]], [[50],[113]])
  119. assert_equal((A*B).matmat([[1],[1]]), [[50],[113]])
  120. assert_equal((A * B).rmatvec([1, 1]), [71, 92])
  121. assert_equal((A * B).H.matvec([1, 1]), [71, 92])
  122. assert_(isinstance(A*B, interface._ProductLinearOperator))
  123. assert_raises(ValueError, lambda: A+B)
  124. assert_raises(ValueError, lambda: A**2)
  125. z = A*B
  126. assert_(len(z.args) == 2 and z.args[0] is A and z.args[1] is B)
  127. for matvecsC in get_matvecs(self.C):
  128. C = interface.LinearOperator(**matvecsC)
  129. X = np.array([[1, 2], [3, 4]])
  130. assert_equal(C.rmatmat(X), np.dot((self.C).T, X))
  131. assert_equal((C**2).rmatmat(X),
  132. np.dot((np.dot(self.C, self.C)).T, X))
  133. assert_equal((C**2)*[1,1], [17,37])
  134. assert_equal((C**2).rmatvec([1, 1]), [22, 32])
  135. assert_equal((C**2).H.matvec([1, 1]), [22, 32])
  136. assert_equal((C**2).matmat([[1],[1]]), [[17],[37]])
  137. assert_(isinstance(C**2, interface._PowerLinearOperator))
  138. def test_matmul(self):
  139. D = {'shape': self.A.shape,
  140. 'matvec': lambda x: np.dot(self.A, x).reshape(self.A.shape[0]),
  141. 'rmatvec': lambda x: np.dot(self.A.T.conj(),
  142. x).reshape(self.A.shape[1]),
  143. 'rmatmat': lambda x: np.dot(self.A.T.conj(), x),
  144. 'matmat': lambda x: np.dot(self.A, x)}
  145. A = interface.LinearOperator(**D)
  146. B = np.array([[1, 2, 3],
  147. [4, 5, 6],
  148. [7, 8, 9]])
  149. b = B[0]
  150. assert_equal(operator.matmul(A, b), A * b)
  151. assert_equal(operator.matmul(A, B), A * B)
  152. assert_raises(ValueError, operator.matmul, A, 2)
  153. assert_raises(ValueError, operator.matmul, 2, A)
  154. class TestAsLinearOperator:
  155. def setup_method(self):
  156. self.cases = []
  157. def make_cases(original, dtype):
  158. cases = []
  159. cases.append((matrix(original, dtype=dtype), original))
  160. cases.append((np.array(original, dtype=dtype), original))
  161. cases.append((sparse.csr_matrix(original, dtype=dtype), original))
  162. # Test default implementations of _adjoint and _rmatvec, which
  163. # refer to each other.
  164. def mv(x, dtype):
  165. y = original.dot(x)
  166. if len(x.shape) == 2:
  167. y = y.reshape(-1, 1)
  168. return y
  169. def rmv(x, dtype):
  170. return original.T.conj().dot(x)
  171. class BaseMatlike(interface.LinearOperator):
  172. args = ()
  173. def __init__(self, dtype):
  174. self.dtype = np.dtype(dtype)
  175. self.shape = original.shape
  176. def _matvec(self, x):
  177. return mv(x, self.dtype)
  178. class HasRmatvec(BaseMatlike):
  179. args = ()
  180. def _rmatvec(self,x):
  181. return rmv(x, self.dtype)
  182. class HasAdjoint(BaseMatlike):
  183. args = ()
  184. def _adjoint(self):
  185. shape = self.shape[1], self.shape[0]
  186. matvec = partial(rmv, dtype=self.dtype)
  187. rmatvec = partial(mv, dtype=self.dtype)
  188. return interface.LinearOperator(matvec=matvec,
  189. rmatvec=rmatvec,
  190. dtype=self.dtype,
  191. shape=shape)
  192. class HasRmatmat(HasRmatvec):
  193. def _matmat(self, x):
  194. return original.dot(x)
  195. def _rmatmat(self, x):
  196. return original.T.conj().dot(x)
  197. cases.append((HasRmatvec(dtype), original))
  198. cases.append((HasAdjoint(dtype), original))
  199. cases.append((HasRmatmat(dtype), original))
  200. return cases
  201. original = np.array([[1,2,3], [4,5,6]])
  202. self.cases += make_cases(original, np.int32)
  203. self.cases += make_cases(original, np.float32)
  204. self.cases += make_cases(original, np.float64)
  205. self.cases += [(interface.aslinearoperator(M).T, A.T)
  206. for M, A in make_cases(original.T, np.float64)]
  207. self.cases += [(interface.aslinearoperator(M).H, A.T.conj())
  208. for M, A in make_cases(original.T, np.float64)]
  209. original = np.array([[1, 2j, 3j], [4j, 5j, 6]])
  210. self.cases += make_cases(original, np.complex_)
  211. self.cases += [(interface.aslinearoperator(M).T, A.T)
  212. for M, A in make_cases(original.T, np.complex_)]
  213. self.cases += [(interface.aslinearoperator(M).H, A.T.conj())
  214. for M, A in make_cases(original.T, np.complex_)]
  215. def test_basic(self):
  216. for M, A_array in self.cases:
  217. A = interface.aslinearoperator(M)
  218. M,N = A.shape
  219. xs = [np.array([1, 2, 3]),
  220. np.array([[1], [2], [3]])]
  221. ys = [np.array([1, 2]), np.array([[1], [2]])]
  222. if A.dtype == np.complex_:
  223. xs += [np.array([1, 2j, 3j]),
  224. np.array([[1], [2j], [3j]])]
  225. ys += [np.array([1, 2j]), np.array([[1], [2j]])]
  226. x2 = np.array([[1, 4], [2, 5], [3, 6]])
  227. for x in xs:
  228. assert_equal(A.matvec(x), A_array.dot(x))
  229. assert_equal(A * x, A_array.dot(x))
  230. assert_equal(A.matmat(x2), A_array.dot(x2))
  231. assert_equal(A * x2, A_array.dot(x2))
  232. for y in ys:
  233. assert_equal(A.rmatvec(y), A_array.T.conj().dot(y))
  234. assert_equal(A.T.matvec(y), A_array.T.dot(y))
  235. assert_equal(A.H.matvec(y), A_array.T.conj().dot(y))
  236. for y in ys:
  237. if y.ndim < 2:
  238. continue
  239. assert_equal(A.rmatmat(y), A_array.T.conj().dot(y))
  240. assert_equal(A.T.matmat(y), A_array.T.dot(y))
  241. assert_equal(A.H.matmat(y), A_array.T.conj().dot(y))
  242. if hasattr(M,'dtype'):
  243. assert_equal(A.dtype, M.dtype)
  244. assert_(hasattr(A, 'args'))
  245. def test_dot(self):
  246. for M, A_array in self.cases:
  247. A = interface.aslinearoperator(M)
  248. M,N = A.shape
  249. x0 = np.array([1, 2, 3])
  250. x1 = np.array([[1], [2], [3]])
  251. x2 = np.array([[1, 4], [2, 5], [3, 6]])
  252. assert_equal(A.dot(x0), A_array.dot(x0))
  253. assert_equal(A.dot(x1), A_array.dot(x1))
  254. assert_equal(A.dot(x2), A_array.dot(x2))
  255. def test_repr():
  256. A = interface.LinearOperator(shape=(1, 1), matvec=lambda x: 1)
  257. repr_A = repr(A)
  258. assert_('unspecified dtype' not in repr_A, repr_A)
  259. def test_identity():
  260. ident = interface.IdentityOperator((3, 3))
  261. assert_equal(ident * [1, 2, 3], [1, 2, 3])
  262. assert_equal(ident.dot(np.arange(9).reshape(3, 3)).ravel(), np.arange(9))
  263. assert_raises(ValueError, ident.matvec, [1, 2, 3, 4])
  264. def test_attributes():
  265. A = interface.aslinearoperator(np.arange(16).reshape(4, 4))
  266. def always_four_ones(x):
  267. x = np.asarray(x)
  268. assert_(x.shape == (3,) or x.shape == (3, 1))
  269. return np.ones(4)
  270. B = interface.LinearOperator(shape=(4, 3), matvec=always_four_ones)
  271. for op in [A, B, A * B, A.H, A + A, B + B, A**4]:
  272. assert_(hasattr(op, "dtype"))
  273. assert_(hasattr(op, "shape"))
  274. assert_(hasattr(op, "_matvec"))
  275. def matvec(x):
  276. """ Needed for test_pickle as local functions are not pickleable """
  277. return np.zeros(3)
  278. def test_pickle():
  279. import pickle
  280. for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
  281. A = interface.LinearOperator((3, 3), matvec)
  282. s = pickle.dumps(A, protocol=protocol)
  283. B = pickle.loads(s)
  284. for k in A.__dict__:
  285. assert_equal(getattr(A, k), getattr(B, k))
  286. def test_inheritance():
  287. class Empty(interface.LinearOperator):
  288. pass
  289. with warns(RuntimeWarning, match="should implement at least"):
  290. assert_raises(TypeError, Empty)
  291. class Identity(interface.LinearOperator):
  292. def __init__(self, n):
  293. super().__init__(dtype=None, shape=(n, n))
  294. def _matvec(self, x):
  295. return x
  296. id3 = Identity(3)
  297. assert_equal(id3.matvec([1, 2, 3]), [1, 2, 3])
  298. assert_raises(NotImplementedError, id3.rmatvec, [4, 5, 6])
  299. class MatmatOnly(interface.LinearOperator):
  300. def __init__(self, A):
  301. super().__init__(A.dtype, A.shape)
  302. self.A = A
  303. def _matmat(self, x):
  304. return self.A.dot(x)
  305. mm = MatmatOnly(np.random.randn(5, 3))
  306. assert_equal(mm.matvec(np.random.randn(3)).shape, (5,))
  307. def test_dtypes_of_operator_sum():
  308. # gh-6078
  309. mat_complex = np.random.rand(2,2) + 1j * np.random.rand(2,2)
  310. mat_real = np.random.rand(2,2)
  311. complex_operator = interface.aslinearoperator(mat_complex)
  312. real_operator = interface.aslinearoperator(mat_real)
  313. sum_complex = complex_operator + complex_operator
  314. sum_real = real_operator + real_operator
  315. assert_equal(sum_real.dtype, np.float64)
  316. assert_equal(sum_complex.dtype, np.complex128)
  317. def test_no_double_init():
  318. call_count = [0]
  319. def matvec(v):
  320. call_count[0] += 1
  321. return v
  322. # It should call matvec exactly once (in order to determine the
  323. # operator dtype)
  324. interface.LinearOperator((2, 2), matvec=matvec)
  325. assert_equal(call_count[0], 1)
  326. def test_adjoint_conjugate():
  327. X = np.array([[1j]])
  328. A = interface.aslinearoperator(X)
  329. B = 1j * A
  330. Y = 1j * X
  331. v = np.array([1])
  332. assert_equal(B.dot(v), Y.dot(v))
  333. assert_equal(B.H.dot(v), Y.T.conj().dot(v))
  334. def test_ndim():
  335. X = np.array([[1]])
  336. A = interface.aslinearoperator(X)
  337. assert_equal(A.ndim, 2)
  338. def test_transpose_noconjugate():
  339. X = np.array([[1j]])
  340. A = interface.aslinearoperator(X)
  341. B = 1j * A
  342. Y = 1j * X
  343. v = np.array([1])
  344. assert_equal(B.dot(v), Y.dot(v))
  345. assert_equal(B.T.dot(v), Y.T.dot(v))