test_banded_ode_solvers.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. import itertools
  2. import numpy as np
  3. from numpy.testing import assert_allclose
  4. from scipy.integrate import ode
  5. def _band_count(a):
  6. """Returns ml and mu, the lower and upper band sizes of a."""
  7. nrows, ncols = a.shape
  8. ml = 0
  9. for k in range(-nrows+1, 0):
  10. if np.diag(a, k).any():
  11. ml = -k
  12. break
  13. mu = 0
  14. for k in range(nrows-1, 0, -1):
  15. if np.diag(a, k).any():
  16. mu = k
  17. break
  18. return ml, mu
  19. def _linear_func(t, y, a):
  20. """Linear system dy/dt = a * y"""
  21. return a.dot(y)
  22. def _linear_jac(t, y, a):
  23. """Jacobian of a * y is a."""
  24. return a
  25. def _linear_banded_jac(t, y, a):
  26. """Banded Jacobian."""
  27. ml, mu = _band_count(a)
  28. bjac = [np.r_[[0] * k, np.diag(a, k)] for k in range(mu, 0, -1)]
  29. bjac.append(np.diag(a))
  30. for k in range(-1, -ml-1, -1):
  31. bjac.append(np.r_[np.diag(a, k), [0] * (-k)])
  32. return bjac
  33. def _solve_linear_sys(a, y0, tend=1, dt=0.1,
  34. solver=None, method='bdf', use_jac=True,
  35. with_jacobian=False, banded=False):
  36. """Use scipy.integrate.ode to solve a linear system of ODEs.
  37. a : square ndarray
  38. Matrix of the linear system to be solved.
  39. y0 : ndarray
  40. Initial condition
  41. tend : float
  42. Stop time.
  43. dt : float
  44. Step size of the output.
  45. solver : str
  46. If not None, this must be "vode", "lsoda" or "zvode".
  47. method : str
  48. Either "bdf" or "adams".
  49. use_jac : bool
  50. Determines if the jacobian function is passed to ode().
  51. with_jacobian : bool
  52. Passed to ode.set_integrator().
  53. banded : bool
  54. Determines whether a banded or full jacobian is used.
  55. If `banded` is True, `lband` and `uband` are determined by the
  56. values in `a`.
  57. """
  58. if banded:
  59. lband, uband = _band_count(a)
  60. else:
  61. lband = None
  62. uband = None
  63. if use_jac:
  64. if banded:
  65. r = ode(_linear_func, _linear_banded_jac)
  66. else:
  67. r = ode(_linear_func, _linear_jac)
  68. else:
  69. r = ode(_linear_func)
  70. if solver is None:
  71. if np.iscomplexobj(a):
  72. solver = "zvode"
  73. else:
  74. solver = "vode"
  75. r.set_integrator(solver,
  76. with_jacobian=with_jacobian,
  77. method=method,
  78. lband=lband, uband=uband,
  79. rtol=1e-9, atol=1e-10,
  80. )
  81. t0 = 0
  82. r.set_initial_value(y0, t0)
  83. r.set_f_params(a)
  84. r.set_jac_params(a)
  85. t = [t0]
  86. y = [y0]
  87. while r.successful() and r.t < tend:
  88. r.integrate(r.t + dt)
  89. t.append(r.t)
  90. y.append(r.y)
  91. t = np.array(t)
  92. y = np.array(y)
  93. return t, y
  94. def _analytical_solution(a, y0, t):
  95. """
  96. Analytical solution to the linear differential equations dy/dt = a*y.
  97. The solution is only valid if `a` is diagonalizable.
  98. Returns a 2-D array with shape (len(t), len(y0)).
  99. """
  100. lam, v = np.linalg.eig(a)
  101. c = np.linalg.solve(v, y0)
  102. e = c * np.exp(lam * t.reshape(-1, 1))
  103. sol = e.dot(v.T)
  104. return sol
  105. def test_banded_ode_solvers():
  106. # Test the "lsoda", "vode" and "zvode" solvers of the `ode` class
  107. # with a system that has a banded Jacobian matrix.
  108. t_exact = np.linspace(0, 1.0, 5)
  109. # --- Real arrays for testing the "lsoda" and "vode" solvers ---
  110. # lband = 2, uband = 1:
  111. a_real = np.array([[-0.6, 0.1, 0.0, 0.0, 0.0],
  112. [0.2, -0.5, 0.9, 0.0, 0.0],
  113. [0.1, 0.1, -0.4, 0.1, 0.0],
  114. [0.0, 0.3, -0.1, -0.9, -0.3],
  115. [0.0, 0.0, 0.1, 0.1, -0.7]])
  116. # lband = 0, uband = 1:
  117. a_real_upper = np.triu(a_real)
  118. # lband = 2, uband = 0:
  119. a_real_lower = np.tril(a_real)
  120. # lband = 0, uband = 0:
  121. a_real_diag = np.triu(a_real_lower)
  122. real_matrices = [a_real, a_real_upper, a_real_lower, a_real_diag]
  123. real_solutions = []
  124. for a in real_matrices:
  125. y0 = np.arange(1, a.shape[0] + 1)
  126. y_exact = _analytical_solution(a, y0, t_exact)
  127. real_solutions.append((y0, t_exact, y_exact))
  128. def check_real(idx, solver, meth, use_jac, with_jac, banded):
  129. a = real_matrices[idx]
  130. y0, t_exact, y_exact = real_solutions[idx]
  131. t, y = _solve_linear_sys(a, y0,
  132. tend=t_exact[-1],
  133. dt=t_exact[1] - t_exact[0],
  134. solver=solver,
  135. method=meth,
  136. use_jac=use_jac,
  137. with_jacobian=with_jac,
  138. banded=banded)
  139. assert_allclose(t, t_exact)
  140. assert_allclose(y, y_exact)
  141. for idx in range(len(real_matrices)):
  142. p = [['vode', 'lsoda'], # solver
  143. ['bdf', 'adams'], # method
  144. [False, True], # use_jac
  145. [False, True], # with_jacobian
  146. [False, True]] # banded
  147. for solver, meth, use_jac, with_jac, banded in itertools.product(*p):
  148. check_real(idx, solver, meth, use_jac, with_jac, banded)
  149. # --- Complex arrays for testing the "zvode" solver ---
  150. # complex, lband = 2, uband = 1:
  151. a_complex = a_real - 0.5j * a_real
  152. # complex, lband = 0, uband = 0:
  153. a_complex_diag = np.diag(np.diag(a_complex))
  154. complex_matrices = [a_complex, a_complex_diag]
  155. complex_solutions = []
  156. for a in complex_matrices:
  157. y0 = np.arange(1, a.shape[0] + 1) + 1j
  158. y_exact = _analytical_solution(a, y0, t_exact)
  159. complex_solutions.append((y0, t_exact, y_exact))
  160. def check_complex(idx, solver, meth, use_jac, with_jac, banded):
  161. a = complex_matrices[idx]
  162. y0, t_exact, y_exact = complex_solutions[idx]
  163. t, y = _solve_linear_sys(a, y0,
  164. tend=t_exact[-1],
  165. dt=t_exact[1] - t_exact[0],
  166. solver=solver,
  167. method=meth,
  168. use_jac=use_jac,
  169. with_jacobian=with_jac,
  170. banded=banded)
  171. assert_allclose(t, t_exact)
  172. assert_allclose(y, y_exact)
  173. for idx in range(len(complex_matrices)):
  174. p = [['bdf', 'adams'], # method
  175. [False, True], # use_jac
  176. [False, True], # with_jacobian
  177. [False, True]] # banded
  178. for meth, use_jac, with_jac, banded in itertools.product(*p):
  179. check_complex(idx, "zvode", meth, use_jac, with_jac, banded)