test_riccati.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877
  1. from sympy.core.random import randint
  2. from sympy.core.function import Function
  3. from sympy.core.mul import Mul
  4. from sympy.core.numbers import (I, Rational, oo)
  5. from sympy.core.relational import Eq
  6. from sympy.core.singleton import S
  7. from sympy.core.symbol import (Dummy, symbols)
  8. from sympy.functions.elementary.exponential import (exp, log)
  9. from sympy.functions.elementary.hyperbolic import tanh
  10. from sympy.functions.elementary.miscellaneous import sqrt
  11. from sympy.functions.elementary.trigonometric import sin
  12. from sympy.polys.polytools import Poly
  13. from sympy.simplify.ratsimp import ratsimp
  14. from sympy.solvers.ode.subscheck import checkodesol
  15. from sympy.testing.pytest import slow
  16. from sympy.solvers.ode.riccati import (riccati_normal, riccati_inverse_normal,
  17. riccati_reduced, match_riccati, inverse_transform_poly, limit_at_inf,
  18. check_necessary_conds, val_at_inf, construct_c_case_1,
  19. construct_c_case_2, construct_c_case_3, construct_d_case_4,
  20. construct_d_case_5, construct_d_case_6, rational_laurent_series,
  21. solve_riccati)
  22. f = Function('f')
  23. x = symbols('x')
  24. # These are the functions used to generate the tests
  25. # SHOULD NOT BE USED DIRECTLY IN TESTS
  26. def rand_rational(maxint):
  27. return Rational(randint(-maxint, maxint), randint(1, maxint))
  28. def rand_poly(x, degree, maxint):
  29. return Poly([rand_rational(maxint) for _ in range(degree+1)], x)
  30. def rand_rational_function(x, degree, maxint):
  31. degnum = randint(1, degree)
  32. degden = randint(1, degree)
  33. num = rand_poly(x, degnum, maxint)
  34. den = rand_poly(x, degden, maxint)
  35. while den == Poly(0, x):
  36. den = rand_poly(x, degden, maxint)
  37. return num / den
  38. def find_riccati_ode(ratfunc, x, yf):
  39. y = ratfunc
  40. yp = y.diff(x)
  41. q1 = rand_rational_function(x, 1, 3)
  42. q2 = rand_rational_function(x, 1, 3)
  43. while q2 == 0:
  44. q2 = rand_rational_function(x, 1, 3)
  45. q0 = ratsimp(yp - q1*y - q2*y**2)
  46. eq = Eq(yf.diff(), q0 + q1*yf + q2*yf**2)
  47. sol = Eq(yf, y)
  48. assert checkodesol(eq, sol) == (True, 0)
  49. return eq, q0, q1, q2
  50. # Testing functions start
  51. def test_riccati_transformation():
  52. """
  53. This function tests the transformation of the
  54. solution of a Riccati ODE to the solution of
  55. its corresponding normal Riccati ODE.
  56. Each test case 4 values -
  57. 1. w - The solution to be transformed
  58. 2. b1 - The coefficient of f(x) in the ODE.
  59. 3. b2 - The coefficient of f(x)**2 in the ODE.
  60. 4. y - The solution to the normal Riccati ODE.
  61. """
  62. tests = [
  63. (
  64. x/(x - 1),
  65. (x**2 + 7)/3*x,
  66. x,
  67. -x**2/(x - 1) - x*(x**2/3 + S(7)/3)/2 - 1/(2*x)
  68. ),
  69. (
  70. (2*x + 3)/(2*x + 2),
  71. (3 - 3*x)/(x + 1),
  72. 5*x,
  73. -5*x*(2*x + 3)/(2*x + 2) - (3 - 3*x)/(Mul(2, x + 1, evaluate=False)) - 1/(2*x)
  74. ),
  75. (
  76. -1/(2*x**2 - 1),
  77. 0,
  78. (2 - x)/(4*x - 2),
  79. (2 - x)/((4*x - 2)*(2*x**2 - 1)) - (4*x - 2)*(Mul(-4, 2 - x, evaluate=False)/(4*x - \
  80. 2)**2 - 1/(4*x - 2))/(Mul(2, 2 - x, evaluate=False))
  81. ),
  82. (
  83. x,
  84. (8*x - 12)/(12*x + 9),
  85. x**3/(6*x - 9),
  86. -x**4/(6*x - 9) - (8*x - 12)/(Mul(2, 12*x + 9, evaluate=False)) - (6*x - 9)*(-6*x**3/(6*x \
  87. - 9)**2 + 3*x**2/(6*x - 9))/(2*x**3)
  88. )]
  89. for w, b1, b2, y in tests:
  90. assert y == riccati_normal(w, x, b1, b2)
  91. assert w == riccati_inverse_normal(y, x, b1, b2).cancel()
  92. # Test bp parameter in riccati_inverse_normal
  93. tests = [
  94. (
  95. (-2*x - 1)/(2*x**2 + 2*x - 2),
  96. -2/x,
  97. (-x - 1)/(4*x),
  98. 8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1),
  99. -2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) - (-2*x - 1)*(-x - 1)/(4*x*(2*x**2 + 2*x \
  100. - 2)) + 1/x
  101. ),
  102. (
  103. 3/(2*x**2),
  104. -2/x,
  105. (-x - 1)/(4*x),
  106. 8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1),
  107. -2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) + 1/x - Mul(3, -x - 1, evaluate=False)/(8*x**3)
  108. )]
  109. for w, b1, b2, bp, y in tests:
  110. assert y == riccati_normal(w, x, b1, b2)
  111. assert w == riccati_inverse_normal(y, x, b1, b2, bp).cancel()
  112. def test_riccati_reduced():
  113. """
  114. This function tests the transformation of a
  115. Riccati ODE to its normal Riccati ODE.
  116. Each test case 2 values -
  117. 1. eq - A Riccati ODE.
  118. 2. normal_eq - The normal Riccati ODE of eq.
  119. """
  120. tests = [
  121. (
  122. f(x).diff(x) - x**2 - x*f(x) - x*f(x)**2,
  123. f(x).diff(x) + f(x)**2 + x**3 - x**2/4 - 3/(4*x**2)
  124. ),
  125. (
  126. 6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)**2/x,
  127. -3*x**2*(1/x + (-x - 1)/x**2)**2/(4*(-x - 1)**2) + Mul(6, \
  128. -x - 1, evaluate=False)/(2*x + 9) + f(x)**2 + f(x).diff(x) \
  129. - (-1 + (x + 1)/x)/(x*(-x - 1))
  130. ),
  131. (
  132. f(x)**2 + f(x).diff(x) - (x - 1)*f(x)/(-x - S(1)/2),
  133. -(2*x - 2)**2/(4*(2*x + 1)**2) + (2*x - 2)/(2*x + 1)**2 + \
  134. f(x)**2 + f(x).diff(x) - 1/(2*x + 1)
  135. ),
  136. (
  137. f(x).diff(x) - f(x)**2/x,
  138. f(x)**2 + f(x).diff(x) + 1/(4*x**2)
  139. ),
  140. (
  141. -3*(-x**2 - x + 1)/(x**2 + 6*x + 1) + f(x).diff(x) + f(x)**2/x,
  142. f(x)**2 + f(x).diff(x) + (3*x**2/(x**2 + 6*x + 1) + 3*x/(x**2 \
  143. + 6*x + 1) - 3/(x**2 + 6*x + 1))/x + 1/(4*x**2)
  144. ),
  145. (
  146. 6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)/x,
  147. False
  148. ),
  149. (
  150. f(x)*f(x).diff(x) - 1/x + f(x)/3 + f(x)**2/(x**2 - 2),
  151. False
  152. )]
  153. for eq, normal_eq in tests:
  154. assert normal_eq == riccati_reduced(eq, f, x)
  155. def test_match_riccati():
  156. """
  157. This function tests if an ODE is Riccati or not.
  158. Each test case has 5 values -
  159. 1. eq - The Riccati ODE.
  160. 2. match - Boolean indicating if eq is a Riccati ODE.
  161. 3. b0 -
  162. 4. b1 - Coefficient of f(x) in eq.
  163. 5. b2 - Coefficient of f(x)**2 in eq.
  164. """
  165. tests = [
  166. # Test Rational Riccati ODEs
  167. (
  168. f(x).diff(x) - (405*x**3 - 882*x**2 - 78*x + 92)/(243*x**4 \
  169. - 945*x**3 + 846*x**2 + 180*x - 72) - 2 - f(x)**2/(3*x + 1) \
  170. - (S(1)/3 - x)*f(x)/(S(1)/3 - 3*x/2),
  171. True,
  172. 45*x**3/(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 98*x**2/ \
  173. (27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 26*x/(81*x**4 - \
  174. 315*x**3 + 282*x**2 + 60*x - 24) + 2 + 92/(243*x**4 - 945*x**3 \
  175. + 846*x**2 + 180*x - 72),
  176. Mul(-1, 2 - 6*x, evaluate=False)/(9*x - 2),
  177. 1/(3*x + 1)
  178. ),
  179. (
  180. f(x).diff(x) + 4*x/27 - (x/3 - 1)*f(x)**2 - (2*x/3 + \
  181. 1)*f(x)/(3*x + 2) - S(10)/27 - (265*x**2 + 423*x + 162) \
  182. /(324*x**3 + 216*x**2),
  183. True,
  184. -4*x/27 + S(10)/27 + 3/(6*x**3 + 4*x**2) + 47/(36*x**2 \
  185. + 24*x) + 265/(324*x + 216),
  186. Mul(-1, -2*x - 3, evaluate=False)/(9*x + 6),
  187. x/3 - 1
  188. ),
  189. (
  190. f(x).diff(x) - (304*x**5 - 745*x**4 + 631*x**3 - 876*x**2 \
  191. + 198*x - 108)/(36*x**6 - 216*x**5 + 477*x**4 - 567*x**3 + \
  192. 360*x**2 - 108*x) - S(17)/9 - (x - S(3)/2)*f(x)/(x/2 - \
  193. S(3)/2) - (x/3 - 3)*f(x)**2/(3*x),
  194. True,
  195. 304*x**4/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + 360*x - \
  196. 108) - 745*x**3/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + \
  197. 360*x - 108) + 631*x**2/(36*x**5 - 216*x**4 + 477*x**3 - 567* \
  198. x**2 + 360*x - 108) - 292*x/(12*x**5 - 72*x**4 + 159*x**3 - \
  199. 189*x**2 + 120*x - 36) + S(17)/9 - 12/(4*x**6 - 24*x**5 + \
  200. 53*x**4 - 63*x**3 + 40*x**2 - 12*x) + 22/(4*x**5 - 24*x**4 \
  201. + 53*x**3 - 63*x**2 + 40*x - 12),
  202. Mul(-1, 3 - 2*x, evaluate=False)/(x - 3),
  203. Mul(-1, 9 - x, evaluate=False)/(9*x)
  204. ),
  205. # Test Non-Rational Riccati ODEs
  206. (
  207. f(x).diff(x) - x**(S(3)/2)/(x**(S(1)/2) - 2) + x**2*f(x) + \
  208. x*f(x)**2/(x**(S(3)/4)),
  209. False, 0, 0, 0
  210. ),
  211. (
  212. f(x).diff(x) - sin(x**2) + exp(x)*f(x) + log(x)*f(x)**2,
  213. False, 0, 0, 0
  214. ),
  215. (
  216. f(x).diff(x) - tanh(x + sqrt(x)) + f(x) + x**4*f(x)**2,
  217. False, 0, 0, 0
  218. ),
  219. # Test Non-Riccati ODEs
  220. (
  221. (1 - x**2)*f(x).diff(x, 2) - 2*x*f(x).diff(x) + 20*f(x),
  222. False, 0, 0, 0
  223. ),
  224. (
  225. f(x).diff(x) - x**2 + x**3*f(x) + (x**2/(x + 1))*f(x)**3,
  226. False, 0, 0, 0
  227. ),
  228. (
  229. f(x).diff(x)*f(x)**2 + (x**2 - 1)/(x**3 + 1)*f(x) + 1/(2*x \
  230. + 3) + f(x)**2,
  231. False, 0, 0, 0
  232. )]
  233. for eq, res, b0, b1, b2 in tests:
  234. match, funcs = match_riccati(eq, f, x)
  235. assert match == res
  236. if res:
  237. assert [b0, b1, b2] == funcs
  238. def test_val_at_inf():
  239. """
  240. This function tests the valuation of rational
  241. function at oo.
  242. Each test case has 3 values -
  243. 1. num - Numerator of rational function.
  244. 2. den - Denominator of rational function.
  245. 3. val_inf - Valuation of rational function at oo
  246. """
  247. tests = [
  248. # degree(denom) > degree(numer)
  249. (
  250. Poly(10*x**3 + 8*x**2 - 13*x + 6, x),
  251. Poly(-13*x**10 - x**9 + 5*x**8 + 7*x**7 + 10*x**6 + 6*x**5 - 7*x**4 + 11*x**3 - 8*x**2 + 5*x + 13, x),
  252. 7
  253. ),
  254. (
  255. Poly(1, x),
  256. Poly(-9*x**4 + 3*x**3 + 15*x**2 - 6*x - 14, x),
  257. 4
  258. ),
  259. # degree(denom) == degree(numer)
  260. (
  261. Poly(-6*x**3 - 8*x**2 + 8*x - 6, x),
  262. Poly(-5*x**3 + 12*x**2 - 6*x - 9, x),
  263. 0
  264. ),
  265. # degree(denom) < degree(numer)
  266. (
  267. Poly(12*x**8 - 12*x**7 - 11*x**6 + 8*x**5 + 3*x**4 - x**3 + x**2 - 11*x, x),
  268. Poly(-14*x**2 + x, x),
  269. -6
  270. ),
  271. (
  272. Poly(5*x**6 + 9*x**5 - 11*x**4 - 9*x**3 + x**2 - 4*x + 4, x),
  273. Poly(15*x**4 + 3*x**3 - 8*x**2 + 15*x + 12, x),
  274. -2
  275. )]
  276. for num, den, val in tests:
  277. assert val_at_inf(num, den, x) == val
  278. def test_necessary_conds():
  279. """
  280. This function tests the necessary conditions for
  281. a Riccati ODE to have a rational particular solution.
  282. """
  283. # Valuation at Infinity is an odd negative integer
  284. assert check_necessary_conds(-3, [1, 2, 4]) == False
  285. # Valuation at Infinity is a positive integer lesser than 2
  286. assert check_necessary_conds(1, [1, 2, 4]) == False
  287. # Multiplicity of a pole is an odd integer greater than 1
  288. assert check_necessary_conds(2, [3, 1, 6]) == False
  289. # All values are correct
  290. assert check_necessary_conds(-10, [1, 2, 8, 12]) == True
  291. def test_inverse_transform_poly():
  292. """
  293. This function tests the substitution x -> 1/x
  294. in rational functions represented using Poly.
  295. """
  296. fns = [
  297. (15*x**3 - 8*x**2 - 2*x - 6)/(18*x + 6),
  298. (180*x**5 + 40*x**4 + 80*x**3 + 30*x**2 - 60*x - 80)/(180*x**3 - 150*x**2 + 75*x + 12),
  299. (-15*x**5 - 36*x**4 + 75*x**3 - 60*x**2 - 80*x - 60)/(80*x**4 + 60*x**3 + 60*x**2 + 60*x - 80),
  300. (60*x**7 + 24*x**6 - 15*x**5 - 20*x**4 + 30*x**2 + 100*x - 60)/(240*x**2 - 20*x - 30),
  301. (30*x**6 - 12*x**5 + 15*x**4 - 15*x**2 + 10*x + 60)/(3*x**10 - 45*x**9 + 15*x**5 + 15*x**4 - 5*x**3 \
  302. + 15*x**2 + 45*x - 15)
  303. ]
  304. for f in fns:
  305. num, den = [Poly(e, x) for e in f.as_numer_denom()]
  306. num, den = inverse_transform_poly(num, den, x)
  307. assert f.subs(x, 1/x).cancel() == num/den
  308. def test_limit_at_inf():
  309. """
  310. This function tests the limit at oo of a
  311. rational function.
  312. Each test case has 3 values -
  313. 1. num - Numerator of rational function.
  314. 2. den - Denominator of rational function.
  315. 3. limit_at_inf - Limit of rational function at oo
  316. """
  317. tests = [
  318. # deg(denom) > deg(numer)
  319. (
  320. Poly(-12*x**2 + 20*x + 32, x),
  321. Poly(32*x**3 + 72*x**2 + 3*x - 32, x),
  322. 0
  323. ),
  324. # deg(denom) < deg(numer)
  325. (
  326. Poly(1260*x**4 - 1260*x**3 - 700*x**2 - 1260*x + 1400, x),
  327. Poly(6300*x**3 - 1575*x**2 + 756*x - 540, x),
  328. oo
  329. ),
  330. # deg(denom) < deg(numer), one of the leading coefficients is negative
  331. (
  332. Poly(-735*x**8 - 1400*x**7 + 1680*x**6 - 315*x**5 - 600*x**4 + 840*x**3 - 525*x**2 \
  333. + 630*x + 3780, x),
  334. Poly(1008*x**7 - 2940*x**6 - 84*x**5 + 2940*x**4 - 420*x**3 + 1512*x**2 + 105*x + 168, x),
  335. -oo
  336. ),
  337. # deg(denom) == deg(numer)
  338. (
  339. Poly(105*x**7 - 960*x**6 + 60*x**5 + 60*x**4 - 80*x**3 + 45*x**2 + 120*x + 15, x),
  340. Poly(735*x**7 + 525*x**6 + 720*x**5 + 720*x**4 - 8400*x**3 - 2520*x**2 + 2800*x + 280, x),
  341. S(1)/7
  342. ),
  343. (
  344. Poly(288*x**4 - 450*x**3 + 280*x**2 - 900*x - 90, x),
  345. Poly(607*x**4 + 840*x**3 - 1050*x**2 + 420*x + 420, x),
  346. S(288)/607
  347. )]
  348. for num, den, lim in tests:
  349. assert limit_at_inf(num, den, x) == lim
  350. def test_construct_c_case_1():
  351. """
  352. This function tests the Case 1 in the step
  353. to calculate coefficients of c-vectors.
  354. Each test case has 4 values -
  355. 1. num - Numerator of the rational function a(x).
  356. 2. den - Denominator of the rational function a(x).
  357. 3. pole - Pole of a(x) for which c-vector is being
  358. calculated.
  359. 4. c - The c-vector for the pole.
  360. """
  361. tests = [
  362. (
  363. Poly(-3*x**3 + 3*x**2 + 4*x - 5, x, extension=True),
  364. Poly(4*x**8 + 16*x**7 + 9*x**5 + 12*x**4 + 6*x**3 + 12*x**2, x, extension=True),
  365. S(0),
  366. [[S(1)/2 + sqrt(6)*I/6], [S(1)/2 - sqrt(6)*I/6]]
  367. ),
  368. (
  369. Poly(1200*x**3 + 1440*x**2 + 816*x + 560, x, extension=True),
  370. Poly(128*x**5 - 656*x**4 + 1264*x**3 - 1125*x**2 + 385*x + 49, x, extension=True),
  371. S(7)/4,
  372. [[S(1)/2 + sqrt(16367978)/634], [S(1)/2 - sqrt(16367978)/634]]
  373. ),
  374. (
  375. Poly(4*x + 2, x, extension=True),
  376. Poly(18*x**4 + (2 - 18*sqrt(3))*x**3 + (14 - 11*sqrt(3))*x**2 + (4 - 6*sqrt(3))*x \
  377. + 8*sqrt(3) + 16, x, domain='QQ<sqrt(3)>'),
  378. (S(1) + sqrt(3))/2,
  379. [[S(1)/2 + sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2], \
  380. [S(1)/2 - sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2]]
  381. )]
  382. for num, den, pole, c in tests:
  383. assert construct_c_case_1(num, den, x, pole) == c
  384. def test_construct_c_case_2():
  385. """
  386. This function tests the Case 2 in the step
  387. to calculate coefficients of c-vectors.
  388. Each test case has 5 values -
  389. 1. num - Numerator of the rational function a(x).
  390. 2. den - Denominator of the rational function a(x).
  391. 3. pole - Pole of a(x) for which c-vector is being
  392. calculated.
  393. 4. mul - The multiplicity of the pole.
  394. 5. c - The c-vector for the pole.
  395. """
  396. tests = [
  397. # Testing poles with multiplicity 2
  398. (
  399. Poly(1, x, extension=True),
  400. Poly((x - 1)**2*(x - 2), x, extension=True),
  401. 1, 2,
  402. [[-I*(-1 - I)/2], [I*(-1 + I)/2]]
  403. ),
  404. (
  405. Poly(3*x**5 - 12*x**4 - 7*x**3 + 1, x, extension=True),
  406. Poly((3*x - 1)**2*(x + 2)**2, x, extension=True),
  407. S(1)/3, 2,
  408. [[-S(89)/98], [-S(9)/98]]
  409. ),
  410. # Testing poles with multiplicity 4
  411. (
  412. Poly(x**3 - x**2 + 4*x, x, extension=True),
  413. Poly((x - 2)**4*(x + 5)**2, x, extension=True),
  414. 2, 4,
  415. [[7*sqrt(3)*(S(60)/343 - 4*sqrt(3)/7)/12, 2*sqrt(3)/7], \
  416. [-7*sqrt(3)*(S(60)/343 + 4*sqrt(3)/7)/12, -2*sqrt(3)/7]]
  417. ),
  418. (
  419. Poly(3*x**5 + x**4 + 3, x, extension=True),
  420. Poly((4*x + 1)**4*(x + 2), x, extension=True),
  421. -S(1)/4, 4,
  422. [[128*sqrt(439)*(-sqrt(439)/128 - S(55)/14336)/439, sqrt(439)/256], \
  423. [-128*sqrt(439)*(sqrt(439)/128 - S(55)/14336)/439, -sqrt(439)/256]]
  424. ),
  425. # Testing poles with multiplicity 6
  426. (
  427. Poly(x**3 + 2, x, extension=True),
  428. Poly((3*x - 1)**6*(x**2 + 1), x, extension=True),
  429. S(1)/3, 6,
  430. [[27*sqrt(66)*(-sqrt(66)/54 - S(131)/267300)/22, -2*sqrt(66)/1485, sqrt(66)/162], \
  431. [-27*sqrt(66)*(sqrt(66)/54 - S(131)/267300)/22, 2*sqrt(66)/1485, -sqrt(66)/162]]
  432. ),
  433. (
  434. Poly(x**2 + 12, x, extension=True),
  435. Poly((x - sqrt(2))**6, x, extension=True),
  436. sqrt(2), 6,
  437. [[sqrt(14)*(S(6)/7 - 3*sqrt(14))/28, sqrt(7)/7, sqrt(14)], \
  438. [-sqrt(14)*(S(6)/7 + 3*sqrt(14))/28, -sqrt(7)/7, -sqrt(14)]]
  439. )]
  440. for num, den, pole, mul, c in tests:
  441. assert construct_c_case_2(num, den, x, pole, mul) == c
  442. def test_construct_c_case_3():
  443. """
  444. This function tests the Case 3 in the step
  445. to calculate coefficients of c-vectors.
  446. """
  447. assert construct_c_case_3() == [[1]]
  448. def test_construct_d_case_4():
  449. """
  450. This function tests the Case 4 in the step
  451. to calculate coefficients of the d-vector.
  452. Each test case has 4 values -
  453. 1. num - Numerator of the rational function a(x).
  454. 2. den - Denominator of the rational function a(x).
  455. 3. mul - Multiplicity of oo as a pole.
  456. 4. d - The d-vector.
  457. """
  458. tests = [
  459. # Tests with multiplicity at oo = 2
  460. (
  461. Poly(-x**5 - 2*x**4 + 4*x**3 + 2*x + 5, x, extension=True),
  462. Poly(9*x**3 - 2*x**2 + 10*x - 2, x, extension=True),
  463. 2,
  464. [[10*I/27, I/3, -3*I*(S(158)/243 - I/3)/2], \
  465. [-10*I/27, -I/3, 3*I*(S(158)/243 + I/3)/2]]
  466. ),
  467. (
  468. Poly(-x**6 + 9*x**5 + 5*x**4 + 6*x**3 + 5*x**2 + 6*x + 7, x, extension=True),
  469. Poly(x**4 + 3*x**3 + 12*x**2 - x + 7, x, extension=True),
  470. 2,
  471. [[-6*I, I, -I*(17 - I)/2], [6*I, -I, I*(17 + I)/2]]
  472. ),
  473. # Tests with multiplicity at oo = 4
  474. (
  475. Poly(-2*x**6 - x**5 - x**4 - 2*x**3 - x**2 - 3*x - 3, x, extension=True),
  476. Poly(3*x**2 + 10*x + 7, x, extension=True),
  477. 4,
  478. [[269*sqrt(6)*I/288, -17*sqrt(6)*I/36, sqrt(6)*I/3, -sqrt(6)*I*(S(16969)/2592 \
  479. - 2*sqrt(6)*I/3)/4], [-269*sqrt(6)*I/288, 17*sqrt(6)*I/36, -sqrt(6)*I/3, \
  480. sqrt(6)*I*(S(16969)/2592 + 2*sqrt(6)*I/3)/4]]
  481. ),
  482. (
  483. Poly(-3*x**5 - 3*x**4 - 3*x**3 - x**2 - 1, x, extension=True),
  484. Poly(12*x - 2, x, extension=True),
  485. 4,
  486. [[41*I/192, 7*I/24, I/2, -I*(-S(59)/6912 - I)], \
  487. [-41*I/192, -7*I/24, -I/2, I*(-S(59)/6912 + I)]]
  488. ),
  489. # Tests with multiplicity at oo = 4
  490. (
  491. Poly(-x**7 - x**5 - x**4 - x**2 - x, x, extension=True),
  492. Poly(x + 2, x, extension=True),
  493. 6,
  494. [[-5*I/2, 2*I, -I, I, -I*(-9 - 3*I)/2], [5*I/2, -2*I, I, -I, I*(-9 + 3*I)/2]]
  495. ),
  496. (
  497. Poly(-x**7 - x**6 - 2*x**5 - 2*x**4 - x**3 - x**2 + 2*x - 2, x, extension=True),
  498. Poly(2*x - 2, x, extension=True),
  499. 6,
  500. [[3*sqrt(2)*I/4, 3*sqrt(2)*I/4, sqrt(2)*I/2, sqrt(2)*I/2, -sqrt(2)*I*(-S(7)/8 - \
  501. 3*sqrt(2)*I/2)/2], [-3*sqrt(2)*I/4, -3*sqrt(2)*I/4, -sqrt(2)*I/2, -sqrt(2)*I/2, \
  502. sqrt(2)*I*(-S(7)/8 + 3*sqrt(2)*I/2)/2]]
  503. )]
  504. for num, den, mul, d in tests:
  505. ser = rational_laurent_series(num, den, x, oo, mul, 1)
  506. assert construct_d_case_4(ser, mul//2) == d
  507. def test_construct_d_case_5():
  508. """
  509. This function tests the Case 5 in the step
  510. to calculate coefficients of the d-vector.
  511. Each test case has 3 values -
  512. 1. num - Numerator of the rational function a(x).
  513. 2. den - Denominator of the rational function a(x).
  514. 3. d - The d-vector.
  515. """
  516. tests = [
  517. (
  518. Poly(2*x**3 + x**2 + x - 2, x, extension=True),
  519. Poly(9*x**3 + 5*x**2 + 2*x - 1, x, extension=True),
  520. [[sqrt(2)/3, -sqrt(2)/108], [-sqrt(2)/3, sqrt(2)/108]]
  521. ),
  522. (
  523. Poly(3*x**5 + x**4 - x**3 + x**2 - 2*x - 2, x, domain='ZZ'),
  524. Poly(9*x**5 + 7*x**4 + 3*x**3 + 2*x**2 + 5*x + 7, x, domain='ZZ'),
  525. [[sqrt(3)/3, -2*sqrt(3)/27], [-sqrt(3)/3, 2*sqrt(3)/27]]
  526. ),
  527. (
  528. Poly(x**2 - x + 1, x, domain='ZZ'),
  529. Poly(3*x**2 + 7*x + 3, x, domain='ZZ'),
  530. [[sqrt(3)/3, -5*sqrt(3)/9], [-sqrt(3)/3, 5*sqrt(3)/9]]
  531. )]
  532. for num, den, d in tests:
  533. # Multiplicity of oo is 0
  534. ser = rational_laurent_series(num, den, x, oo, 0, 1)
  535. assert construct_d_case_5(ser) == d
  536. def test_construct_d_case_6():
  537. """
  538. This function tests the Case 6 in the step
  539. to calculate coefficients of the d-vector.
  540. Each test case has 3 values -
  541. 1. num - Numerator of the rational function a(x).
  542. 2. den - Denominator of the rational function a(x).
  543. 3. d - The d-vector.
  544. """
  545. tests = [
  546. (
  547. Poly(-2*x**2 - 5, x, domain='ZZ'),
  548. Poly(4*x**4 + 2*x**2 + 10*x + 2, x, domain='ZZ'),
  549. [[S(1)/2 + I/2], [S(1)/2 - I/2]]
  550. ),
  551. (
  552. Poly(-2*x**3 - 4*x**2 - 2*x - 5, x, domain='ZZ'),
  553. Poly(x**6 - x**5 + 2*x**4 - 4*x**3 - 5*x**2 - 5*x + 9, x, domain='ZZ'),
  554. [[1], [0]]
  555. ),
  556. (
  557. Poly(-5*x**3 + x**2 + 11*x + 12, x, domain='ZZ'),
  558. Poly(6*x**8 - 26*x**7 - 27*x**6 - 10*x**5 - 44*x**4 - 46*x**3 - 34*x**2 \
  559. - 27*x - 42, x, domain='ZZ'),
  560. [[1], [0]]
  561. )]
  562. for num, den, d in tests:
  563. assert construct_d_case_6(num, den, x) == d
  564. def test_rational_laurent_series():
  565. """
  566. This function tests the computation of coefficients
  567. of Laurent series of a rational function.
  568. Each test case has 5 values -
  569. 1. num - Numerator of the rational function.
  570. 2. den - Denominator of the rational function.
  571. 3. x0 - Point about which Laurent series is to
  572. be calculated.
  573. 4. mul - Multiplicity of x0 if x0 is a pole of
  574. the rational function (0 otherwise).
  575. 5. n - Number of terms upto which the series
  576. is to be calculated.
  577. """
  578. tests = [
  579. # Laurent series about simple pole (Multiplicity = 1)
  580. (
  581. Poly(x**2 - 3*x + 9, x, extension=True),
  582. Poly(x**2 - x, x, extension=True),
  583. S(1), 1, 6,
  584. {1: 7, 0: -8, -1: 9, -2: -9, -3: 9, -4: -9}
  585. ),
  586. # Laurent series about multiple pole (Multiplicity > 1)
  587. (
  588. Poly(64*x**3 - 1728*x + 1216, x, extension=True),
  589. Poly(64*x**4 - 80*x**3 - 831*x**2 + 1809*x - 972, x, extension=True),
  590. S(9)/8, 2, 3,
  591. {0: S(32177152)/46521675, 2: S(1019)/984, -1: S(11947565056)/28610830125, \
  592. 1: S(209149)/75645}
  593. ),
  594. (
  595. Poly(1, x, extension=True),
  596. Poly(x**5 + (-4*sqrt(2) - 1)*x**4 + (4*sqrt(2) + 12)*x**3 + (-12 - 8*sqrt(2))*x**2 \
  597. + (4 + 8*sqrt(2))*x - 4, x, extension=True),
  598. sqrt(2), 4, 6,
  599. {4: 1 + sqrt(2), 3: -3 - 2*sqrt(2), 2: Mul(-1, -3 - 2*sqrt(2), evaluate=False)/(-1 \
  600. + sqrt(2)), 1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**2, 0: Mul(-1, -3 - 2*sqrt(2), evaluate=False \
  601. )/(-1 + sqrt(2))**3, -1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**4}
  602. ),
  603. # Laurent series about oo
  604. (
  605. Poly(x**5 - 4*x**3 + 6*x**2 + 10*x - 13, x, extension=True),
  606. Poly(x**2 - 5, x, extension=True),
  607. oo, 3, 6,
  608. {3: 1, 2: 0, 1: 1, 0: 6, -1: 15, -2: 17}
  609. ),
  610. # Laurent series at x0 where x0 is not a pole of the function
  611. # Using multiplicity as 0 (as x0 will not be a pole)
  612. (
  613. Poly(3*x**3 + 6*x**2 - 2*x + 5, x, extension=True),
  614. Poly(9*x**4 - x**3 - 3*x**2 + 4*x + 4, x, extension=True),
  615. S(2)/5, 0, 1,
  616. {0: S(3345)/3304, -1: S(399325)/2729104, -2: S(3926413375)/4508479808, \
  617. -3: S(-5000852751875)/1862002160704, -4: S(-6683640101653125)/6152055138966016}
  618. ),
  619. (
  620. Poly(-7*x**2 + 2*x - 4, x, extension=True),
  621. Poly(7*x**5 + 9*x**4 + 8*x**3 + 3*x**2 + 6*x + 9, x, extension=True),
  622. oo, 0, 6,
  623. {0: 0, -2: 0, -5: -S(71)/49, -1: 0, -3: -1, -4: S(11)/7}
  624. )]
  625. for num, den, x0, mul, n, ser in tests:
  626. assert ser == rational_laurent_series(num, den, x, x0, mul, n)
  627. def check_dummy_sol(eq, solse, dummy_sym):
  628. """
  629. Helper function to check if actual solution
  630. matches expected solution if actual solution
  631. contains dummy symbols.
  632. """
  633. if isinstance(eq, Eq):
  634. eq = eq.lhs - eq.rhs
  635. _, funcs = match_riccati(eq, f, x)
  636. sols = solve_riccati(f(x), x, *funcs)
  637. C1 = Dummy('C1')
  638. sols = [sol.subs(C1, dummy_sym) for sol in sols]
  639. assert all([x[0] for x in checkodesol(eq, sols)])
  640. assert all([s1.dummy_eq(s2, dummy_sym) for s1, s2 in zip(sols, solse)])
  641. def test_solve_riccati():
  642. """
  643. This function tests the computation of rational
  644. particular solutions for a Riccati ODE.
  645. Each test case has 2 values -
  646. 1. eq - Riccati ODE to be solved.
  647. 2. sol - Expected solution to the equation.
  648. Some examples have been taken from the paper - "Statistical Investigation of
  649. First-Order Algebraic ODEs and their Rational General Solutions" by
  650. Georg Grasegger, N. Thieu Vo, Franz Winkler
  651. https://www3.risc.jku.at/publications/download/risc_5197/RISCReport15-19.pdf
  652. """
  653. C0 = Dummy('C0')
  654. # Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2,
  655. # a, b, c are rational functions of x
  656. tests = [
  657. # a(x) is a constant
  658. (
  659. Eq(f(x).diff(x) + f(x)**2 - 2, 0),
  660. [Eq(f(x), sqrt(2)), Eq(f(x), -sqrt(2))]
  661. ),
  662. # a(x) is a constant
  663. (
  664. f(x)**2 + f(x).diff(x) + 4*f(x)/x + 2/x**2,
  665. [Eq(f(x), (-2*C0 - x)/(C0*x + x**2))]
  666. ),
  667. # a(x) is a constant
  668. (
  669. 2*x**2*f(x).diff(x) - x*(4*f(x) + f(x).diff(x) - 4) + (f(x) - 1)*f(x),
  670. [Eq(f(x), (C0 + 2*x**2)/(C0 + x))]
  671. ),
  672. # Pole with multiplicity 1
  673. (
  674. Eq(f(x).diff(x), -f(x)**2 - 2/(x**3 - x**2)),
  675. [Eq(f(x), 1/(x**2 - x))]
  676. ),
  677. # One pole of multiplicity 2
  678. (
  679. x**2 - (2*x + 1/x)*f(x) + f(x)**2 + f(x).diff(x),
  680. [Eq(f(x), (C0*x + x**3 + 2*x)/(C0 + x**2)), Eq(f(x), x)]
  681. ),
  682. (
  683. x**4*f(x).diff(x) + x**2 - x*(2*f(x)**2 + f(x).diff(x)) + f(x),
  684. [Eq(f(x), (C0*x**2 + x)/(C0 + x**2)), Eq(f(x), x**2)]
  685. ),
  686. # Multiple poles of multiplicity 2
  687. (
  688. -f(x)**2 + f(x).diff(x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \
  689. - 1)**2),
  690. [Eq(f(x), (9*C0*x - 6*C0 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 \
  691. - 30*x + 6)/(6*C0*x**2 - 9*C0*x + 3*C0 + 6*x**6 - 29*x**5 + \
  692. 57*x**4 - 58*x**3 + 30*x**2 - 6*x)), Eq(f(x), (3*x - 2)/(2*x**2 \
  693. - 3*x + 1))]
  694. ),
  695. # Regression: Poles with even multiplicity > 2 fixed
  696. (
  697. f(x)**2 + f(x).diff(x) - (4*x**6 - 8*x**5 + 12*x**4 + 4*x**3 + \
  698. 7*x**2 - 20*x + 4)/(4*x**4),
  699. [Eq(f(x), (2*x**5 - 2*x**4 - x**3 + 4*x**2 + 3*x - 2)/(2*x**4 \
  700. - 2*x**2))]
  701. ),
  702. # Regression: Poles with even multiplicity > 2 fixed
  703. (
  704. Eq(f(x).diff(x), (-x**6 + 15*x**4 - 40*x**3 + 45*x**2 - 24*x + 4)/\
  705. (x**12 - 12*x**11 + 66*x**10 - 220*x**9 + 495*x**8 - 792*x**7 + 924*x**6 - \
  706. 792*x**5 + 495*x**4 - 220*x**3 + 66*x**2 - 12*x + 1) + f(x)**2 + f(x)),
  707. [Eq(f(x), 1/(x**6 - 6*x**5 + 15*x**4 - 20*x**3 + 15*x**2 - 6*x + 1))]
  708. ),
  709. # More than 2 poles with multiplicity 2
  710. # Regression: Fixed mistake in necessary conditions
  711. (
  712. Eq(f(x).diff(x), x*f(x) + 2*x + (3*x - 2)*f(x)**2/(4*x + 2) + \
  713. (8*x**2 - 7*x + 26)/(16*x**3 - 24*x**2 + 8) - S(3)/2),
  714. [Eq(f(x), (1 - 4*x)/(2*x - 2))]
  715. ),
  716. # Regression: Fixed mistake in necessary conditions
  717. (
  718. Eq(f(x).diff(x), (-12*x**2 - 48*x - 15)/(24*x**3 - 40*x**2 + 8*x + 8) \
  719. + 3*f(x)**2/(6*x + 2)),
  720. [Eq(f(x), (2*x + 1)/(2*x - 2))]
  721. ),
  722. # Imaginary poles
  723. (
  724. f(x).diff(x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \
  725. - 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2),
  726. [Eq(f(x), (-C0 - x**3 + x**2 - 2*x)/(C0*x - C0 + x**4 - x**3 + x**2 \
  727. - x)), Eq(f(x), -1/(x - 1))],
  728. ),
  729. # Imaginary coefficients in equation
  730. (
  731. f(x).diff(x) - 2*I*(f(x)**2 + 1)/x,
  732. [Eq(f(x), (-I*C0 + I*x**4)/(C0 + x**4)), Eq(f(x), -I)]
  733. ),
  734. # Regression: linsolve returning empty solution
  735. # Large value of m (> 10)
  736. (
  737. Eq(f(x).diff(x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\
  738. (2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)),
  739. [Eq(f(x), (9 - x)/x), Eq(f(x), (40*x**14 + 28*x**13 + 420*x**12 + 2940*x**11 + \
  740. 18480*x**10 + 103950*x**9 + 519750*x**8 + 2286900*x**7 + 8731800*x**6 + 28378350*\
  741. x**5 + 76403250*x**4 + 163721250*x**3 + 261954000*x**2 + 278326125*x + 147349125)/\
  742. ((24*x**14 + 140*x**13 + 840*x**12 + 4620*x**11 + 23100*x**10 + 103950*x**9 + \
  743. 415800*x**8 + 1455300*x**7 + 4365900*x**6 + 10914750*x**5 + 21829500*x**4 + 32744250\
  744. *x**3 + 32744250*x**2 + 16372125*x)))]
  745. ),
  746. # Regression: Fixed bug due to a typo in paper
  747. (
  748. Eq(f(x).diff(x), 18*x**3 + 18*x**2 + (-x/2 - S(1)/2)*f(x)**2 + 6),
  749. [Eq(f(x), 6*x)]
  750. ),
  751. # Regression: Fixed bug due to a typo in paper
  752. (
  753. Eq(f(x).diff(x), -3*x**3/4 + 15*x/2 + (x/3 - S(4)/3)*f(x)**2 \
  754. + 9 + (1 - x)*f(x)/x + 3/x),
  755. [Eq(f(x), -3*x/2 - 3)]
  756. )]
  757. for eq, sol in tests:
  758. check_dummy_sol(eq, sol, C0)
  759. @slow
  760. def test_solve_riccati_slow():
  761. """
  762. This function tests the computation of rational
  763. particular solutions for a Riccati ODE.
  764. Each test case has 2 values -
  765. 1. eq - Riccati ODE to be solved.
  766. 2. sol - Expected solution to the equation.
  767. """
  768. C0 = Dummy('C0')
  769. tests = [
  770. # Very large values of m (989 and 991)
  771. (
  772. Eq(f(x).diff(x), (1 - x)*f(x)/(x - 3) + (2 - 12*x)*f(x)**2/(2*x - 9) + \
  773. (54924*x**3 - 405264*x**2 + 1084347*x - 1087533)/(8*x**4 - 132*x**3 + 810*x**2 - \
  774. 2187*x + 2187) + 495),
  775. [Eq(f(x), (18*x + 6)/(2*x - 9))]
  776. )]
  777. for eq, sol in tests:
  778. check_dummy_sol(eq, sol, C0)