test_expm_multiply.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. """Test functions for the sparse.linalg._expm_multiply module."""
  2. from functools import partial
  3. from itertools import product
  4. import numpy as np
  5. import pytest
  6. from numpy.testing import (assert_allclose, assert_, assert_equal,
  7. suppress_warnings)
  8. from scipy.sparse import SparseEfficiencyWarning
  9. from scipy.sparse.linalg import aslinearoperator
  10. import scipy.linalg
  11. from scipy.sparse.linalg import expm as sp_expm
  12. from scipy.sparse.linalg._expm_multiply import (_theta, _compute_p_max,
  13. _onenormest_matrix_power, expm_multiply, _expm_multiply_simple,
  14. _expm_multiply_interval)
  15. IMPRECISE = {np.single, np.csingle}
  16. REAL_DTYPES = {np.intc, np.int_, np.longlong,
  17. np.single, np.double, np.longdouble}
  18. COMPLEX_DTYPES = {np.csingle, np.cdouble, np.clongdouble}
  19. # use sorted tuple to ensure fixed order of tests
  20. DTYPES = tuple(sorted(REAL_DTYPES ^ COMPLEX_DTYPES, key=str))
  21. def estimated(func):
  22. """If trace is estimated, it should warn.
  23. We warn that estimation of trace might impact performance.
  24. All result have to be correct nevertheless!
  25. """
  26. def wrapped(*args, **kwds):
  27. with pytest.warns(UserWarning,
  28. match="Trace of LinearOperator not available"):
  29. return func(*args, **kwds)
  30. return wrapped
  31. def less_than_or_close(a, b):
  32. return np.allclose(a, b) or (a < b)
  33. class TestExpmActionSimple:
  34. """
  35. These tests do not consider the case of multiple time steps in one call.
  36. """
  37. def test_theta_monotonicity(self):
  38. pairs = sorted(_theta.items())
  39. for (m_a, theta_a), (m_b, theta_b) in zip(pairs[:-1], pairs[1:]):
  40. assert_(theta_a < theta_b)
  41. def test_p_max_default(self):
  42. m_max = 55
  43. expected_p_max = 8
  44. observed_p_max = _compute_p_max(m_max)
  45. assert_equal(observed_p_max, expected_p_max)
  46. def test_p_max_range(self):
  47. for m_max in range(1, 55+1):
  48. p_max = _compute_p_max(m_max)
  49. assert_(p_max*(p_max - 1) <= m_max + 1)
  50. p_too_big = p_max + 1
  51. assert_(p_too_big*(p_too_big - 1) > m_max + 1)
  52. def test_onenormest_matrix_power(self):
  53. np.random.seed(1234)
  54. n = 40
  55. nsamples = 10
  56. for i in range(nsamples):
  57. A = scipy.linalg.inv(np.random.randn(n, n))
  58. for p in range(4):
  59. if not p:
  60. M = np.identity(n)
  61. else:
  62. M = np.dot(M, A)
  63. estimated = _onenormest_matrix_power(A, p)
  64. exact = np.linalg.norm(M, 1)
  65. assert_(less_than_or_close(estimated, exact))
  66. assert_(less_than_or_close(exact, 3*estimated))
  67. def test_expm_multiply(self):
  68. np.random.seed(1234)
  69. n = 40
  70. k = 3
  71. nsamples = 10
  72. for i in range(nsamples):
  73. A = scipy.linalg.inv(np.random.randn(n, n))
  74. B = np.random.randn(n, k)
  75. observed = expm_multiply(A, B)
  76. expected = np.dot(sp_expm(A), B)
  77. assert_allclose(observed, expected)
  78. observed = estimated(expm_multiply)(aslinearoperator(A), B)
  79. assert_allclose(observed, expected)
  80. traceA = np.trace(A)
  81. observed = expm_multiply(aslinearoperator(A), B, traceA=traceA)
  82. assert_allclose(observed, expected)
  83. def test_matrix_vector_multiply(self):
  84. np.random.seed(1234)
  85. n = 40
  86. nsamples = 10
  87. for i in range(nsamples):
  88. A = scipy.linalg.inv(np.random.randn(n, n))
  89. v = np.random.randn(n)
  90. observed = expm_multiply(A, v)
  91. expected = np.dot(sp_expm(A), v)
  92. assert_allclose(observed, expected)
  93. observed = estimated(expm_multiply)(aslinearoperator(A), v)
  94. assert_allclose(observed, expected)
  95. def test_scaled_expm_multiply(self):
  96. np.random.seed(1234)
  97. n = 40
  98. k = 3
  99. nsamples = 10
  100. for i, t in product(range(nsamples), [0.2, 1.0, 1.5]):
  101. with np.errstate(invalid='ignore'):
  102. A = scipy.linalg.inv(np.random.randn(n, n))
  103. B = np.random.randn(n, k)
  104. observed = _expm_multiply_simple(A, B, t=t)
  105. expected = np.dot(sp_expm(t*A), B)
  106. assert_allclose(observed, expected)
  107. observed = estimated(_expm_multiply_simple)(
  108. aslinearoperator(A), B, t=t
  109. )
  110. assert_allclose(observed, expected)
  111. def test_scaled_expm_multiply_single_timepoint(self):
  112. np.random.seed(1234)
  113. t = 0.1
  114. n = 5
  115. k = 2
  116. A = np.random.randn(n, n)
  117. B = np.random.randn(n, k)
  118. observed = _expm_multiply_simple(A, B, t=t)
  119. expected = sp_expm(t*A).dot(B)
  120. assert_allclose(observed, expected)
  121. observed = estimated(_expm_multiply_simple)(
  122. aslinearoperator(A), B, t=t
  123. )
  124. assert_allclose(observed, expected)
  125. def test_sparse_expm_multiply(self):
  126. np.random.seed(1234)
  127. n = 40
  128. k = 3
  129. nsamples = 10
  130. for i in range(nsamples):
  131. A = scipy.sparse.rand(n, n, density=0.05)
  132. B = np.random.randn(n, k)
  133. observed = expm_multiply(A, B)
  134. with suppress_warnings() as sup:
  135. sup.filter(SparseEfficiencyWarning,
  136. "splu converted its input to CSC format")
  137. sup.filter(SparseEfficiencyWarning,
  138. "spsolve is more efficient when sparse b is in the"
  139. " CSC matrix format")
  140. expected = sp_expm(A).dot(B)
  141. assert_allclose(observed, expected)
  142. observed = estimated(expm_multiply)(aslinearoperator(A), B)
  143. assert_allclose(observed, expected)
  144. def test_complex(self):
  145. A = np.array([
  146. [1j, 1j],
  147. [0, 1j]], dtype=complex)
  148. B = np.array([1j, 1j])
  149. observed = expm_multiply(A, B)
  150. expected = np.array([
  151. 1j * np.exp(1j) + 1j * (1j*np.cos(1) - np.sin(1)),
  152. 1j * np.exp(1j)], dtype=complex)
  153. assert_allclose(observed, expected)
  154. observed = estimated(expm_multiply)(aslinearoperator(A), B)
  155. assert_allclose(observed, expected)
  156. class TestExpmActionInterval:
  157. def test_sparse_expm_multiply_interval(self):
  158. np.random.seed(1234)
  159. start = 0.1
  160. stop = 3.2
  161. n = 40
  162. k = 3
  163. endpoint = True
  164. for num in (14, 13, 2):
  165. A = scipy.sparse.rand(n, n, density=0.05)
  166. B = np.random.randn(n, k)
  167. v = np.random.randn(n)
  168. for target in (B, v):
  169. X = expm_multiply(A, target, start=start, stop=stop,
  170. num=num, endpoint=endpoint)
  171. samples = np.linspace(start=start, stop=stop,
  172. num=num, endpoint=endpoint)
  173. with suppress_warnings() as sup:
  174. sup.filter(SparseEfficiencyWarning,
  175. "splu converted its input to CSC format")
  176. sup.filter(SparseEfficiencyWarning,
  177. "spsolve is more efficient when sparse b is in"
  178. " the CSC matrix format")
  179. for solution, t in zip(X, samples):
  180. assert_allclose(solution, sp_expm(t*A).dot(target))
  181. def test_expm_multiply_interval_vector(self):
  182. np.random.seed(1234)
  183. interval = {'start': 0.1, 'stop': 3.2, 'endpoint': True}
  184. for num, n in product([14, 13, 2], [1, 2, 5, 20, 40]):
  185. A = scipy.linalg.inv(np.random.randn(n, n))
  186. v = np.random.randn(n)
  187. samples = np.linspace(num=num, **interval)
  188. X = expm_multiply(A, v, num=num, **interval)
  189. for solution, t in zip(X, samples):
  190. assert_allclose(solution, sp_expm(t*A).dot(v))
  191. # test for linear operator with unknown trace -> estimate trace
  192. Xguess = estimated(expm_multiply)(aslinearoperator(A), v,
  193. num=num, **interval)
  194. # test for linear operator with given trace
  195. Xgiven = expm_multiply(aslinearoperator(A), v, num=num, **interval,
  196. traceA=np.trace(A))
  197. # test robustness for linear operator with wrong trace
  198. Xwrong = expm_multiply(aslinearoperator(A), v, num=num, **interval,
  199. traceA=np.trace(A)*5)
  200. for sol_guess, sol_given, sol_wrong, t in zip(Xguess, Xgiven,
  201. Xwrong, samples):
  202. correct = sp_expm(t*A).dot(v)
  203. assert_allclose(sol_guess, correct)
  204. assert_allclose(sol_given, correct)
  205. assert_allclose(sol_wrong, correct)
  206. def test_expm_multiply_interval_matrix(self):
  207. np.random.seed(1234)
  208. interval = {'start': 0.1, 'stop': 3.2, 'endpoint': True}
  209. for num, n, k in product([14, 13, 2], [1, 2, 5, 20, 40], [1, 2]):
  210. A = scipy.linalg.inv(np.random.randn(n, n))
  211. B = np.random.randn(n, k)
  212. samples = np.linspace(num=num, **interval)
  213. X = expm_multiply(A, B, num=num, **interval)
  214. for solution, t in zip(X, samples):
  215. assert_allclose(solution, sp_expm(t*A).dot(B))
  216. X = estimated(expm_multiply)(aslinearoperator(A), B, num=num,
  217. **interval)
  218. for solution, t in zip(X, samples):
  219. assert_allclose(solution, sp_expm(t*A).dot(B))
  220. def test_sparse_expm_multiply_interval_dtypes(self):
  221. # Test A & B int
  222. A = scipy.sparse.diags(np.arange(5),format='csr', dtype=int)
  223. B = np.ones(5, dtype=int)
  224. Aexpm = scipy.sparse.diags(np.exp(np.arange(5)),format='csr')
  225. assert_allclose(expm_multiply(A,B,0,1)[-1], Aexpm.dot(B))
  226. # Test A complex, B int
  227. A = scipy.sparse.diags(-1j*np.arange(5),format='csr', dtype=complex)
  228. B = np.ones(5, dtype=int)
  229. Aexpm = scipy.sparse.diags(np.exp(-1j*np.arange(5)),format='csr')
  230. assert_allclose(expm_multiply(A,B,0,1)[-1], Aexpm.dot(B))
  231. # Test A int, B complex
  232. A = scipy.sparse.diags(np.arange(5),format='csr', dtype=int)
  233. B = np.full(5, 1j, dtype=complex)
  234. Aexpm = scipy.sparse.diags(np.exp(np.arange(5)),format='csr')
  235. assert_allclose(expm_multiply(A,B,0,1)[-1], Aexpm.dot(B))
  236. def test_expm_multiply_interval_status_0(self):
  237. self._help_test_specific_expm_interval_status(0)
  238. def test_expm_multiply_interval_status_1(self):
  239. self._help_test_specific_expm_interval_status(1)
  240. def test_expm_multiply_interval_status_2(self):
  241. self._help_test_specific_expm_interval_status(2)
  242. def _help_test_specific_expm_interval_status(self, target_status):
  243. np.random.seed(1234)
  244. start = 0.1
  245. stop = 3.2
  246. num = 13
  247. endpoint = True
  248. n = 5
  249. k = 2
  250. nrepeats = 10
  251. nsuccesses = 0
  252. for num in [14, 13, 2] * nrepeats:
  253. A = np.random.randn(n, n)
  254. B = np.random.randn(n, k)
  255. status = _expm_multiply_interval(A, B,
  256. start=start, stop=stop, num=num, endpoint=endpoint,
  257. status_only=True)
  258. if status == target_status:
  259. X, status = _expm_multiply_interval(A, B,
  260. start=start, stop=stop, num=num, endpoint=endpoint,
  261. status_only=False)
  262. assert_equal(X.shape, (num, n, k))
  263. samples = np.linspace(start=start, stop=stop,
  264. num=num, endpoint=endpoint)
  265. for solution, t in zip(X, samples):
  266. assert_allclose(solution, sp_expm(t*A).dot(B))
  267. nsuccesses += 1
  268. if not nsuccesses:
  269. msg = 'failed to find a status-' + str(target_status) + ' interval'
  270. raise Exception(msg)
  271. @pytest.mark.parametrize("dtype_a", DTYPES)
  272. @pytest.mark.parametrize("dtype_b", DTYPES)
  273. @pytest.mark.parametrize("b_is_matrix", [False, True])
  274. def test_expm_multiply_dtype(dtype_a, dtype_b, b_is_matrix):
  275. """Make sure `expm_multiply` handles all numerical dtypes correctly."""
  276. assert_allclose_ = (partial(assert_allclose, rtol=1.2e-3, atol=1e-5)
  277. if {dtype_a, dtype_b} & IMPRECISE else assert_allclose)
  278. rng = np.random.default_rng(1234)
  279. # test data
  280. n = 7
  281. b_shape = (n, 3) if b_is_matrix else (n, )
  282. if dtype_a in REAL_DTYPES:
  283. A = scipy.linalg.inv(rng.random([n, n])).astype(dtype_a)
  284. else:
  285. A = scipy.linalg.inv(
  286. rng.random([n, n]) + 1j*rng.random([n, n])
  287. ).astype(dtype_a)
  288. if dtype_b in REAL_DTYPES:
  289. B = (2*rng.random(b_shape)).astype(dtype_b)
  290. else:
  291. B = (rng.random(b_shape) + 1j*rng.random(b_shape)).astype(dtype_b)
  292. # single application
  293. sol_mat = expm_multiply(A, B)
  294. sol_op = estimated(expm_multiply)(aslinearoperator(A), B)
  295. direct_sol = np.dot(sp_expm(A), B)
  296. assert_allclose_(sol_mat, direct_sol)
  297. assert_allclose_(sol_op, direct_sol)
  298. sol_op = expm_multiply(aslinearoperator(A), B, traceA=np.trace(A))
  299. assert_allclose_(sol_op, direct_sol)
  300. # for time points
  301. interval = {'start': 0.1, 'stop': 3.2, 'num': 13, 'endpoint': True}
  302. samples = np.linspace(**interval)
  303. X_mat = expm_multiply(A, B, **interval)
  304. X_op = estimated(expm_multiply)(aslinearoperator(A), B, **interval)
  305. for sol_mat, sol_op, t in zip(X_mat, X_op, samples):
  306. direct_sol = sp_expm(t*A).dot(B)
  307. assert_allclose_(sol_mat, direct_sol)
  308. assert_allclose_(sol_op, direct_sol)