_expm_multiply.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. """Compute the action of the matrix exponential."""
  2. from warnings import warn
  3. import numpy as np
  4. import scipy.linalg
  5. import scipy.sparse.linalg
  6. from scipy.linalg._decomp_qr import qr
  7. from scipy.sparse._sputils import is_pydata_spmatrix
  8. from scipy.sparse.linalg import aslinearoperator
  9. from scipy.sparse.linalg._interface import IdentityOperator
  10. from scipy.sparse.linalg._onenormest import onenormest
  11. __all__ = ['expm_multiply']
  12. def _exact_inf_norm(A):
  13. # A compatibility function which should eventually disappear.
  14. if scipy.sparse.isspmatrix(A):
  15. return max(abs(A).sum(axis=1).flat)
  16. elif is_pydata_spmatrix(A):
  17. return max(abs(A).sum(axis=1))
  18. else:
  19. return np.linalg.norm(A, np.inf)
  20. def _exact_1_norm(A):
  21. # A compatibility function which should eventually disappear.
  22. if scipy.sparse.isspmatrix(A):
  23. return max(abs(A).sum(axis=0).flat)
  24. elif is_pydata_spmatrix(A):
  25. return max(abs(A).sum(axis=0))
  26. else:
  27. return np.linalg.norm(A, 1)
  28. def _trace(A):
  29. # A compatibility function which should eventually disappear.
  30. if is_pydata_spmatrix(A):
  31. return A.to_scipy_sparse().trace()
  32. else:
  33. return A.trace()
  34. def traceest(A, m3, seed=None):
  35. """Estimate `np.trace(A)` using `3*m3` matrix-vector products.
  36. The result is not deterministic.
  37. Parameters
  38. ----------
  39. A : LinearOperator
  40. Linear operator whose trace will be estimated. Has to be square.
  41. m3 : int
  42. Number of matrix-vector products divided by 3 used to estimate the
  43. trace.
  44. seed : optional
  45. Seed for `numpy.random.default_rng`.
  46. Can be provided to obtain deterministic results.
  47. Returns
  48. -------
  49. trace : LinearOperator.dtype
  50. Estimate of the trace
  51. Notes
  52. -----
  53. This is the Hutch++ algorithm given in [1]_.
  54. References
  55. ----------
  56. .. [1] Meyer, Raphael A., Cameron Musco, Christopher Musco, and David P.
  57. Woodruff. "Hutch++: Optimal Stochastic Trace Estimation." In Symposium
  58. on Simplicity in Algorithms (SOSA), pp. 142-155. Society for Industrial
  59. and Applied Mathematics, 2021
  60. https://doi.org/10.1137/1.9781611976496.16
  61. """
  62. rng = np.random.default_rng(seed)
  63. if len(A.shape) != 2 or A.shape[-1] != A.shape[-2]:
  64. raise ValueError("Expected A to be like a square matrix.")
  65. n = A.shape[-1]
  66. S = rng.choice([-1.0, +1.0], [n, m3])
  67. Q, _ = qr(A.matmat(S), overwrite_a=True, mode='economic')
  68. trQAQ = np.trace(Q.conj().T @ A.matmat(Q))
  69. G = rng.choice([-1, +1], [n, m3])
  70. right = G - Q@(Q.conj().T @ G)
  71. trGAG = np.trace(right.conj().T @ A.matmat(right))
  72. return trQAQ + trGAG/m3
  73. def _ident_like(A):
  74. # A compatibility function which should eventually disappear.
  75. if scipy.sparse.isspmatrix(A):
  76. return scipy.sparse._construct.eye(A.shape[0], A.shape[1],
  77. dtype=A.dtype, format=A.format)
  78. elif is_pydata_spmatrix(A):
  79. import sparse
  80. return sparse.eye(A.shape[0], A.shape[1], dtype=A.dtype)
  81. elif isinstance(A, scipy.sparse.linalg.LinearOperator):
  82. return IdentityOperator(A.shape, dtype=A.dtype)
  83. else:
  84. return np.eye(A.shape[0], A.shape[1], dtype=A.dtype)
  85. def expm_multiply(A, B, start=None, stop=None, num=None,
  86. endpoint=None, traceA=None):
  87. """
  88. Compute the action of the matrix exponential of A on B.
  89. Parameters
  90. ----------
  91. A : transposable linear operator
  92. The operator whose exponential is of interest.
  93. B : ndarray
  94. The matrix or vector to be multiplied by the matrix exponential of A.
  95. start : scalar, optional
  96. The starting time point of the sequence.
  97. stop : scalar, optional
  98. The end time point of the sequence, unless `endpoint` is set to False.
  99. In that case, the sequence consists of all but the last of ``num + 1``
  100. evenly spaced time points, so that `stop` is excluded.
  101. Note that the step size changes when `endpoint` is False.
  102. num : int, optional
  103. Number of time points to use.
  104. endpoint : bool, optional
  105. If True, `stop` is the last time point. Otherwise, it is not included.
  106. traceA : scalar, optional
  107. Trace of `A`. If not given the trace is estimated for linear operators,
  108. or calculated exactly for sparse matrices. It is used to precondition
  109. `A`, thus an approximate trace is acceptable.
  110. For linear operators, `traceA` should be provided to ensure performance
  111. as the estimation is not guaranteed to be reliable for all cases.
  112. .. versionadded: 1.9.0
  113. Returns
  114. -------
  115. expm_A_B : ndarray
  116. The result of the action :math:`e^{t_k A} B`.
  117. Warns
  118. -----
  119. UserWarning
  120. If `A` is a linear operator and ``traceA=None`` (default).
  121. Notes
  122. -----
  123. The optional arguments defining the sequence of evenly spaced time points
  124. are compatible with the arguments of `numpy.linspace`.
  125. The output ndarray shape is somewhat complicated so I explain it here.
  126. The ndim of the output could be either 1, 2, or 3.
  127. It would be 1 if you are computing the expm action on a single vector
  128. at a single time point.
  129. It would be 2 if you are computing the expm action on a vector
  130. at multiple time points, or if you are computing the expm action
  131. on a matrix at a single time point.
  132. It would be 3 if you want the action on a matrix with multiple
  133. columns at multiple time points.
  134. If multiple time points are requested, expm_A_B[0] will always
  135. be the action of the expm at the first time point,
  136. regardless of whether the action is on a vector or a matrix.
  137. References
  138. ----------
  139. .. [1] Awad H. Al-Mohy and Nicholas J. Higham (2011)
  140. "Computing the Action of the Matrix Exponential,
  141. with an Application to Exponential Integrators."
  142. SIAM Journal on Scientific Computing,
  143. 33 (2). pp. 488-511. ISSN 1064-8275
  144. http://eprints.ma.man.ac.uk/1591/
  145. .. [2] Nicholas J. Higham and Awad H. Al-Mohy (2010)
  146. "Computing Matrix Functions."
  147. Acta Numerica,
  148. 19. 159-208. ISSN 0962-4929
  149. http://eprints.ma.man.ac.uk/1451/
  150. Examples
  151. --------
  152. >>> import numpy as np
  153. >>> from scipy.sparse import csc_matrix
  154. >>> from scipy.sparse.linalg import expm, expm_multiply
  155. >>> A = csc_matrix([[1, 0], [0, 1]])
  156. >>> A.toarray()
  157. array([[1, 0],
  158. [0, 1]], dtype=int64)
  159. >>> B = np.array([np.exp(-1.), np.exp(-2.)])
  160. >>> B
  161. array([ 0.36787944, 0.13533528])
  162. >>> expm_multiply(A, B, start=1, stop=2, num=3, endpoint=True)
  163. array([[ 1. , 0.36787944],
  164. [ 1.64872127, 0.60653066],
  165. [ 2.71828183, 1. ]])
  166. >>> expm(A).dot(B) # Verify 1st timestep
  167. array([ 1. , 0.36787944])
  168. >>> expm(1.5*A).dot(B) # Verify 2nd timestep
  169. array([ 1.64872127, 0.60653066])
  170. >>> expm(2*A).dot(B) # Verify 3rd timestep
  171. array([ 2.71828183, 1. ])
  172. """
  173. if all(arg is None for arg in (start, stop, num, endpoint)):
  174. X = _expm_multiply_simple(A, B, traceA=traceA)
  175. else:
  176. X, status = _expm_multiply_interval(A, B, start, stop, num,
  177. endpoint, traceA=traceA)
  178. return X
  179. def _expm_multiply_simple(A, B, t=1.0, traceA=None, balance=False):
  180. """
  181. Compute the action of the matrix exponential at a single time point.
  182. Parameters
  183. ----------
  184. A : transposable linear operator
  185. The operator whose exponential is of interest.
  186. B : ndarray
  187. The matrix to be multiplied by the matrix exponential of A.
  188. t : float
  189. A time point.
  190. traceA : scalar, optional
  191. Trace of `A`. If not given the trace is estimated for linear operators,
  192. or calculated exactly for sparse matrices. It is used to precondition
  193. `A`, thus an approximate trace is acceptable
  194. balance : bool
  195. Indicates whether or not to apply balancing.
  196. Returns
  197. -------
  198. F : ndarray
  199. :math:`e^{t A} B`
  200. Notes
  201. -----
  202. This is algorithm (3.2) in Al-Mohy and Higham (2011).
  203. """
  204. if balance:
  205. raise NotImplementedError
  206. if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
  207. raise ValueError('expected A to be like a square matrix')
  208. if A.shape[1] != B.shape[0]:
  209. raise ValueError('shapes of matrices A {} and B {} are incompatible'
  210. .format(A.shape, B.shape))
  211. ident = _ident_like(A)
  212. is_linear_operator = isinstance(A, scipy.sparse.linalg.LinearOperator)
  213. n = A.shape[0]
  214. if len(B.shape) == 1:
  215. n0 = 1
  216. elif len(B.shape) == 2:
  217. n0 = B.shape[1]
  218. else:
  219. raise ValueError('expected B to be like a matrix or a vector')
  220. u_d = 2**-53
  221. tol = u_d
  222. if traceA is None:
  223. if is_linear_operator:
  224. warn("Trace of LinearOperator not available, it will be estimated."
  225. " Provide `traceA` to ensure performance.", stacklevel=3)
  226. # m3=1 is bit arbitrary choice, a more accurate trace (larger m3) might
  227. # speed up exponential calculation, but trace estimation is more costly
  228. traceA = traceest(A, m3=1) if is_linear_operator else _trace(A)
  229. mu = traceA / float(n)
  230. A = A - mu * ident
  231. A_1_norm = onenormest(A) if is_linear_operator else _exact_1_norm(A)
  232. if t*A_1_norm == 0:
  233. m_star, s = 0, 1
  234. else:
  235. ell = 2
  236. norm_info = LazyOperatorNormInfo(t*A, A_1_norm=t*A_1_norm, ell=ell)
  237. m_star, s = _fragment_3_1(norm_info, n0, tol, ell=ell)
  238. return _expm_multiply_simple_core(A, B, t, mu, m_star, s, tol, balance)
  239. def _expm_multiply_simple_core(A, B, t, mu, m_star, s, tol=None, balance=False):
  240. """
  241. A helper function.
  242. """
  243. if balance:
  244. raise NotImplementedError
  245. if tol is None:
  246. u_d = 2 ** -53
  247. tol = u_d
  248. F = B
  249. eta = np.exp(t*mu / float(s))
  250. for i in range(s):
  251. c1 = _exact_inf_norm(B)
  252. for j in range(m_star):
  253. coeff = t / float(s*(j+1))
  254. B = coeff * A.dot(B)
  255. c2 = _exact_inf_norm(B)
  256. F = F + B
  257. if c1 + c2 <= tol * _exact_inf_norm(F):
  258. break
  259. c1 = c2
  260. F = eta * F
  261. B = F
  262. return F
  263. # This table helps to compute bounds.
  264. # They seem to have been difficult to calculate, involving symbolic
  265. # manipulation of equations, followed by numerical root finding.
  266. _theta = {
  267. # The first 30 values are from table A.3 of Computing Matrix Functions.
  268. 1: 2.29e-16,
  269. 2: 2.58e-8,
  270. 3: 1.39e-5,
  271. 4: 3.40e-4,
  272. 5: 2.40e-3,
  273. 6: 9.07e-3,
  274. 7: 2.38e-2,
  275. 8: 5.00e-2,
  276. 9: 8.96e-2,
  277. 10: 1.44e-1,
  278. # 11
  279. 11: 2.14e-1,
  280. 12: 3.00e-1,
  281. 13: 4.00e-1,
  282. 14: 5.14e-1,
  283. 15: 6.41e-1,
  284. 16: 7.81e-1,
  285. 17: 9.31e-1,
  286. 18: 1.09,
  287. 19: 1.26,
  288. 20: 1.44,
  289. # 21
  290. 21: 1.62,
  291. 22: 1.82,
  292. 23: 2.01,
  293. 24: 2.22,
  294. 25: 2.43,
  295. 26: 2.64,
  296. 27: 2.86,
  297. 28: 3.08,
  298. 29: 3.31,
  299. 30: 3.54,
  300. # The rest are from table 3.1 of
  301. # Computing the Action of the Matrix Exponential.
  302. 35: 4.7,
  303. 40: 6.0,
  304. 45: 7.2,
  305. 50: 8.5,
  306. 55: 9.9,
  307. }
  308. def _onenormest_matrix_power(A, p,
  309. t=2, itmax=5, compute_v=False, compute_w=False):
  310. """
  311. Efficiently estimate the 1-norm of A^p.
  312. Parameters
  313. ----------
  314. A : ndarray
  315. Matrix whose 1-norm of a power is to be computed.
  316. p : int
  317. Non-negative integer power.
  318. t : int, optional
  319. A positive parameter controlling the tradeoff between
  320. accuracy versus time and memory usage.
  321. Larger values take longer and use more memory
  322. but give more accurate output.
  323. itmax : int, optional
  324. Use at most this many iterations.
  325. compute_v : bool, optional
  326. Request a norm-maximizing linear operator input vector if True.
  327. compute_w : bool, optional
  328. Request a norm-maximizing linear operator output vector if True.
  329. Returns
  330. -------
  331. est : float
  332. An underestimate of the 1-norm of the sparse matrix.
  333. v : ndarray, optional
  334. The vector such that ||Av||_1 == est*||v||_1.
  335. It can be thought of as an input to the linear operator
  336. that gives an output with particularly large norm.
  337. w : ndarray, optional
  338. The vector Av which has relatively large 1-norm.
  339. It can be thought of as an output of the linear operator
  340. that is relatively large in norm compared to the input.
  341. """
  342. #XXX Eventually turn this into an API function in the _onenormest module,
  343. #XXX and remove its underscore,
  344. #XXX but wait until expm_multiply goes into scipy.
  345. from scipy.sparse.linalg._onenormest import onenormest
  346. return onenormest(aslinearoperator(A) ** p)
  347. class LazyOperatorNormInfo:
  348. """
  349. Information about an operator is lazily computed.
  350. The information includes the exact 1-norm of the operator,
  351. in addition to estimates of 1-norms of powers of the operator.
  352. This uses the notation of Computing the Action (2011).
  353. This class is specialized enough to probably not be of general interest
  354. outside of this module.
  355. """
  356. def __init__(self, A, A_1_norm=None, ell=2, scale=1):
  357. """
  358. Provide the operator and some norm-related information.
  359. Parameters
  360. ----------
  361. A : linear operator
  362. The operator of interest.
  363. A_1_norm : float, optional
  364. The exact 1-norm of A.
  365. ell : int, optional
  366. A technical parameter controlling norm estimation quality.
  367. scale : int, optional
  368. If specified, return the norms of scale*A instead of A.
  369. """
  370. self._A = A
  371. self._A_1_norm = A_1_norm
  372. self._ell = ell
  373. self._d = {}
  374. self._scale = scale
  375. def set_scale(self,scale):
  376. """
  377. Set the scale parameter.
  378. """
  379. self._scale = scale
  380. def onenorm(self):
  381. """
  382. Compute the exact 1-norm.
  383. """
  384. if self._A_1_norm is None:
  385. self._A_1_norm = _exact_1_norm(self._A)
  386. return self._scale*self._A_1_norm
  387. def d(self, p):
  388. """
  389. Lazily estimate d_p(A) ~= || A^p ||^(1/p) where ||.|| is the 1-norm.
  390. """
  391. if p not in self._d:
  392. est = _onenormest_matrix_power(self._A, p, self._ell)
  393. self._d[p] = est ** (1.0 / p)
  394. return self._scale*self._d[p]
  395. def alpha(self, p):
  396. """
  397. Lazily compute max(d(p), d(p+1)).
  398. """
  399. return max(self.d(p), self.d(p+1))
  400. def _compute_cost_div_m(m, p, norm_info):
  401. """
  402. A helper function for computing bounds.
  403. This is equation (3.10).
  404. It measures cost in terms of the number of required matrix products.
  405. Parameters
  406. ----------
  407. m : int
  408. A valid key of _theta.
  409. p : int
  410. A matrix power.
  411. norm_info : LazyOperatorNormInfo
  412. Information about 1-norms of related operators.
  413. Returns
  414. -------
  415. cost_div_m : int
  416. Required number of matrix products divided by m.
  417. """
  418. return int(np.ceil(norm_info.alpha(p) / _theta[m]))
  419. def _compute_p_max(m_max):
  420. """
  421. Compute the largest positive integer p such that p*(p-1) <= m_max + 1.
  422. Do this in a slightly dumb way, but safe and not too slow.
  423. Parameters
  424. ----------
  425. m_max : int
  426. A count related to bounds.
  427. """
  428. sqrt_m_max = np.sqrt(m_max)
  429. p_low = int(np.floor(sqrt_m_max))
  430. p_high = int(np.ceil(sqrt_m_max + 1))
  431. return max(p for p in range(p_low, p_high+1) if p*(p-1) <= m_max + 1)
  432. def _fragment_3_1(norm_info, n0, tol, m_max=55, ell=2):
  433. """
  434. A helper function for the _expm_multiply_* functions.
  435. Parameters
  436. ----------
  437. norm_info : LazyOperatorNormInfo
  438. Information about norms of certain linear operators of interest.
  439. n0 : int
  440. Number of columns in the _expm_multiply_* B matrix.
  441. tol : float
  442. Expected to be
  443. :math:`2^{-24}` for single precision or
  444. :math:`2^{-53}` for double precision.
  445. m_max : int
  446. A value related to a bound.
  447. ell : int
  448. The number of columns used in the 1-norm approximation.
  449. This is usually taken to be small, maybe between 1 and 5.
  450. Returns
  451. -------
  452. best_m : int
  453. Related to bounds for error control.
  454. best_s : int
  455. Amount of scaling.
  456. Notes
  457. -----
  458. This is code fragment (3.1) in Al-Mohy and Higham (2011).
  459. The discussion of default values for m_max and ell
  460. is given between the definitions of equation (3.11)
  461. and the definition of equation (3.12).
  462. """
  463. if ell < 1:
  464. raise ValueError('expected ell to be a positive integer')
  465. best_m = None
  466. best_s = None
  467. if _condition_3_13(norm_info.onenorm(), n0, m_max, ell):
  468. for m, theta in _theta.items():
  469. s = int(np.ceil(norm_info.onenorm() / theta))
  470. if best_m is None or m * s < best_m * best_s:
  471. best_m = m
  472. best_s = s
  473. else:
  474. # Equation (3.11).
  475. for p in range(2, _compute_p_max(m_max) + 1):
  476. for m in range(p*(p-1)-1, m_max+1):
  477. if m in _theta:
  478. s = _compute_cost_div_m(m, p, norm_info)
  479. if best_m is None or m * s < best_m * best_s:
  480. best_m = m
  481. best_s = s
  482. best_s = max(best_s, 1)
  483. return best_m, best_s
  484. def _condition_3_13(A_1_norm, n0, m_max, ell):
  485. """
  486. A helper function for the _expm_multiply_* functions.
  487. Parameters
  488. ----------
  489. A_1_norm : float
  490. The precomputed 1-norm of A.
  491. n0 : int
  492. Number of columns in the _expm_multiply_* B matrix.
  493. m_max : int
  494. A value related to a bound.
  495. ell : int
  496. The number of columns used in the 1-norm approximation.
  497. This is usually taken to be small, maybe between 1 and 5.
  498. Returns
  499. -------
  500. value : bool
  501. Indicates whether or not the condition has been met.
  502. Notes
  503. -----
  504. This is condition (3.13) in Al-Mohy and Higham (2011).
  505. """
  506. # This is the rhs of equation (3.12).
  507. p_max = _compute_p_max(m_max)
  508. a = 2 * ell * p_max * (p_max + 3)
  509. # Evaluate the condition (3.13).
  510. b = _theta[m_max] / float(n0 * m_max)
  511. return A_1_norm <= a * b
  512. def _expm_multiply_interval(A, B, start=None, stop=None, num=None,
  513. endpoint=None, traceA=None, balance=False,
  514. status_only=False):
  515. """
  516. Compute the action of the matrix exponential at multiple time points.
  517. Parameters
  518. ----------
  519. A : transposable linear operator
  520. The operator whose exponential is of interest.
  521. B : ndarray
  522. The matrix to be multiplied by the matrix exponential of A.
  523. start : scalar, optional
  524. The starting time point of the sequence.
  525. stop : scalar, optional
  526. The end time point of the sequence, unless `endpoint` is set to False.
  527. In that case, the sequence consists of all but the last of ``num + 1``
  528. evenly spaced time points, so that `stop` is excluded.
  529. Note that the step size changes when `endpoint` is False.
  530. num : int, optional
  531. Number of time points to use.
  532. traceA : scalar, optional
  533. Trace of `A`. If not given the trace is estimated for linear operators,
  534. or calculated exactly for sparse matrices. It is used to precondition
  535. `A`, thus an approximate trace is acceptable
  536. endpoint : bool, optional
  537. If True, `stop` is the last time point. Otherwise, it is not included.
  538. balance : bool
  539. Indicates whether or not to apply balancing.
  540. status_only : bool
  541. A flag that is set to True for some debugging and testing operations.
  542. Returns
  543. -------
  544. F : ndarray
  545. :math:`e^{t_k A} B`
  546. status : int
  547. An integer status for testing and debugging.
  548. Notes
  549. -----
  550. This is algorithm (5.2) in Al-Mohy and Higham (2011).
  551. There seems to be a typo, where line 15 of the algorithm should be
  552. moved to line 6.5 (between lines 6 and 7).
  553. """
  554. if balance:
  555. raise NotImplementedError
  556. if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
  557. raise ValueError('expected A to be like a square matrix')
  558. if A.shape[1] != B.shape[0]:
  559. raise ValueError('shapes of matrices A {} and B {} are incompatible'
  560. .format(A.shape, B.shape))
  561. ident = _ident_like(A)
  562. is_linear_operator = isinstance(A, scipy.sparse.linalg.LinearOperator)
  563. n = A.shape[0]
  564. if len(B.shape) == 1:
  565. n0 = 1
  566. elif len(B.shape) == 2:
  567. n0 = B.shape[1]
  568. else:
  569. raise ValueError('expected B to be like a matrix or a vector')
  570. u_d = 2**-53
  571. tol = u_d
  572. if traceA is None:
  573. if is_linear_operator:
  574. warn("Trace of LinearOperator not available, it will be estimated."
  575. " Provide `traceA` to ensure performance.", stacklevel=3)
  576. # m3=5 is bit arbitrary choice, a more accurate trace (larger m3) might
  577. # speed up exponential calculation, but trace estimation is also costly
  578. # an educated guess would need to consider the number of time points
  579. traceA = traceest(A, m3=5) if is_linear_operator else _trace(A)
  580. mu = traceA / float(n)
  581. # Get the linspace samples, attempting to preserve the linspace defaults.
  582. linspace_kwargs = {'retstep': True}
  583. if num is not None:
  584. linspace_kwargs['num'] = num
  585. if endpoint is not None:
  586. linspace_kwargs['endpoint'] = endpoint
  587. samples, step = np.linspace(start, stop, **linspace_kwargs)
  588. # Convert the linspace output to the notation used by the publication.
  589. nsamples = len(samples)
  590. if nsamples < 2:
  591. raise ValueError('at least two time points are required')
  592. q = nsamples - 1
  593. h = step
  594. t_0 = samples[0]
  595. t_q = samples[q]
  596. # Define the output ndarray.
  597. # Use an ndim=3 shape, such that the last two indices
  598. # are the ones that may be involved in level 3 BLAS operations.
  599. X_shape = (nsamples,) + B.shape
  600. X = np.empty(X_shape, dtype=np.result_type(A.dtype, B.dtype, float))
  601. t = t_q - t_0
  602. A = A - mu * ident
  603. A_1_norm = onenormest(A) if is_linear_operator else _exact_1_norm(A)
  604. ell = 2
  605. norm_info = LazyOperatorNormInfo(t*A, A_1_norm=t*A_1_norm, ell=ell)
  606. if t*A_1_norm == 0:
  607. m_star, s = 0, 1
  608. else:
  609. m_star, s = _fragment_3_1(norm_info, n0, tol, ell=ell)
  610. # Compute the expm action up to the initial time point.
  611. X[0] = _expm_multiply_simple_core(A, B, t_0, mu, m_star, s)
  612. # Compute the expm action at the rest of the time points.
  613. if q <= s:
  614. if status_only:
  615. return 0
  616. else:
  617. return _expm_multiply_interval_core_0(A, X,
  618. h, mu, q, norm_info, tol, ell,n0)
  619. elif not (q % s):
  620. if status_only:
  621. return 1
  622. else:
  623. return _expm_multiply_interval_core_1(A, X,
  624. h, mu, m_star, s, q, tol)
  625. elif (q % s):
  626. if status_only:
  627. return 2
  628. else:
  629. return _expm_multiply_interval_core_2(A, X,
  630. h, mu, m_star, s, q, tol)
  631. else:
  632. raise Exception('internal error')
  633. def _expm_multiply_interval_core_0(A, X, h, mu, q, norm_info, tol, ell, n0):
  634. """
  635. A helper function, for the case q <= s.
  636. """
  637. # Compute the new values of m_star and s which should be applied
  638. # over intervals of size t/q
  639. if norm_info.onenorm() == 0:
  640. m_star, s = 0, 1
  641. else:
  642. norm_info.set_scale(1./q)
  643. m_star, s = _fragment_3_1(norm_info, n0, tol, ell=ell)
  644. norm_info.set_scale(1)
  645. for k in range(q):
  646. X[k+1] = _expm_multiply_simple_core(A, X[k], h, mu, m_star, s)
  647. return X, 0
  648. def _expm_multiply_interval_core_1(A, X, h, mu, m_star, s, q, tol):
  649. """
  650. A helper function, for the case q > s and q % s == 0.
  651. """
  652. d = q // s
  653. input_shape = X.shape[1:]
  654. K_shape = (m_star + 1, ) + input_shape
  655. K = np.empty(K_shape, dtype=X.dtype)
  656. for i in range(s):
  657. Z = X[i*d]
  658. K[0] = Z
  659. high_p = 0
  660. for k in range(1, d+1):
  661. F = K[0]
  662. c1 = _exact_inf_norm(F)
  663. for p in range(1, m_star+1):
  664. if p > high_p:
  665. K[p] = h * A.dot(K[p-1]) / float(p)
  666. coeff = float(pow(k, p))
  667. F = F + coeff * K[p]
  668. inf_norm_K_p_1 = _exact_inf_norm(K[p])
  669. c2 = coeff * inf_norm_K_p_1
  670. if c1 + c2 <= tol * _exact_inf_norm(F):
  671. break
  672. c1 = c2
  673. X[k + i*d] = np.exp(k*h*mu) * F
  674. return X, 1
  675. def _expm_multiply_interval_core_2(A, X, h, mu, m_star, s, q, tol):
  676. """
  677. A helper function, for the case q > s and q % s > 0.
  678. """
  679. d = q // s
  680. j = q // d
  681. r = q - d * j
  682. input_shape = X.shape[1:]
  683. K_shape = (m_star + 1, ) + input_shape
  684. K = np.empty(K_shape, dtype=X.dtype)
  685. for i in range(j + 1):
  686. Z = X[i*d]
  687. K[0] = Z
  688. high_p = 0
  689. if i < j:
  690. effective_d = d
  691. else:
  692. effective_d = r
  693. for k in range(1, effective_d+1):
  694. F = K[0]
  695. c1 = _exact_inf_norm(F)
  696. for p in range(1, m_star+1):
  697. if p == high_p + 1:
  698. K[p] = h * A.dot(K[p-1]) / float(p)
  699. high_p = p
  700. coeff = float(pow(k, p))
  701. F = F + coeff * K[p]
  702. inf_norm_K_p_1 = _exact_inf_norm(K[p])
  703. c2 = coeff * inf_norm_K_p_1
  704. if c1 + c2 <= tol * _exact_inf_norm(F):
  705. break
  706. c1 = c2
  707. X[k + i*d] = np.exp(k*h*mu) * F
  708. return X, 2