test_match.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766
  1. from sympy import abc
  2. from sympy.concrete.summations import Sum
  3. from sympy.core.add import Add
  4. from sympy.core.function import (Derivative, Function, diff)
  5. from sympy.core.mul import Mul
  6. from sympy.core.numbers import (Float, I, Integer, Rational, oo, pi)
  7. from sympy.core.singleton import S
  8. from sympy.core.symbol import (Symbol, Wild, symbols)
  9. from sympy.functions.elementary.exponential import (exp, log)
  10. from sympy.functions.elementary.miscellaneous import sqrt
  11. from sympy.functions.elementary.trigonometric import (cos, sin)
  12. from sympy.functions.special.hyper import meijerg
  13. from sympy.polys.polytools import Poly
  14. from sympy.simplify.radsimp import collect
  15. from sympy.simplify.simplify import signsimp
  16. from sympy.testing.pytest import XFAIL
  17. def test_symbol():
  18. x = Symbol('x')
  19. a, b, c, p, q = map(Wild, 'abcpq')
  20. e = x
  21. assert e.match(x) == {}
  22. assert e.matches(x) == {}
  23. assert e.match(a) == {a: x}
  24. e = Rational(5)
  25. assert e.match(c) == {c: 5}
  26. assert e.match(e) == {}
  27. assert e.match(e + 1) is None
  28. def test_add():
  29. x, y, a, b, c = map(Symbol, 'xyabc')
  30. p, q, r = map(Wild, 'pqr')
  31. e = a + b
  32. assert e.match(p + b) == {p: a}
  33. assert e.match(p + a) == {p: b}
  34. e = 1 + b
  35. assert e.match(p + b) == {p: 1}
  36. e = a + b + c
  37. assert e.match(a + p + c) == {p: b}
  38. assert e.match(b + p + c) == {p: a}
  39. e = a + b + c + x
  40. assert e.match(a + p + x + c) == {p: b}
  41. assert e.match(b + p + c + x) == {p: a}
  42. assert e.match(b) is None
  43. assert e.match(b + p) == {p: a + c + x}
  44. assert e.match(a + p + c) == {p: b + x}
  45. assert e.match(b + p + c) == {p: a + x}
  46. e = 4*x + 5
  47. assert e.match(4*x + p) == {p: 5}
  48. assert e.match(3*x + p) == {p: x + 5}
  49. assert e.match(p*x + 5) == {p: 4}
  50. def test_power():
  51. x, y, a, b, c = map(Symbol, 'xyabc')
  52. p, q, r = map(Wild, 'pqr')
  53. e = (x + y)**a
  54. assert e.match(p**q) == {p: x + y, q: a}
  55. assert e.match(p**p) is None
  56. e = (x + y)**(x + y)
  57. assert e.match(p**p) == {p: x + y}
  58. assert e.match(p**q) == {p: x + y, q: x + y}
  59. e = (2*x)**2
  60. assert e.match(p*q**r) == {p: 4, q: x, r: 2}
  61. e = Integer(1)
  62. assert e.match(x**p) == {p: 0}
  63. def test_match_exclude():
  64. x = Symbol('x')
  65. y = Symbol('y')
  66. p = Wild("p")
  67. q = Wild("q")
  68. r = Wild("r")
  69. e = Rational(6)
  70. assert e.match(2*p) == {p: 3}
  71. e = 3/(4*x + 5)
  72. assert e.match(3/(p*x + q)) == {p: 4, q: 5}
  73. e = 3/(4*x + 5)
  74. assert e.match(p/(q*x + r)) == {p: 3, q: 4, r: 5}
  75. e = 2/(x + 1)
  76. assert e.match(p/(q*x + r)) == {p: 2, q: 1, r: 1}
  77. e = 1/(x + 1)
  78. assert e.match(p/(q*x + r)) == {p: 1, q: 1, r: 1}
  79. e = 4*x + 5
  80. assert e.match(p*x + q) == {p: 4, q: 5}
  81. e = 4*x + 5*y + 6
  82. assert e.match(p*x + q*y + r) == {p: 4, q: 5, r: 6}
  83. a = Wild('a', exclude=[x])
  84. e = 3*x
  85. assert e.match(p*x) == {p: 3}
  86. assert e.match(a*x) == {a: 3}
  87. e = 3*x**2
  88. assert e.match(p*x) == {p: 3*x}
  89. assert e.match(a*x) is None
  90. e = 3*x + 3 + 6/x
  91. assert e.match(p*x**2 + p*x + 2*p) == {p: 3/x}
  92. assert e.match(a*x**2 + a*x + 2*a) is None
  93. def test_mul():
  94. x, y, a, b, c = map(Symbol, 'xyabc')
  95. p, q = map(Wild, 'pq')
  96. e = 4*x
  97. assert e.match(p*x) == {p: 4}
  98. assert e.match(p*y) is None
  99. assert e.match(e + p*y) == {p: 0}
  100. e = a*x*b*c
  101. assert e.match(p*x) == {p: a*b*c}
  102. assert e.match(c*p*x) == {p: a*b}
  103. e = (a + b)*(a + c)
  104. assert e.match((p + b)*(p + c)) == {p: a}
  105. e = x
  106. assert e.match(p*x) == {p: 1}
  107. e = exp(x)
  108. assert e.match(x**p*exp(x*q)) == {p: 0, q: 1}
  109. e = I*Poly(x, x)
  110. assert e.match(I*p) == {p: x}
  111. def test_mul_noncommutative():
  112. x, y = symbols('x y')
  113. A, B, C = symbols('A B C', commutative=False)
  114. u, v = symbols('u v', cls=Wild)
  115. w, z = symbols('w z', cls=Wild, commutative=False)
  116. assert (u*v).matches(x) in ({v: x, u: 1}, {u: x, v: 1})
  117. assert (u*v).matches(x*y) in ({v: y, u: x}, {u: y, v: x})
  118. assert (u*v).matches(A) is None
  119. assert (u*v).matches(A*B) is None
  120. assert (u*v).matches(x*A) is None
  121. assert (u*v).matches(x*y*A) is None
  122. assert (u*v).matches(x*A*B) is None
  123. assert (u*v).matches(x*y*A*B) is None
  124. assert (v*w).matches(x) is None
  125. assert (v*w).matches(x*y) is None
  126. assert (v*w).matches(A) == {w: A, v: 1}
  127. assert (v*w).matches(A*B) == {w: A*B, v: 1}
  128. assert (v*w).matches(x*A) == {w: A, v: x}
  129. assert (v*w).matches(x*y*A) == {w: A, v: x*y}
  130. assert (v*w).matches(x*A*B) == {w: A*B, v: x}
  131. assert (v*w).matches(x*y*A*B) == {w: A*B, v: x*y}
  132. assert (v*w).matches(-x) is None
  133. assert (v*w).matches(-x*y) is None
  134. assert (v*w).matches(-A) == {w: A, v: -1}
  135. assert (v*w).matches(-A*B) == {w: A*B, v: -1}
  136. assert (v*w).matches(-x*A) == {w: A, v: -x}
  137. assert (v*w).matches(-x*y*A) == {w: A, v: -x*y}
  138. assert (v*w).matches(-x*A*B) == {w: A*B, v: -x}
  139. assert (v*w).matches(-x*y*A*B) == {w: A*B, v: -x*y}
  140. assert (w*z).matches(x) is None
  141. assert (w*z).matches(x*y) is None
  142. assert (w*z).matches(A) is None
  143. assert (w*z).matches(A*B) == {w: A, z: B}
  144. assert (w*z).matches(B*A) == {w: B, z: A}
  145. assert (w*z).matches(A*B*C) in [{w: A, z: B*C}, {w: A*B, z: C}]
  146. assert (w*z).matches(x*A) is None
  147. assert (w*z).matches(x*y*A) is None
  148. assert (w*z).matches(x*A*B) is None
  149. assert (w*z).matches(x*y*A*B) is None
  150. assert (w*A).matches(A) is None
  151. assert (A*w*B).matches(A*B) is None
  152. assert (u*w*z).matches(x) is None
  153. assert (u*w*z).matches(x*y) is None
  154. assert (u*w*z).matches(A) is None
  155. assert (u*w*z).matches(A*B) == {u: 1, w: A, z: B}
  156. assert (u*w*z).matches(B*A) == {u: 1, w: B, z: A}
  157. assert (u*w*z).matches(x*A) is None
  158. assert (u*w*z).matches(x*y*A) is None
  159. assert (u*w*z).matches(x*A*B) == {u: x, w: A, z: B}
  160. assert (u*w*z).matches(x*B*A) == {u: x, w: B, z: A}
  161. assert (u*w*z).matches(x*y*A*B) == {u: x*y, w: A, z: B}
  162. assert (u*w*z).matches(x*y*B*A) == {u: x*y, w: B, z: A}
  163. assert (u*A).matches(x*A) == {u: x}
  164. assert (u*A).matches(x*A*B) is None
  165. assert (u*B).matches(x*A) is None
  166. assert (u*A*B).matches(x*A*B) == {u: x}
  167. assert (u*A*B).matches(x*B*A) is None
  168. assert (u*A*B).matches(x*A) is None
  169. assert (u*w*A).matches(x*A*B) is None
  170. assert (u*w*B).matches(x*A*B) == {u: x, w: A}
  171. assert (u*v*A*B).matches(x*A*B) in [{u: x, v: 1}, {v: x, u: 1}]
  172. assert (u*v*A*B).matches(x*B*A) is None
  173. assert (u*v*A*B).matches(u*v*A*C) is None
  174. def test_mul_noncommutative_mismatch():
  175. A, B, C = symbols('A B C', commutative=False)
  176. w = symbols('w', cls=Wild, commutative=False)
  177. assert (w*B*w).matches(A*B*A) == {w: A}
  178. assert (w*B*w).matches(A*C*B*A*C) == {w: A*C}
  179. assert (w*B*w).matches(A*C*B*A*B) is None
  180. assert (w*B*w).matches(A*B*C) is None
  181. assert (w*w*C).matches(A*B*C) is None
  182. def test_mul_noncommutative_pow():
  183. A, B, C = symbols('A B C', commutative=False)
  184. w = symbols('w', cls=Wild, commutative=False)
  185. assert (A*B*w).matches(A*B**2) == {w: B}
  186. assert (A*(B**2)*w*(B**3)).matches(A*B**8) == {w: B**3}
  187. assert (A*B*w*C).matches(A*(B**4)*C) == {w: B**3}
  188. assert (A*B*(w**(-1))).matches(A*B*(C**(-1))) == {w: C}
  189. assert (A*(B*w)**(-1)*C).matches(A*(B*C)**(-1)*C) == {w: C}
  190. assert ((w**2)*B*C).matches((A**2)*B*C) == {w: A}
  191. assert ((w**2)*B*(w**3)).matches((A**2)*B*(A**3)) == {w: A}
  192. assert ((w**2)*B*(w**4)).matches((A**2)*B*(A**2)) is None
  193. def test_complex():
  194. a, b, c = map(Symbol, 'abc')
  195. x, y = map(Wild, 'xy')
  196. assert (1 + I).match(x + I) == {x: 1}
  197. assert (a + I).match(x + I) == {x: a}
  198. assert (2*I).match(x*I) == {x: 2}
  199. assert (a*I).match(x*I) == {x: a}
  200. assert (a*I).match(x*y) == {x: I, y: a}
  201. assert (2*I).match(x*y) == {x: 2, y: I}
  202. assert (a + b*I).match(x + y*I) == {x: a, y: b}
  203. def test_functions():
  204. from sympy.core.function import WildFunction
  205. x = Symbol('x')
  206. g = WildFunction('g')
  207. p = Wild('p')
  208. q = Wild('q')
  209. f = cos(5*x)
  210. notf = x
  211. assert f.match(p*cos(q*x)) == {p: 1, q: 5}
  212. assert f.match(p*g) == {p: 1, g: cos(5*x)}
  213. assert notf.match(g) is None
  214. @XFAIL
  215. def test_functions_X1():
  216. from sympy.core.function import WildFunction
  217. x = Symbol('x')
  218. g = WildFunction('g')
  219. p = Wild('p')
  220. q = Wild('q')
  221. f = cos(5*x)
  222. assert f.match(p*g(q*x)) == {p: 1, g: cos, q: 5}
  223. def test_interface():
  224. x, y = map(Symbol, 'xy')
  225. p, q = map(Wild, 'pq')
  226. assert (x + 1).match(p + 1) == {p: x}
  227. assert (x*3).match(p*3) == {p: x}
  228. assert (x**3).match(p**3) == {p: x}
  229. assert (x*cos(y)).match(p*cos(q)) == {p: x, q: y}
  230. assert (x*y).match(p*q) in [{p:x, q:y}, {p:y, q:x}]
  231. assert (x + y).match(p + q) in [{p:x, q:y}, {p:y, q:x}]
  232. assert (x*y + 1).match(p*q) in [{p:1, q:1 + x*y}, {p:1 + x*y, q:1}]
  233. def test_derivative1():
  234. x, y = map(Symbol, 'xy')
  235. p, q = map(Wild, 'pq')
  236. f = Function('f', nargs=1)
  237. fd = Derivative(f(x), x)
  238. assert fd.match(p) == {p: fd}
  239. assert (fd + 1).match(p + 1) == {p: fd}
  240. assert (fd).match(fd) == {}
  241. assert (3*fd).match(p*fd) is not None
  242. assert (3*fd - 1).match(p*fd + q) == {p: 3, q: -1}
  243. def test_derivative_bug1():
  244. f = Function("f")
  245. x = Symbol("x")
  246. a = Wild("a", exclude=[f, x])
  247. b = Wild("b", exclude=[f])
  248. pattern = a * Derivative(f(x), x, x) + b
  249. expr = Derivative(f(x), x) + x**2
  250. d1 = {b: x**2}
  251. d2 = pattern.xreplace(d1).matches(expr, d1)
  252. assert d2 is None
  253. def test_derivative2():
  254. f = Function("f")
  255. x = Symbol("x")
  256. a = Wild("a", exclude=[f, x])
  257. b = Wild("b", exclude=[f])
  258. e = Derivative(f(x), x)
  259. assert e.match(Derivative(f(x), x)) == {}
  260. assert e.match(Derivative(f(x), x, x)) is None
  261. e = Derivative(f(x), x, x)
  262. assert e.match(Derivative(f(x), x)) is None
  263. assert e.match(Derivative(f(x), x, x)) == {}
  264. e = Derivative(f(x), x) + x**2
  265. assert e.match(a*Derivative(f(x), x) + b) == {a: 1, b: x**2}
  266. assert e.match(a*Derivative(f(x), x, x) + b) is None
  267. e = Derivative(f(x), x, x) + x**2
  268. assert e.match(a*Derivative(f(x), x) + b) is None
  269. assert e.match(a*Derivative(f(x), x, x) + b) == {a: 1, b: x**2}
  270. def test_match_deriv_bug1():
  271. n = Function('n')
  272. l = Function('l')
  273. x = Symbol('x')
  274. p = Wild('p')
  275. e = diff(l(x), x)/x - diff(diff(n(x), x), x)/2 - \
  276. diff(n(x), x)**2/4 + diff(n(x), x)*diff(l(x), x)/4
  277. e = e.subs(n(x), -l(x)).doit()
  278. t = x*exp(-l(x))
  279. t2 = t.diff(x, x)/t
  280. assert e.match( (p*t2).expand() ) == {p: Rational(-1, 2)}
  281. def test_match_bug2():
  282. x, y = map(Symbol, 'xy')
  283. p, q, r = map(Wild, 'pqr')
  284. res = (x + y).match(p + q + r)
  285. assert (p + q + r).subs(res) == x + y
  286. def test_match_bug3():
  287. x, a, b = map(Symbol, 'xab')
  288. p = Wild('p')
  289. assert (b*x*exp(a*x)).match(x*exp(p*x)) is None
  290. def test_match_bug4():
  291. x = Symbol('x')
  292. p = Wild('p')
  293. e = x
  294. assert e.match(-p*x) == {p: -1}
  295. def test_match_bug5():
  296. x = Symbol('x')
  297. p = Wild('p')
  298. e = -x
  299. assert e.match(-p*x) == {p: 1}
  300. def test_match_bug6():
  301. x = Symbol('x')
  302. p = Wild('p')
  303. e = x
  304. assert e.match(3*p*x) == {p: Rational(1)/3}
  305. def test_match_polynomial():
  306. x = Symbol('x')
  307. a = Wild('a', exclude=[x])
  308. b = Wild('b', exclude=[x])
  309. c = Wild('c', exclude=[x])
  310. d = Wild('d', exclude=[x])
  311. eq = 4*x**3 + 3*x**2 + 2*x + 1
  312. pattern = a*x**3 + b*x**2 + c*x + d
  313. assert eq.match(pattern) == {a: 4, b: 3, c: 2, d: 1}
  314. assert (eq - 3*x**2).match(pattern) == {a: 4, b: 0, c: 2, d: 1}
  315. assert (x + sqrt(2) + 3).match(a + b*x + c*x**2) == \
  316. {b: 1, a: sqrt(2) + 3, c: 0}
  317. def test_exclude():
  318. x, y, a = map(Symbol, 'xya')
  319. p = Wild('p', exclude=[1, x])
  320. q = Wild('q')
  321. r = Wild('r', exclude=[sin, y])
  322. assert sin(x).match(r) is None
  323. assert cos(y).match(r) is None
  324. e = 3*x**2 + y*x + a
  325. assert e.match(p*x**2 + q*x + r) == {p: 3, q: y, r: a}
  326. e = x + 1
  327. assert e.match(x + p) is None
  328. assert e.match(p + 1) is None
  329. assert e.match(x + 1 + p) == {p: 0}
  330. e = cos(x) + 5*sin(y)
  331. assert e.match(r) is None
  332. assert e.match(cos(y) + r) is None
  333. assert e.match(r + p*sin(q)) == {r: cos(x), p: 5, q: y}
  334. def test_floats():
  335. a, b = map(Wild, 'ab')
  336. e = cos(0.12345, evaluate=False)**2
  337. r = e.match(a*cos(b)**2)
  338. assert r == {a: 1, b: Float(0.12345)}
  339. def test_Derivative_bug1():
  340. f = Function("f")
  341. x = abc.x
  342. a = Wild("a", exclude=[f(x)])
  343. b = Wild("b", exclude=[f(x)])
  344. eq = f(x).diff(x)
  345. assert eq.match(a*Derivative(f(x), x) + b) == {a: 1, b: 0}
  346. def test_match_wild_wild():
  347. p = Wild('p')
  348. q = Wild('q')
  349. r = Wild('r')
  350. assert p.match(q + r) in [ {q: p, r: 0}, {q: 0, r: p} ]
  351. assert p.match(q*r) in [ {q: p, r: 1}, {q: 1, r: p} ]
  352. p = Wild('p')
  353. q = Wild('q', exclude=[p])
  354. r = Wild('r')
  355. assert p.match(q + r) == {q: 0, r: p}
  356. assert p.match(q*r) == {q: 1, r: p}
  357. p = Wild('p')
  358. q = Wild('q', exclude=[p])
  359. r = Wild('r', exclude=[p])
  360. assert p.match(q + r) is None
  361. assert p.match(q*r) is None
  362. def test__combine_inverse():
  363. x, y = symbols("x y")
  364. assert Mul._combine_inverse(x*I*y, x*I) == y
  365. assert Mul._combine_inverse(x*x**(1 + y), x**(1 + y)) == x
  366. assert Mul._combine_inverse(x*I*y, y*I) == x
  367. assert Mul._combine_inverse(oo*I*y, y*I) is oo
  368. assert Mul._combine_inverse(oo*I*y, oo*I) == y
  369. assert Mul._combine_inverse(oo*I*y, oo*I) == y
  370. assert Mul._combine_inverse(oo*y, -oo) == -y
  371. assert Mul._combine_inverse(-oo*y, oo) == -y
  372. assert Mul._combine_inverse((1-exp(x/y)),(exp(x/y)-1)) == -1
  373. assert Add._combine_inverse(oo, oo) is S.Zero
  374. assert Add._combine_inverse(oo*I, oo*I) is S.Zero
  375. assert Add._combine_inverse(x*oo, x*oo) is S.Zero
  376. assert Add._combine_inverse(-x*oo, -x*oo) is S.Zero
  377. assert Add._combine_inverse((x - oo)*(x + oo), -oo)
  378. def test_issue_3773():
  379. x = symbols('x')
  380. z, phi, r = symbols('z phi r')
  381. c, A, B, N = symbols('c A B N', cls=Wild)
  382. l = Wild('l', exclude=(0,))
  383. eq = z * sin(2*phi) * r**7
  384. matcher = c * sin(phi*N)**l * r**A * log(r)**B
  385. assert eq.match(matcher) == {c: z, l: 1, N: 2, A: 7, B: 0}
  386. assert (-eq).match(matcher) == {c: -z, l: 1, N: 2, A: 7, B: 0}
  387. assert (x*eq).match(matcher) == {c: x*z, l: 1, N: 2, A: 7, B: 0}
  388. assert (-7*x*eq).match(matcher) == {c: -7*x*z, l: 1, N: 2, A: 7, B: 0}
  389. matcher = c*sin(phi*N)**l * r**A
  390. assert eq.match(matcher) == {c: z, l: 1, N: 2, A: 7}
  391. assert (-eq).match(matcher) == {c: -z, l: 1, N: 2, A: 7}
  392. assert (x*eq).match(matcher) == {c: x*z, l: 1, N: 2, A: 7}
  393. assert (-7*x*eq).match(matcher) == {c: -7*x*z, l: 1, N: 2, A: 7}
  394. def test_issue_3883():
  395. from sympy.abc import gamma, mu, x
  396. f = (-gamma * (x - mu)**2 - log(gamma) + log(2*pi))/2
  397. a, b, c = symbols('a b c', cls=Wild, exclude=(gamma,))
  398. assert f.match(a * log(gamma) + b * gamma + c) == \
  399. {a: Rational(-1, 2), b: -(-mu + x)**2/2, c: log(2*pi)/2}
  400. assert f.expand().collect(gamma).match(a * log(gamma) + b * gamma + c) == \
  401. {a: Rational(-1, 2), b: (-(x - mu)**2/2).expand(), c: (log(2*pi)/2).expand()}
  402. g1 = Wild('g1', exclude=[gamma])
  403. g2 = Wild('g2', exclude=[gamma])
  404. g3 = Wild('g3', exclude=[gamma])
  405. assert f.expand().match(g1 * log(gamma) + g2 * gamma + g3) == \
  406. {g3: log(2)/2 + log(pi)/2, g1: Rational(-1, 2), g2: -mu**2/2 + mu*x - x**2/2}
  407. def test_issue_4418():
  408. x = Symbol('x')
  409. a, b, c = symbols('a b c', cls=Wild, exclude=(x,))
  410. f, g = symbols('f g', cls=Function)
  411. eq = diff(g(x)*f(x).diff(x), x)
  412. assert eq.match(
  413. g(x).diff(x)*f(x).diff(x) + g(x)*f(x).diff(x, x) + c) == {c: 0}
  414. assert eq.match(a*g(x).diff(
  415. x)*f(x).diff(x) + b*g(x)*f(x).diff(x, x) + c) == {a: 1, b: 1, c: 0}
  416. def test_issue_4700():
  417. f = Function('f')
  418. x = Symbol('x')
  419. a, b = symbols('a b', cls=Wild, exclude=(f(x),))
  420. p = a*f(x) + b
  421. eq1 = sin(x)
  422. eq2 = f(x) + sin(x)
  423. eq3 = f(x) + x + sin(x)
  424. eq4 = x + sin(x)
  425. assert eq1.match(p) == {a: 0, b: sin(x)}
  426. assert eq2.match(p) == {a: 1, b: sin(x)}
  427. assert eq3.match(p) == {a: 1, b: x + sin(x)}
  428. assert eq4.match(p) == {a: 0, b: x + sin(x)}
  429. def test_issue_5168():
  430. a, b, c = symbols('a b c', cls=Wild)
  431. x = Symbol('x')
  432. f = Function('f')
  433. assert x.match(a) == {a: x}
  434. assert x.match(a*f(x)**c) == {a: x, c: 0}
  435. assert x.match(a*b) == {a: 1, b: x}
  436. assert x.match(a*b*f(x)**c) == {a: 1, b: x, c: 0}
  437. assert (-x).match(a) == {a: -x}
  438. assert (-x).match(a*f(x)**c) == {a: -x, c: 0}
  439. assert (-x).match(a*b) == {a: -1, b: x}
  440. assert (-x).match(a*b*f(x)**c) == {a: -1, b: x, c: 0}
  441. assert (2*x).match(a) == {a: 2*x}
  442. assert (2*x).match(a*f(x)**c) == {a: 2*x, c: 0}
  443. assert (2*x).match(a*b) == {a: 2, b: x}
  444. assert (2*x).match(a*b*f(x)**c) == {a: 2, b: x, c: 0}
  445. assert (-2*x).match(a) == {a: -2*x}
  446. assert (-2*x).match(a*f(x)**c) == {a: -2*x, c: 0}
  447. assert (-2*x).match(a*b) == {a: -2, b: x}
  448. assert (-2*x).match(a*b*f(x)**c) == {a: -2, b: x, c: 0}
  449. def test_issue_4559():
  450. x = Symbol('x')
  451. e = Symbol('e')
  452. w = Wild('w', exclude=[x])
  453. y = Wild('y')
  454. # this is as it should be
  455. assert (3/x).match(w/y) == {w: 3, y: x}
  456. assert (3*x).match(w*y) == {w: 3, y: x}
  457. assert (x/3).match(y/w) == {w: 3, y: x}
  458. assert (3*x).match(y/w) == {w: S.One/3, y: x}
  459. assert (3*x).match(y/w) == {w: Rational(1, 3), y: x}
  460. # these could be allowed to fail
  461. assert (x/3).match(w/y) == {w: S.One/3, y: 1/x}
  462. assert (3*x).match(w/y) == {w: 3, y: 1/x}
  463. assert (3/x).match(w*y) == {w: 3, y: 1/x}
  464. # Note that solve will give
  465. # multiple roots but match only gives one:
  466. #
  467. # >>> solve(x**r-y**2,y)
  468. # [-x**(r/2), x**(r/2)]
  469. r = Symbol('r', rational=True)
  470. assert (x**r).match(y**2) == {y: x**(r/2)}
  471. assert (x**e).match(y**2) == {y: sqrt(x**e)}
  472. # since (x**i = y) -> x = y**(1/i) where i is an integer
  473. # the following should also be valid as long as y is not
  474. # zero when i is negative.
  475. a = Wild('a')
  476. e = S.Zero
  477. assert e.match(a) == {a: e}
  478. assert e.match(1/a) is None
  479. assert e.match(a**.3) is None
  480. e = S(3)
  481. assert e.match(1/a) == {a: 1/e}
  482. assert e.match(1/a**2) == {a: 1/sqrt(e)}
  483. e = pi
  484. assert e.match(1/a) == {a: 1/e}
  485. assert e.match(1/a**2) == {a: 1/sqrt(e)}
  486. assert (-e).match(sqrt(a)) is None
  487. assert (-e).match(a**2) == {a: I*sqrt(pi)}
  488. # The pattern matcher doesn't know how to handle (x - a)**2 == (a - x)**2. To
  489. # avoid ambiguity in actual applications, don't put a coefficient (including a
  490. # minus sign) in front of a wild.
  491. @XFAIL
  492. def test_issue_4883():
  493. a = Wild('a')
  494. x = Symbol('x')
  495. e = [i**2 for i in (x - 2, 2 - x)]
  496. p = [i**2 for i in (x - a, a- x)]
  497. for eq in e:
  498. for pat in p:
  499. assert eq.match(pat) == {a: 2}
  500. def test_issue_4319():
  501. x, y = symbols('x y')
  502. p = -x*(S.One/8 - y)
  503. ans = {S.Zero, y - S.One/8}
  504. def ok(pat):
  505. assert set(p.match(pat).values()) == ans
  506. ok(Wild("coeff", exclude=[x])*x + Wild("rest"))
  507. ok(Wild("w", exclude=[x])*x + Wild("rest"))
  508. ok(Wild("coeff", exclude=[x])*x + Wild("rest"))
  509. ok(Wild("w", exclude=[x])*x + Wild("rest"))
  510. ok(Wild("e", exclude=[x])*x + Wild("rest"))
  511. ok(Wild("ress", exclude=[x])*x + Wild("rest"))
  512. ok(Wild("resu", exclude=[x])*x + Wild("rest"))
  513. def test_issue_3778():
  514. p, c, q = symbols('p c q', cls=Wild)
  515. x = Symbol('x')
  516. assert (sin(x)**2).match(sin(p)*sin(q)*c) == {q: x, c: 1, p: x}
  517. assert (2*sin(x)).match(sin(p) + sin(q) + c) == {q: x, c: 0, p: x}
  518. def test_issue_6103():
  519. x = Symbol('x')
  520. a = Wild('a')
  521. assert (-I*x*oo).match(I*a*oo) == {a: -x}
  522. def test_issue_3539():
  523. a = Wild('a')
  524. x = Symbol('x')
  525. assert (x - 2).match(a - x) is None
  526. assert (6/x).match(a*x) is None
  527. assert (6/x**2).match(a/x) == {a: 6/x}
  528. def test_gh_issue_2711():
  529. x = Symbol('x')
  530. f = meijerg(((), ()), ((0,), ()), x)
  531. a = Wild('a')
  532. b = Wild('b')
  533. assert f.find(a) == {(S.Zero,), ((), ()), ((S.Zero,), ()), x, S.Zero,
  534. (), meijerg(((), ()), ((S.Zero,), ()), x)}
  535. assert f.find(a + b) == \
  536. {meijerg(((), ()), ((S.Zero,), ()), x), x, S.Zero}
  537. assert f.find(a**2) == {meijerg(((), ()), ((S.Zero,), ()), x), x}
  538. def test_issue_17354():
  539. from sympy.core.symbol import (Wild, symbols)
  540. x, y = symbols("x y", real=True)
  541. a, b = symbols("a b", cls=Wild)
  542. assert ((0 <= x).reversed | (y <= x)).match((1/a <= b) | (a <= b)) is None
  543. def test_match_issue_17397():
  544. f = Function("f")
  545. x = Symbol("x")
  546. a3 = Wild('a3', exclude=[f(x), f(x).diff(x), f(x).diff(x, 2)])
  547. b3 = Wild('b3', exclude=[f(x), f(x).diff(x), f(x).diff(x, 2)])
  548. c3 = Wild('c3', exclude=[f(x), f(x).diff(x), f(x).diff(x, 2)])
  549. deq = a3*(f(x).diff(x, 2)) + b3*f(x).diff(x) + c3*f(x)
  550. eq = (x-2)**2*(f(x).diff(x, 2)) + (x-2)*(f(x).diff(x)) + ((x-2)**2 - 4)*f(x)
  551. r = collect(eq, [f(x).diff(x, 2), f(x).diff(x), f(x)]).match(deq)
  552. assert r == {a3: (x - 2)**2, c3: (x - 2)**2 - 4, b3: x - 2}
  553. eq =x*f(x) + x*Derivative(f(x), (x, 2)) - 4*f(x) + Derivative(f(x), x) \
  554. - 4*Derivative(f(x), (x, 2)) - 2*Derivative(f(x), x)/x + 4*Derivative(f(x), (x, 2))/x
  555. r = collect(eq, [f(x).diff(x, 2), f(x).diff(x), f(x)]).match(deq)
  556. assert r == {a3: x - 4 + 4/x, b3: 1 - 2/x, c3: x - 4}
  557. def test_match_issue_21942():
  558. a, r, w = symbols('a, r, w', nonnegative=True)
  559. p = symbols('p', positive=True)
  560. g_ = Wild('g')
  561. pattern = g_ ** (1 / (1 - p))
  562. eq = (a * r ** (1 - p) + w ** (1 - p) * (1 - a)) ** (1 / (1 - p))
  563. m = {g_: a * r ** (1 - p) + w ** (1 - p) * (1 - a)}
  564. assert pattern.matches(eq) == m
  565. assert (-pattern).matches(-eq) == m
  566. assert pattern.matches(signsimp(eq)) is None
  567. def test_match_terms():
  568. X, Y = map(Wild, "XY")
  569. x, y, z = symbols('x y z')
  570. assert (5*y - x).match(5*X - Y) == {X: y, Y: x}
  571. # 15907
  572. assert (x + (y - 1)*z).match(x + X*z) == {X: y - 1}
  573. # 20747
  574. assert (x - log(x/y)*(1-exp(x/y))).match(x - log(X/y)*(1-exp(x/y))) == {X: x}
  575. def test_match_bound():
  576. V, W = map(Wild, "VW")
  577. x, y = symbols('x y')
  578. assert Sum(x, (x, 1, 2)).match(Sum(y, (y, 1, W))) == {W: 2}
  579. assert Sum(x, (x, 1, 2)).match(Sum(V, (V, 1, W))) == {W: 2, V:x}
  580. assert Sum(x, (x, 1, 2)).match(Sum(V, (V, 1, 2))) == {V:x}
  581. def test_issue_22462():
  582. x, f = symbols('x'), Function('f')
  583. n, Q = symbols('n Q', cls=Wild)
  584. pattern = -Q*f(x)**n
  585. eq = 5*f(x)**2
  586. assert pattern.matches(eq) == {n: 2, Q: -5}