base.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import numpy as np
  2. def check_arguments(fun, y0, support_complex):
  3. """Helper function for checking arguments common to all solvers."""
  4. y0 = np.asarray(y0)
  5. if np.issubdtype(y0.dtype, np.complexfloating):
  6. if not support_complex:
  7. raise ValueError("`y0` is complex, but the chosen solver does "
  8. "not support integration in a complex domain.")
  9. dtype = complex
  10. else:
  11. dtype = float
  12. y0 = y0.astype(dtype, copy=False)
  13. if y0.ndim != 1:
  14. raise ValueError("`y0` must be 1-dimensional.")
  15. def fun_wrapped(t, y):
  16. return np.asarray(fun(t, y), dtype=dtype)
  17. return fun_wrapped, y0
  18. class OdeSolver:
  19. """Base class for ODE solvers.
  20. In order to implement a new solver you need to follow the guidelines:
  21. 1. A constructor must accept parameters presented in the base class
  22. (listed below) along with any other parameters specific to a solver.
  23. 2. A constructor must accept arbitrary extraneous arguments
  24. ``**extraneous``, but warn that these arguments are irrelevant
  25. using `common.warn_extraneous` function. Do not pass these
  26. arguments to the base class.
  27. 3. A solver must implement a private method `_step_impl(self)` which
  28. propagates a solver one step further. It must return tuple
  29. ``(success, message)``, where ``success`` is a boolean indicating
  30. whether a step was successful, and ``message`` is a string
  31. containing description of a failure if a step failed or None
  32. otherwise.
  33. 4. A solver must implement a private method `_dense_output_impl(self)`,
  34. which returns a `DenseOutput` object covering the last successful
  35. step.
  36. 5. A solver must have attributes listed below in Attributes section.
  37. Note that ``t_old`` and ``step_size`` are updated automatically.
  38. 6. Use `fun(self, t, y)` method for the system rhs evaluation, this
  39. way the number of function evaluations (`nfev`) will be tracked
  40. automatically.
  41. 7. For convenience, a base class provides `fun_single(self, t, y)` and
  42. `fun_vectorized(self, t, y)` for evaluating the rhs in
  43. non-vectorized and vectorized fashions respectively (regardless of
  44. how `fun` from the constructor is implemented). These calls don't
  45. increment `nfev`.
  46. 8. If a solver uses a Jacobian matrix and LU decompositions, it should
  47. track the number of Jacobian evaluations (`njev`) and the number of
  48. LU decompositions (`nlu`).
  49. 9. By convention, the function evaluations used to compute a finite
  50. difference approximation of the Jacobian should not be counted in
  51. `nfev`, thus use `fun_single(self, t, y)` or
  52. `fun_vectorized(self, t, y)` when computing a finite difference
  53. approximation of the Jacobian.
  54. Parameters
  55. ----------
  56. fun : callable
  57. Right-hand side of the system. The calling signature is ``fun(t, y)``.
  58. Here ``t`` is a scalar and there are two options for ndarray ``y``.
  59. It can either have shape (n,), then ``fun`` must return array_like with
  60. shape (n,). Or, alternatively, it can have shape (n, n_points), then
  61. ``fun`` must return array_like with shape (n, n_points) (each column
  62. corresponds to a single column in ``y``). The choice between the two
  63. options is determined by `vectorized` argument (see below).
  64. t0 : float
  65. Initial time.
  66. y0 : array_like, shape (n,)
  67. Initial state.
  68. t_bound : float
  69. Boundary time --- the integration won't continue beyond it. It also
  70. determines the direction of the integration.
  71. vectorized : bool
  72. Whether `fun` is implemented in a vectorized fashion.
  73. support_complex : bool, optional
  74. Whether integration in a complex domain should be supported.
  75. Generally determined by a derived solver class capabilities.
  76. Default is False.
  77. Attributes
  78. ----------
  79. n : int
  80. Number of equations.
  81. status : string
  82. Current status of the solver: 'running', 'finished' or 'failed'.
  83. t_bound : float
  84. Boundary time.
  85. direction : float
  86. Integration direction: +1 or -1.
  87. t : float
  88. Current time.
  89. y : ndarray
  90. Current state.
  91. t_old : float
  92. Previous time. None if no steps were made yet.
  93. step_size : float
  94. Size of the last successful step. None if no steps were made yet.
  95. nfev : int
  96. Number of the system's rhs evaluations.
  97. njev : int
  98. Number of the Jacobian evaluations.
  99. nlu : int
  100. Number of LU decompositions.
  101. """
  102. TOO_SMALL_STEP = "Required step size is less than spacing between numbers."
  103. def __init__(self, fun, t0, y0, t_bound, vectorized,
  104. support_complex=False):
  105. self.t_old = None
  106. self.t = t0
  107. self._fun, self.y = check_arguments(fun, y0, support_complex)
  108. self.t_bound = t_bound
  109. self.vectorized = vectorized
  110. if vectorized:
  111. def fun_single(t, y):
  112. return self._fun(t, y[:, None]).ravel()
  113. fun_vectorized = self._fun
  114. else:
  115. fun_single = self._fun
  116. def fun_vectorized(t, y):
  117. f = np.empty_like(y)
  118. for i, yi in enumerate(y.T):
  119. f[:, i] = self._fun(t, yi)
  120. return f
  121. def fun(t, y):
  122. self.nfev += 1
  123. return self.fun_single(t, y)
  124. self.fun = fun
  125. self.fun_single = fun_single
  126. self.fun_vectorized = fun_vectorized
  127. self.direction = np.sign(t_bound - t0) if t_bound != t0 else 1
  128. self.n = self.y.size
  129. self.status = 'running'
  130. self.nfev = 0
  131. self.njev = 0
  132. self.nlu = 0
  133. @property
  134. def step_size(self):
  135. if self.t_old is None:
  136. return None
  137. else:
  138. return np.abs(self.t - self.t_old)
  139. def step(self):
  140. """Perform one integration step.
  141. Returns
  142. -------
  143. message : string or None
  144. Report from the solver. Typically a reason for a failure if
  145. `self.status` is 'failed' after the step was taken or None
  146. otherwise.
  147. """
  148. if self.status != 'running':
  149. raise RuntimeError("Attempt to step on a failed or finished "
  150. "solver.")
  151. if self.n == 0 or self.t == self.t_bound:
  152. # Handle corner cases of empty solver or no integration.
  153. self.t_old = self.t
  154. self.t = self.t_bound
  155. message = None
  156. self.status = 'finished'
  157. else:
  158. t = self.t
  159. success, message = self._step_impl()
  160. if not success:
  161. self.status = 'failed'
  162. else:
  163. self.t_old = t
  164. if self.direction * (self.t - self.t_bound) >= 0:
  165. self.status = 'finished'
  166. return message
  167. def dense_output(self):
  168. """Compute a local interpolant over the last successful step.
  169. Returns
  170. -------
  171. sol : `DenseOutput`
  172. Local interpolant over the last successful step.
  173. """
  174. if self.t_old is None:
  175. raise RuntimeError("Dense output is available after a successful "
  176. "step was made.")
  177. if self.n == 0 or self.t == self.t_old:
  178. # Handle corner cases of empty solver and no integration.
  179. return ConstantDenseOutput(self.t_old, self.t, self.y)
  180. else:
  181. return self._dense_output_impl()
  182. def _step_impl(self):
  183. raise NotImplementedError
  184. def _dense_output_impl(self):
  185. raise NotImplementedError
  186. class DenseOutput:
  187. """Base class for local interpolant over step made by an ODE solver.
  188. It interpolates between `t_min` and `t_max` (see Attributes below).
  189. Evaluation outside this interval is not forbidden, but the accuracy is not
  190. guaranteed.
  191. Attributes
  192. ----------
  193. t_min, t_max : float
  194. Time range of the interpolation.
  195. """
  196. def __init__(self, t_old, t):
  197. self.t_old = t_old
  198. self.t = t
  199. self.t_min = min(t, t_old)
  200. self.t_max = max(t, t_old)
  201. def __call__(self, t):
  202. """Evaluate the interpolant.
  203. Parameters
  204. ----------
  205. t : float or array_like with shape (n_points,)
  206. Points to evaluate the solution at.
  207. Returns
  208. -------
  209. y : ndarray, shape (n,) or (n, n_points)
  210. Computed values. Shape depends on whether `t` was a scalar or a
  211. 1-D array.
  212. """
  213. t = np.asarray(t)
  214. if t.ndim > 1:
  215. raise ValueError("`t` must be a float or a 1-D array.")
  216. return self._call_impl(t)
  217. def _call_impl(self, t):
  218. raise NotImplementedError
  219. class ConstantDenseOutput(DenseOutput):
  220. """Constant value interpolator.
  221. This class used for degenerate integration cases: equal integration limits
  222. or a system with 0 equations.
  223. """
  224. def __init__(self, t_old, t, value):
  225. super().__init__(t_old, t)
  226. self.value = value
  227. def _call_impl(self, t):
  228. if t.ndim == 0:
  229. return self.value
  230. else:
  231. ret = np.empty((self.value.shape[0], t.shape[0]))
  232. ret[:] = self.value[:, None]
  233. return ret