test_ivp.py 34 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040
  1. from itertools import product
  2. from numpy.testing import (assert_, assert_allclose,
  3. assert_equal, assert_no_warnings, suppress_warnings)
  4. import pytest
  5. from pytest import raises as assert_raises
  6. import numpy as np
  7. from scipy.optimize._numdiff import group_columns
  8. from scipy.integrate import solve_ivp, RK23, RK45, DOP853, Radau, BDF, LSODA
  9. from scipy.integrate import OdeSolution
  10. from scipy.integrate._ivp.common import num_jac
  11. from scipy.integrate._ivp.base import ConstantDenseOutput
  12. from scipy.sparse import coo_matrix, csc_matrix
  13. def fun_zero(t, y):
  14. return np.zeros_like(y)
  15. def fun_linear(t, y):
  16. return np.array([-y[0] - 5 * y[1], y[0] + y[1]])
  17. def jac_linear():
  18. return np.array([[-1, -5], [1, 1]])
  19. def sol_linear(t):
  20. return np.vstack((-5 * np.sin(2 * t),
  21. 2 * np.cos(2 * t) + np.sin(2 * t)))
  22. def fun_rational(t, y):
  23. return np.array([y[1] / t,
  24. y[1] * (y[0] + 2 * y[1] - 1) / (t * (y[0] - 1))])
  25. def fun_rational_vectorized(t, y):
  26. return np.vstack((y[1] / t,
  27. y[1] * (y[0] + 2 * y[1] - 1) / (t * (y[0] - 1))))
  28. def jac_rational(t, y):
  29. return np.array([
  30. [0, 1 / t],
  31. [-2 * y[1] ** 2 / (t * (y[0] - 1) ** 2),
  32. (y[0] + 4 * y[1] - 1) / (t * (y[0] - 1))]
  33. ])
  34. def jac_rational_sparse(t, y):
  35. return csc_matrix([
  36. [0, 1 / t],
  37. [-2 * y[1] ** 2 / (t * (y[0] - 1) ** 2),
  38. (y[0] + 4 * y[1] - 1) / (t * (y[0] - 1))]
  39. ])
  40. def sol_rational(t):
  41. return np.asarray((t / (t + 10), 10 * t / (t + 10) ** 2))
  42. def fun_medazko(t, y):
  43. n = y.shape[0] // 2
  44. k = 100
  45. c = 4
  46. phi = 2 if t <= 5 else 0
  47. y = np.hstack((phi, 0, y, y[-2]))
  48. d = 1 / n
  49. j = np.arange(n) + 1
  50. alpha = 2 * (j * d - 1) ** 3 / c ** 2
  51. beta = (j * d - 1) ** 4 / c ** 2
  52. j_2_p1 = 2 * j + 2
  53. j_2_m3 = 2 * j - 2
  54. j_2_m1 = 2 * j
  55. j_2 = 2 * j + 1
  56. f = np.empty(2 * n)
  57. f[::2] = (alpha * (y[j_2_p1] - y[j_2_m3]) / (2 * d) +
  58. beta * (y[j_2_m3] - 2 * y[j_2_m1] + y[j_2_p1]) / d ** 2 -
  59. k * y[j_2_m1] * y[j_2])
  60. f[1::2] = -k * y[j_2] * y[j_2_m1]
  61. return f
  62. def medazko_sparsity(n):
  63. cols = []
  64. rows = []
  65. i = np.arange(n) * 2
  66. cols.append(i[1:])
  67. rows.append(i[1:] - 2)
  68. cols.append(i)
  69. rows.append(i)
  70. cols.append(i)
  71. rows.append(i + 1)
  72. cols.append(i[:-1])
  73. rows.append(i[:-1] + 2)
  74. i = np.arange(n) * 2 + 1
  75. cols.append(i)
  76. rows.append(i)
  77. cols.append(i)
  78. rows.append(i - 1)
  79. cols = np.hstack(cols)
  80. rows = np.hstack(rows)
  81. return coo_matrix((np.ones_like(cols), (cols, rows)))
  82. def fun_complex(t, y):
  83. return -y
  84. def jac_complex(t, y):
  85. return -np.eye(y.shape[0])
  86. def jac_complex_sparse(t, y):
  87. return csc_matrix(jac_complex(t, y))
  88. def sol_complex(t):
  89. y = (0.5 + 1j) * np.exp(-t)
  90. return y.reshape((1, -1))
  91. def compute_error(y, y_true, rtol, atol):
  92. e = (y - y_true) / (atol + rtol * np.abs(y_true))
  93. return np.linalg.norm(e, axis=0) / np.sqrt(e.shape[0])
  94. def test_integration():
  95. rtol = 1e-3
  96. atol = 1e-6
  97. y0 = [1/3, 2/9]
  98. for vectorized, method, t_span, jac in product(
  99. [False, True],
  100. ['RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA'],
  101. [[5, 9], [5, 1]],
  102. [None, jac_rational, jac_rational_sparse]):
  103. if vectorized:
  104. fun = fun_rational_vectorized
  105. else:
  106. fun = fun_rational
  107. with suppress_warnings() as sup:
  108. sup.filter(UserWarning,
  109. "The following arguments have no effect for a chosen "
  110. "solver: `jac`")
  111. res = solve_ivp(fun, t_span, y0, rtol=rtol,
  112. atol=atol, method=method, dense_output=True,
  113. jac=jac, vectorized=vectorized)
  114. assert_equal(res.t[0], t_span[0])
  115. assert_(res.t_events is None)
  116. assert_(res.y_events is None)
  117. assert_(res.success)
  118. assert_equal(res.status, 0)
  119. if method == 'DOP853':
  120. # DOP853 spends more functions evaluation because it doesn't
  121. # have enough time to develop big enough step size.
  122. assert_(res.nfev < 50)
  123. else:
  124. assert_(res.nfev < 40)
  125. if method in ['RK23', 'RK45', 'DOP853', 'LSODA']:
  126. assert_equal(res.njev, 0)
  127. assert_equal(res.nlu, 0)
  128. else:
  129. assert_(0 < res.njev < 3)
  130. assert_(0 < res.nlu < 10)
  131. y_true = sol_rational(res.t)
  132. e = compute_error(res.y, y_true, rtol, atol)
  133. assert_(np.all(e < 5))
  134. tc = np.linspace(*t_span)
  135. yc_true = sol_rational(tc)
  136. yc = res.sol(tc)
  137. e = compute_error(yc, yc_true, rtol, atol)
  138. assert_(np.all(e < 5))
  139. tc = (t_span[0] + t_span[-1]) / 2
  140. yc_true = sol_rational(tc)
  141. yc = res.sol(tc)
  142. e = compute_error(yc, yc_true, rtol, atol)
  143. assert_(np.all(e < 5))
  144. # LSODA for some reasons doesn't pass the polynomial through the
  145. # previous points exactly after the order change. It might be some
  146. # bug in LSOSA implementation or maybe we missing something.
  147. if method != 'LSODA':
  148. assert_allclose(res.sol(res.t), res.y, rtol=1e-15, atol=1e-15)
  149. def test_integration_complex():
  150. rtol = 1e-3
  151. atol = 1e-6
  152. y0 = [0.5 + 1j]
  153. t_span = [0, 1]
  154. tc = np.linspace(t_span[0], t_span[1])
  155. for method, jac in product(['RK23', 'RK45', 'DOP853', 'BDF'],
  156. [None, jac_complex, jac_complex_sparse]):
  157. with suppress_warnings() as sup:
  158. sup.filter(UserWarning,
  159. "The following arguments have no effect for a chosen "
  160. "solver: `jac`")
  161. res = solve_ivp(fun_complex, t_span, y0, method=method,
  162. dense_output=True, rtol=rtol, atol=atol, jac=jac)
  163. assert_equal(res.t[0], t_span[0])
  164. assert_(res.t_events is None)
  165. assert_(res.y_events is None)
  166. assert_(res.success)
  167. assert_equal(res.status, 0)
  168. if method == 'DOP853':
  169. assert res.nfev < 35
  170. else:
  171. assert res.nfev < 25
  172. if method == 'BDF':
  173. assert_equal(res.njev, 1)
  174. assert res.nlu < 6
  175. else:
  176. assert res.njev == 0
  177. assert res.nlu == 0
  178. y_true = sol_complex(res.t)
  179. e = compute_error(res.y, y_true, rtol, atol)
  180. assert np.all(e < 5)
  181. yc_true = sol_complex(tc)
  182. yc = res.sol(tc)
  183. e = compute_error(yc, yc_true, rtol, atol)
  184. assert np.all(e < 5)
  185. def test_integration_sparse_difference():
  186. n = 200
  187. t_span = [0, 20]
  188. y0 = np.zeros(2 * n)
  189. y0[1::2] = 1
  190. sparsity = medazko_sparsity(n)
  191. for method in ['BDF', 'Radau']:
  192. res = solve_ivp(fun_medazko, t_span, y0, method=method,
  193. jac_sparsity=sparsity)
  194. assert_equal(res.t[0], t_span[0])
  195. assert_(res.t_events is None)
  196. assert_(res.y_events is None)
  197. assert_(res.success)
  198. assert_equal(res.status, 0)
  199. assert_allclose(res.y[78, -1], 0.233994e-3, rtol=1e-2)
  200. assert_allclose(res.y[79, -1], 0, atol=1e-3)
  201. assert_allclose(res.y[148, -1], 0.359561e-3, rtol=1e-2)
  202. assert_allclose(res.y[149, -1], 0, atol=1e-3)
  203. assert_allclose(res.y[198, -1], 0.117374129e-3, rtol=1e-2)
  204. assert_allclose(res.y[199, -1], 0.6190807e-5, atol=1e-3)
  205. assert_allclose(res.y[238, -1], 0, atol=1e-3)
  206. assert_allclose(res.y[239, -1], 0.9999997, rtol=1e-2)
  207. def test_integration_const_jac():
  208. rtol = 1e-3
  209. atol = 1e-6
  210. y0 = [0, 2]
  211. t_span = [0, 2]
  212. J = jac_linear()
  213. J_sparse = csc_matrix(J)
  214. for method, jac in product(['Radau', 'BDF'], [J, J_sparse]):
  215. res = solve_ivp(fun_linear, t_span, y0, rtol=rtol, atol=atol,
  216. method=method, dense_output=True, jac=jac)
  217. assert_equal(res.t[0], t_span[0])
  218. assert_(res.t_events is None)
  219. assert_(res.y_events is None)
  220. assert_(res.success)
  221. assert_equal(res.status, 0)
  222. assert_(res.nfev < 100)
  223. assert_equal(res.njev, 0)
  224. assert_(0 < res.nlu < 15)
  225. y_true = sol_linear(res.t)
  226. e = compute_error(res.y, y_true, rtol, atol)
  227. assert_(np.all(e < 10))
  228. tc = np.linspace(*t_span)
  229. yc_true = sol_linear(tc)
  230. yc = res.sol(tc)
  231. e = compute_error(yc, yc_true, rtol, atol)
  232. assert_(np.all(e < 15))
  233. assert_allclose(res.sol(res.t), res.y, rtol=1e-14, atol=1e-14)
  234. @pytest.mark.slow
  235. @pytest.mark.parametrize('method', ['Radau', 'BDF', 'LSODA'])
  236. def test_integration_stiff(method):
  237. rtol = 1e-6
  238. atol = 1e-6
  239. y0 = [1e4, 0, 0]
  240. tspan = [0, 1e8]
  241. def fun_robertson(t, state):
  242. x, y, z = state
  243. return [
  244. -0.04 * x + 1e4 * y * z,
  245. 0.04 * x - 1e4 * y * z - 3e7 * y * y,
  246. 3e7 * y * y,
  247. ]
  248. res = solve_ivp(fun_robertson, tspan, y0, rtol=rtol,
  249. atol=atol, method=method)
  250. # If the stiff mode is not activated correctly, these numbers will be much bigger
  251. assert res.nfev < 5000
  252. assert res.njev < 200
  253. def test_events():
  254. def event_rational_1(t, y):
  255. return y[0] - y[1] ** 0.7
  256. def event_rational_2(t, y):
  257. return y[1] ** 0.6 - y[0]
  258. def event_rational_3(t, y):
  259. return t - 7.4
  260. event_rational_3.terminal = True
  261. for method in ['RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA']:
  262. res = solve_ivp(fun_rational, [5, 8], [1/3, 2/9], method=method,
  263. events=(event_rational_1, event_rational_2))
  264. assert_equal(res.status, 0)
  265. assert_equal(res.t_events[0].size, 1)
  266. assert_equal(res.t_events[1].size, 1)
  267. assert_(5.3 < res.t_events[0][0] < 5.7)
  268. assert_(7.3 < res.t_events[1][0] < 7.7)
  269. assert_equal(res.y_events[0].shape, (1, 2))
  270. assert_equal(res.y_events[1].shape, (1, 2))
  271. assert np.isclose(
  272. event_rational_1(res.t_events[0][0], res.y_events[0][0]), 0)
  273. assert np.isclose(
  274. event_rational_2(res.t_events[1][0], res.y_events[1][0]), 0)
  275. event_rational_1.direction = 1
  276. event_rational_2.direction = 1
  277. res = solve_ivp(fun_rational, [5, 8], [1 / 3, 2 / 9], method=method,
  278. events=(event_rational_1, event_rational_2))
  279. assert_equal(res.status, 0)
  280. assert_equal(res.t_events[0].size, 1)
  281. assert_equal(res.t_events[1].size, 0)
  282. assert_(5.3 < res.t_events[0][0] < 5.7)
  283. assert_equal(res.y_events[0].shape, (1, 2))
  284. assert_equal(res.y_events[1].shape, (0,))
  285. assert np.isclose(
  286. event_rational_1(res.t_events[0][0], res.y_events[0][0]), 0)
  287. event_rational_1.direction = -1
  288. event_rational_2.direction = -1
  289. res = solve_ivp(fun_rational, [5, 8], [1 / 3, 2 / 9], method=method,
  290. events=(event_rational_1, event_rational_2))
  291. assert_equal(res.status, 0)
  292. assert_equal(res.t_events[0].size, 0)
  293. assert_equal(res.t_events[1].size, 1)
  294. assert_(7.3 < res.t_events[1][0] < 7.7)
  295. assert_equal(res.y_events[0].shape, (0,))
  296. assert_equal(res.y_events[1].shape, (1, 2))
  297. assert np.isclose(
  298. event_rational_2(res.t_events[1][0], res.y_events[1][0]), 0)
  299. event_rational_1.direction = 0
  300. event_rational_2.direction = 0
  301. res = solve_ivp(fun_rational, [5, 8], [1 / 3, 2 / 9], method=method,
  302. events=(event_rational_1, event_rational_2,
  303. event_rational_3), dense_output=True)
  304. assert_equal(res.status, 1)
  305. assert_equal(res.t_events[0].size, 1)
  306. assert_equal(res.t_events[1].size, 0)
  307. assert_equal(res.t_events[2].size, 1)
  308. assert_(5.3 < res.t_events[0][0] < 5.7)
  309. assert_(7.3 < res.t_events[2][0] < 7.5)
  310. assert_equal(res.y_events[0].shape, (1, 2))
  311. assert_equal(res.y_events[1].shape, (0,))
  312. assert_equal(res.y_events[2].shape, (1, 2))
  313. assert np.isclose(
  314. event_rational_1(res.t_events[0][0], res.y_events[0][0]), 0)
  315. assert np.isclose(
  316. event_rational_3(res.t_events[2][0], res.y_events[2][0]), 0)
  317. res = solve_ivp(fun_rational, [5, 8], [1 / 3, 2 / 9], method=method,
  318. events=event_rational_1, dense_output=True)
  319. assert_equal(res.status, 0)
  320. assert_equal(res.t_events[0].size, 1)
  321. assert_(5.3 < res.t_events[0][0] < 5.7)
  322. assert_equal(res.y_events[0].shape, (1, 2))
  323. assert np.isclose(
  324. event_rational_1(res.t_events[0][0], res.y_events[0][0]), 0)
  325. # Also test that termination by event doesn't break interpolants.
  326. tc = np.linspace(res.t[0], res.t[-1])
  327. yc_true = sol_rational(tc)
  328. yc = res.sol(tc)
  329. e = compute_error(yc, yc_true, 1e-3, 1e-6)
  330. assert_(np.all(e < 5))
  331. # Test that the y_event matches solution
  332. assert np.allclose(sol_rational(res.t_events[0][0]), res.y_events[0][0], rtol=1e-3, atol=1e-6)
  333. # Test in backward direction.
  334. event_rational_1.direction = 0
  335. event_rational_2.direction = 0
  336. for method in ['RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA']:
  337. res = solve_ivp(fun_rational, [8, 5], [4/9, 20/81], method=method,
  338. events=(event_rational_1, event_rational_2))
  339. assert_equal(res.status, 0)
  340. assert_equal(res.t_events[0].size, 1)
  341. assert_equal(res.t_events[1].size, 1)
  342. assert_(5.3 < res.t_events[0][0] < 5.7)
  343. assert_(7.3 < res.t_events[1][0] < 7.7)
  344. assert_equal(res.y_events[0].shape, (1, 2))
  345. assert_equal(res.y_events[1].shape, (1, 2))
  346. assert np.isclose(
  347. event_rational_1(res.t_events[0][0], res.y_events[0][0]), 0)
  348. assert np.isclose(
  349. event_rational_2(res.t_events[1][0], res.y_events[1][0]), 0)
  350. event_rational_1.direction = -1
  351. event_rational_2.direction = -1
  352. res = solve_ivp(fun_rational, [8, 5], [4/9, 20/81], method=method,
  353. events=(event_rational_1, event_rational_2))
  354. assert_equal(res.status, 0)
  355. assert_equal(res.t_events[0].size, 1)
  356. assert_equal(res.t_events[1].size, 0)
  357. assert_(5.3 < res.t_events[0][0] < 5.7)
  358. assert_equal(res.y_events[0].shape, (1, 2))
  359. assert_equal(res.y_events[1].shape, (0,))
  360. assert np.isclose(
  361. event_rational_1(res.t_events[0][0], res.y_events[0][0]), 0)
  362. event_rational_1.direction = 1
  363. event_rational_2.direction = 1
  364. res = solve_ivp(fun_rational, [8, 5], [4/9, 20/81], method=method,
  365. events=(event_rational_1, event_rational_2))
  366. assert_equal(res.status, 0)
  367. assert_equal(res.t_events[0].size, 0)
  368. assert_equal(res.t_events[1].size, 1)
  369. assert_(7.3 < res.t_events[1][0] < 7.7)
  370. assert_equal(res.y_events[0].shape, (0,))
  371. assert_equal(res.y_events[1].shape, (1, 2))
  372. assert np.isclose(
  373. event_rational_2(res.t_events[1][0], res.y_events[1][0]), 0)
  374. event_rational_1.direction = 0
  375. event_rational_2.direction = 0
  376. res = solve_ivp(fun_rational, [8, 5], [4/9, 20/81], method=method,
  377. events=(event_rational_1, event_rational_2,
  378. event_rational_3), dense_output=True)
  379. assert_equal(res.status, 1)
  380. assert_equal(res.t_events[0].size, 0)
  381. assert_equal(res.t_events[1].size, 1)
  382. assert_equal(res.t_events[2].size, 1)
  383. assert_(7.3 < res.t_events[1][0] < 7.7)
  384. assert_(7.3 < res.t_events[2][0] < 7.5)
  385. assert_equal(res.y_events[0].shape, (0,))
  386. assert_equal(res.y_events[1].shape, (1, 2))
  387. assert_equal(res.y_events[2].shape, (1, 2))
  388. assert np.isclose(
  389. event_rational_2(res.t_events[1][0], res.y_events[1][0]), 0)
  390. assert np.isclose(
  391. event_rational_3(res.t_events[2][0], res.y_events[2][0]), 0)
  392. # Also test that termination by event doesn't break interpolants.
  393. tc = np.linspace(res.t[-1], res.t[0])
  394. yc_true = sol_rational(tc)
  395. yc = res.sol(tc)
  396. e = compute_error(yc, yc_true, 1e-3, 1e-6)
  397. assert_(np.all(e < 5))
  398. assert np.allclose(sol_rational(res.t_events[1][0]), res.y_events[1][0], rtol=1e-3, atol=1e-6)
  399. assert np.allclose(sol_rational(res.t_events[2][0]), res.y_events[2][0], rtol=1e-3, atol=1e-6)
  400. def test_max_step():
  401. rtol = 1e-3
  402. atol = 1e-6
  403. y0 = [1/3, 2/9]
  404. for method in [RK23, RK45, DOP853, Radau, BDF, LSODA]:
  405. for t_span in ([5, 9], [5, 1]):
  406. res = solve_ivp(fun_rational, t_span, y0, rtol=rtol,
  407. max_step=0.5, atol=atol, method=method,
  408. dense_output=True)
  409. assert_equal(res.t[0], t_span[0])
  410. assert_equal(res.t[-1], t_span[-1])
  411. assert_(np.all(np.abs(np.diff(res.t)) <= 0.5 + 1e-15))
  412. assert_(res.t_events is None)
  413. assert_(res.success)
  414. assert_equal(res.status, 0)
  415. y_true = sol_rational(res.t)
  416. e = compute_error(res.y, y_true, rtol, atol)
  417. assert_(np.all(e < 5))
  418. tc = np.linspace(*t_span)
  419. yc_true = sol_rational(tc)
  420. yc = res.sol(tc)
  421. e = compute_error(yc, yc_true, rtol, atol)
  422. assert_(np.all(e < 5))
  423. # See comment in test_integration.
  424. if method is not LSODA:
  425. assert_allclose(res.sol(res.t), res.y, rtol=1e-15, atol=1e-15)
  426. assert_raises(ValueError, method, fun_rational, t_span[0], y0,
  427. t_span[1], max_step=-1)
  428. if method is not LSODA:
  429. solver = method(fun_rational, t_span[0], y0, t_span[1],
  430. rtol=rtol, atol=atol, max_step=1e-20)
  431. message = solver.step()
  432. assert_equal(solver.status, 'failed')
  433. assert_("step size is less" in message)
  434. assert_raises(RuntimeError, solver.step)
  435. def test_first_step():
  436. rtol = 1e-3
  437. atol = 1e-6
  438. y0 = [1/3, 2/9]
  439. first_step = 0.1
  440. for method in [RK23, RK45, DOP853, Radau, BDF, LSODA]:
  441. for t_span in ([5, 9], [5, 1]):
  442. res = solve_ivp(fun_rational, t_span, y0, rtol=rtol,
  443. max_step=0.5, atol=atol, method=method,
  444. dense_output=True, first_step=first_step)
  445. assert_equal(res.t[0], t_span[0])
  446. assert_equal(res.t[-1], t_span[-1])
  447. assert_allclose(first_step, np.abs(res.t[1] - 5))
  448. assert_(res.t_events is None)
  449. assert_(res.success)
  450. assert_equal(res.status, 0)
  451. y_true = sol_rational(res.t)
  452. e = compute_error(res.y, y_true, rtol, atol)
  453. assert_(np.all(e < 5))
  454. tc = np.linspace(*t_span)
  455. yc_true = sol_rational(tc)
  456. yc = res.sol(tc)
  457. e = compute_error(yc, yc_true, rtol, atol)
  458. assert_(np.all(e < 5))
  459. # See comment in test_integration.
  460. if method is not LSODA:
  461. assert_allclose(res.sol(res.t), res.y, rtol=1e-15, atol=1e-15)
  462. assert_raises(ValueError, method, fun_rational, t_span[0], y0,
  463. t_span[1], first_step=-1)
  464. assert_raises(ValueError, method, fun_rational, t_span[0], y0,
  465. t_span[1], first_step=5)
  466. def test_t_eval():
  467. rtol = 1e-3
  468. atol = 1e-6
  469. y0 = [1/3, 2/9]
  470. for t_span in ([5, 9], [5, 1]):
  471. t_eval = np.linspace(t_span[0], t_span[1], 10)
  472. res = solve_ivp(fun_rational, t_span, y0, rtol=rtol, atol=atol,
  473. t_eval=t_eval)
  474. assert_equal(res.t, t_eval)
  475. assert_(res.t_events is None)
  476. assert_(res.success)
  477. assert_equal(res.status, 0)
  478. y_true = sol_rational(res.t)
  479. e = compute_error(res.y, y_true, rtol, atol)
  480. assert_(np.all(e < 5))
  481. t_eval = [5, 5.01, 7, 8, 8.01, 9]
  482. res = solve_ivp(fun_rational, [5, 9], y0, rtol=rtol, atol=atol,
  483. t_eval=t_eval)
  484. assert_equal(res.t, t_eval)
  485. assert_(res.t_events is None)
  486. assert_(res.success)
  487. assert_equal(res.status, 0)
  488. y_true = sol_rational(res.t)
  489. e = compute_error(res.y, y_true, rtol, atol)
  490. assert_(np.all(e < 5))
  491. t_eval = [5, 4.99, 3, 1.5, 1.1, 1.01, 1]
  492. res = solve_ivp(fun_rational, [5, 1], y0, rtol=rtol, atol=atol,
  493. t_eval=t_eval)
  494. assert_equal(res.t, t_eval)
  495. assert_(res.t_events is None)
  496. assert_(res.success)
  497. assert_equal(res.status, 0)
  498. t_eval = [5.01, 7, 8, 8.01]
  499. res = solve_ivp(fun_rational, [5, 9], y0, rtol=rtol, atol=atol,
  500. t_eval=t_eval)
  501. assert_equal(res.t, t_eval)
  502. assert_(res.t_events is None)
  503. assert_(res.success)
  504. assert_equal(res.status, 0)
  505. y_true = sol_rational(res.t)
  506. e = compute_error(res.y, y_true, rtol, atol)
  507. assert_(np.all(e < 5))
  508. t_eval = [4.99, 3, 1.5, 1.1, 1.01]
  509. res = solve_ivp(fun_rational, [5, 1], y0, rtol=rtol, atol=atol,
  510. t_eval=t_eval)
  511. assert_equal(res.t, t_eval)
  512. assert_(res.t_events is None)
  513. assert_(res.success)
  514. assert_equal(res.status, 0)
  515. t_eval = [4, 6]
  516. assert_raises(ValueError, solve_ivp, fun_rational, [5, 9], y0,
  517. rtol=rtol, atol=atol, t_eval=t_eval)
  518. def test_t_eval_dense_output():
  519. rtol = 1e-3
  520. atol = 1e-6
  521. y0 = [1/3, 2/9]
  522. t_span = [5, 9]
  523. t_eval = np.linspace(t_span[0], t_span[1], 10)
  524. res = solve_ivp(fun_rational, t_span, y0, rtol=rtol, atol=atol,
  525. t_eval=t_eval)
  526. res_d = solve_ivp(fun_rational, t_span, y0, rtol=rtol, atol=atol,
  527. t_eval=t_eval, dense_output=True)
  528. assert_equal(res.t, t_eval)
  529. assert_(res.t_events is None)
  530. assert_(res.success)
  531. assert_equal(res.status, 0)
  532. assert_equal(res.t, res_d.t)
  533. assert_equal(res.y, res_d.y)
  534. assert_(res_d.t_events is None)
  535. assert_(res_d.success)
  536. assert_equal(res_d.status, 0)
  537. # if t and y are equal only test values for one case
  538. y_true = sol_rational(res.t)
  539. e = compute_error(res.y, y_true, rtol, atol)
  540. assert_(np.all(e < 5))
  541. def test_t_eval_early_event():
  542. def early_event(t, y):
  543. return t - 7
  544. early_event.terminal = True
  545. rtol = 1e-3
  546. atol = 1e-6
  547. y0 = [1/3, 2/9]
  548. t_span = [5, 9]
  549. t_eval = np.linspace(7.5, 9, 16)
  550. for method in ['RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA']:
  551. with suppress_warnings() as sup:
  552. sup.filter(UserWarning,
  553. "The following arguments have no effect for a chosen "
  554. "solver: `jac`")
  555. res = solve_ivp(fun_rational, t_span, y0, rtol=rtol, atol=atol,
  556. method=method, t_eval=t_eval, events=early_event,
  557. jac=jac_rational)
  558. assert res.success
  559. assert res.message == 'A termination event occurred.'
  560. assert res.status == 1
  561. assert not res.t and not res.y
  562. assert len(res.t_events) == 1
  563. assert res.t_events[0].size == 1
  564. assert res.t_events[0][0] == 7
  565. def test_no_integration():
  566. for method in ['RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA']:
  567. sol = solve_ivp(lambda t, y: -y, [4, 4], [2, 3],
  568. method=method, dense_output=True)
  569. assert_equal(sol.sol(4), [2, 3])
  570. assert_equal(sol.sol([4, 5, 6]), [[2, 2, 2], [3, 3, 3]])
  571. def test_no_integration_class():
  572. for method in [RK23, RK45, DOP853, Radau, BDF, LSODA]:
  573. solver = method(lambda t, y: -y, 0.0, [10.0, 0.0], 0.0)
  574. solver.step()
  575. assert_equal(solver.status, 'finished')
  576. sol = solver.dense_output()
  577. assert_equal(sol(0.0), [10.0, 0.0])
  578. assert_equal(sol([0, 1, 2]), [[10, 10, 10], [0, 0, 0]])
  579. solver = method(lambda t, y: -y, 0.0, [], np.inf)
  580. solver.step()
  581. assert_equal(solver.status, 'finished')
  582. sol = solver.dense_output()
  583. assert_equal(sol(100.0), [])
  584. assert_equal(sol([0, 1, 2]), np.empty((0, 3)))
  585. def test_empty():
  586. def fun(t, y):
  587. return np.zeros((0,))
  588. y0 = np.zeros((0,))
  589. for method in ['RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA']:
  590. sol = assert_no_warnings(solve_ivp, fun, [0, 10], y0,
  591. method=method, dense_output=True)
  592. assert_equal(sol.sol(10), np.zeros((0,)))
  593. assert_equal(sol.sol([1, 2, 3]), np.zeros((0, 3)))
  594. for method in ['RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA']:
  595. sol = assert_no_warnings(solve_ivp, fun, [0, np.inf], y0,
  596. method=method, dense_output=True)
  597. assert_equal(sol.sol(10), np.zeros((0,)))
  598. assert_equal(sol.sol([1, 2, 3]), np.zeros((0, 3)))
  599. def test_ConstantDenseOutput():
  600. sol = ConstantDenseOutput(0, 1, np.array([1, 2]))
  601. assert_allclose(sol(1.5), [1, 2])
  602. assert_allclose(sol([1, 1.5, 2]), [[1, 1, 1], [2, 2, 2]])
  603. sol = ConstantDenseOutput(0, 1, np.array([]))
  604. assert_allclose(sol(1.5), np.empty(0))
  605. assert_allclose(sol([1, 1.5, 2]), np.empty((0, 3)))
  606. def test_classes():
  607. y0 = [1 / 3, 2 / 9]
  608. for cls in [RK23, RK45, DOP853, Radau, BDF, LSODA]:
  609. solver = cls(fun_rational, 5, y0, np.inf)
  610. assert_equal(solver.n, 2)
  611. assert_equal(solver.status, 'running')
  612. assert_equal(solver.t_bound, np.inf)
  613. assert_equal(solver.direction, 1)
  614. assert_equal(solver.t, 5)
  615. assert_equal(solver.y, y0)
  616. assert_(solver.step_size is None)
  617. if cls is not LSODA:
  618. assert_(solver.nfev > 0)
  619. assert_(solver.njev >= 0)
  620. assert_equal(solver.nlu, 0)
  621. else:
  622. assert_equal(solver.nfev, 0)
  623. assert_equal(solver.njev, 0)
  624. assert_equal(solver.nlu, 0)
  625. assert_raises(RuntimeError, solver.dense_output)
  626. message = solver.step()
  627. assert_equal(solver.status, 'running')
  628. assert_equal(message, None)
  629. assert_equal(solver.n, 2)
  630. assert_equal(solver.t_bound, np.inf)
  631. assert_equal(solver.direction, 1)
  632. assert_(solver.t > 5)
  633. assert_(not np.all(np.equal(solver.y, y0)))
  634. assert_(solver.step_size > 0)
  635. assert_(solver.nfev > 0)
  636. assert_(solver.njev >= 0)
  637. assert_(solver.nlu >= 0)
  638. sol = solver.dense_output()
  639. assert_allclose(sol(5), y0, rtol=1e-15, atol=0)
  640. def test_OdeSolution():
  641. ts = np.array([0, 2, 5], dtype=float)
  642. s1 = ConstantDenseOutput(ts[0], ts[1], np.array([-1]))
  643. s2 = ConstantDenseOutput(ts[1], ts[2], np.array([1]))
  644. sol = OdeSolution(ts, [s1, s2])
  645. assert_equal(sol(-1), [-1])
  646. assert_equal(sol(1), [-1])
  647. assert_equal(sol(2), [-1])
  648. assert_equal(sol(3), [1])
  649. assert_equal(sol(5), [1])
  650. assert_equal(sol(6), [1])
  651. assert_equal(sol([0, 6, -2, 1.5, 4.5, 2.5, 5, 5.5, 2]),
  652. np.array([[-1, 1, -1, -1, 1, 1, 1, 1, -1]]))
  653. ts = np.array([10, 4, -3])
  654. s1 = ConstantDenseOutput(ts[0], ts[1], np.array([-1]))
  655. s2 = ConstantDenseOutput(ts[1], ts[2], np.array([1]))
  656. sol = OdeSolution(ts, [s1, s2])
  657. assert_equal(sol(11), [-1])
  658. assert_equal(sol(10), [-1])
  659. assert_equal(sol(5), [-1])
  660. assert_equal(sol(4), [-1])
  661. assert_equal(sol(0), [1])
  662. assert_equal(sol(-3), [1])
  663. assert_equal(sol(-4), [1])
  664. assert_equal(sol([12, -5, 10, -3, 6, 1, 4]),
  665. np.array([[-1, 1, -1, 1, -1, 1, -1]]))
  666. ts = np.array([1, 1])
  667. s = ConstantDenseOutput(1, 1, np.array([10]))
  668. sol = OdeSolution(ts, [s])
  669. assert_equal(sol(0), [10])
  670. assert_equal(sol(1), [10])
  671. assert_equal(sol(2), [10])
  672. assert_equal(sol([2, 1, 0]), np.array([[10, 10, 10]]))
  673. def test_num_jac():
  674. def fun(t, y):
  675. return np.vstack([
  676. -0.04 * y[0] + 1e4 * y[1] * y[2],
  677. 0.04 * y[0] - 1e4 * y[1] * y[2] - 3e7 * y[1] ** 2,
  678. 3e7 * y[1] ** 2
  679. ])
  680. def jac(t, y):
  681. return np.array([
  682. [-0.04, 1e4 * y[2], 1e4 * y[1]],
  683. [0.04, -1e4 * y[2] - 6e7 * y[1], -1e4 * y[1]],
  684. [0, 6e7 * y[1], 0]
  685. ])
  686. t = 1
  687. y = np.array([1, 0, 0])
  688. J_true = jac(t, y)
  689. threshold = 1e-5
  690. f = fun(t, y).ravel()
  691. J_num, factor = num_jac(fun, t, y, f, threshold, None)
  692. assert_allclose(J_num, J_true, rtol=1e-5, atol=1e-5)
  693. J_num, factor = num_jac(fun, t, y, f, threshold, factor)
  694. assert_allclose(J_num, J_true, rtol=1e-5, atol=1e-5)
  695. def test_num_jac_sparse():
  696. def fun(t, y):
  697. e = y[1:]**3 - y[:-1]**2
  698. z = np.zeros(y.shape[1])
  699. return np.vstack((z, 3 * e)) + np.vstack((2 * e, z))
  700. def structure(n):
  701. A = np.zeros((n, n), dtype=int)
  702. A[0, 0] = 1
  703. A[0, 1] = 1
  704. for i in range(1, n - 1):
  705. A[i, i - 1: i + 2] = 1
  706. A[-1, -1] = 1
  707. A[-1, -2] = 1
  708. return A
  709. np.random.seed(0)
  710. n = 20
  711. y = np.random.randn(n)
  712. A = structure(n)
  713. groups = group_columns(A)
  714. f = fun(0, y[:, None]).ravel()
  715. # Compare dense and sparse results, assuming that dense implementation
  716. # is correct (as it is straightforward).
  717. J_num_sparse, factor_sparse = num_jac(fun, 0, y.ravel(), f, 1e-8, None,
  718. sparsity=(A, groups))
  719. J_num_dense, factor_dense = num_jac(fun, 0, y.ravel(), f, 1e-8, None)
  720. assert_allclose(J_num_dense, J_num_sparse.toarray(),
  721. rtol=1e-12, atol=1e-14)
  722. assert_allclose(factor_dense, factor_sparse, rtol=1e-12, atol=1e-14)
  723. # Take small factors to trigger their recomputing inside.
  724. factor = np.random.uniform(0, 1e-12, size=n)
  725. J_num_sparse, factor_sparse = num_jac(fun, 0, y.ravel(), f, 1e-8, factor,
  726. sparsity=(A, groups))
  727. J_num_dense, factor_dense = num_jac(fun, 0, y.ravel(), f, 1e-8, factor)
  728. assert_allclose(J_num_dense, J_num_sparse.toarray(),
  729. rtol=1e-12, atol=1e-14)
  730. assert_allclose(factor_dense, factor_sparse, rtol=1e-12, atol=1e-14)
  731. def test_args():
  732. # sys3 is actually two decoupled systems. (x, y) form a
  733. # linear oscillator, while z is a nonlinear first order
  734. # system with equilibria at z=0 and z=1. If k > 0, z=1
  735. # is stable and z=0 is unstable.
  736. def sys3(t, w, omega, k, zfinal):
  737. x, y, z = w
  738. return [-omega*y, omega*x, k*z*(1 - z)]
  739. def sys3_jac(t, w, omega, k, zfinal):
  740. x, y, z = w
  741. J = np.array([[0, -omega, 0],
  742. [omega, 0, 0],
  743. [0, 0, k*(1 - 2*z)]])
  744. return J
  745. def sys3_x0decreasing(t, w, omega, k, zfinal):
  746. x, y, z = w
  747. return x
  748. def sys3_y0increasing(t, w, omega, k, zfinal):
  749. x, y, z = w
  750. return y
  751. def sys3_zfinal(t, w, omega, k, zfinal):
  752. x, y, z = w
  753. return z - zfinal
  754. # Set the event flags for the event functions.
  755. sys3_x0decreasing.direction = -1
  756. sys3_y0increasing.direction = 1
  757. sys3_zfinal.terminal = True
  758. omega = 2
  759. k = 4
  760. tfinal = 5
  761. zfinal = 0.99
  762. # Find z0 such that when z(0) = z0, z(tfinal) = zfinal.
  763. # The condition z(tfinal) = zfinal is the terminal event.
  764. z0 = np.exp(-k*tfinal)/((1 - zfinal)/zfinal + np.exp(-k*tfinal))
  765. w0 = [0, -1, z0]
  766. # Provide the jac argument and use the Radau method to ensure that the use
  767. # of the Jacobian function is exercised.
  768. # If event handling is working, the solution will stop at tfinal, not tend.
  769. tend = 2*tfinal
  770. sol = solve_ivp(sys3, [0, tend], w0,
  771. events=[sys3_x0decreasing, sys3_y0increasing, sys3_zfinal],
  772. dense_output=True, args=(omega, k, zfinal),
  773. method='Radau', jac=sys3_jac,
  774. rtol=1e-10, atol=1e-13)
  775. # Check that we got the expected events at the expected times.
  776. x0events_t = sol.t_events[0]
  777. y0events_t = sol.t_events[1]
  778. zfinalevents_t = sol.t_events[2]
  779. assert_allclose(x0events_t, [0.5*np.pi, 1.5*np.pi])
  780. assert_allclose(y0events_t, [0.25*np.pi, 1.25*np.pi])
  781. assert_allclose(zfinalevents_t, [tfinal])
  782. # Check that the solution agrees with the known exact solution.
  783. t = np.linspace(0, zfinalevents_t[0], 250)
  784. w = sol.sol(t)
  785. assert_allclose(w[0], np.sin(omega*t), rtol=1e-9, atol=1e-12)
  786. assert_allclose(w[1], -np.cos(omega*t), rtol=1e-9, atol=1e-12)
  787. assert_allclose(w[2], 1/(((1 - z0)/z0)*np.exp(-k*t) + 1),
  788. rtol=1e-9, atol=1e-12)
  789. # Check that the state variables have the expected values at the events.
  790. x0events = sol.sol(x0events_t)
  791. y0events = sol.sol(y0events_t)
  792. zfinalevents = sol.sol(zfinalevents_t)
  793. assert_allclose(x0events[0], np.zeros_like(x0events[0]), atol=5e-14)
  794. assert_allclose(x0events[1], np.ones_like(x0events[1]))
  795. assert_allclose(y0events[0], np.ones_like(y0events[0]))
  796. assert_allclose(y0events[1], np.zeros_like(y0events[1]), atol=5e-14)
  797. assert_allclose(zfinalevents[2], [zfinal])
  798. def test_array_rtol():
  799. # solve_ivp had a bug with array_like `rtol`; see gh-15482
  800. # check that it's fixed
  801. def f(t, y):
  802. return y[0], y[1]
  803. # no warning (or error) when `rtol` is array_like
  804. sol = solve_ivp(f, (0, 1), [1., 1.], rtol=[1e-1, 1e-1])
  805. err1 = np.abs(np.linalg.norm(sol.y[:, -1] - np.exp(1)))
  806. # warning when an element of `rtol` is too small
  807. with pytest.warns(UserWarning, match="At least one element..."):
  808. sol = solve_ivp(f, (0, 1), [1., 1.], rtol=[1e-1, 1e-16])
  809. err2 = np.abs(np.linalg.norm(sol.y[:, -1] - np.exp(1)))
  810. # tighter rtol improves the error
  811. assert err2 < err1
  812. @pytest.mark.parametrize('method', ['RK23', 'RK45', 'DOP853', 'Radau', 'BDF', 'LSODA'])
  813. def test_integration_zero_rhs(method):
  814. result = solve_ivp(fun_zero, [0, 10], np.ones(3), method=method)
  815. assert_(result.success)
  816. assert_equal(result.status, 0)
  817. assert_allclose(result.y, 1.0, rtol=1e-15)
  818. def test_args_single_value():
  819. def fun_with_arg(t, y, a):
  820. return a*y
  821. message = "Supplied 'args' cannot be unpacked."
  822. with pytest.raises(TypeError, match=message):
  823. solve_ivp(fun_with_arg, (0, 0.1), [1], args=-1)
  824. sol = solve_ivp(fun_with_arg, (0, 0.1), [1], args=(-1,))
  825. assert_allclose(sol.y[0, -1], np.exp(-0.1))