test_cse.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750
  1. from functools import reduce
  2. import itertools
  3. from operator import add
  4. from sympy.codegen.matrix_nodes import MatrixSolve
  5. from sympy.core.add import Add
  6. from sympy.core.containers import Tuple
  7. from sympy.core.expr import UnevaluatedExpr
  8. from sympy.core.function import Function
  9. from sympy.core.mul import Mul
  10. from sympy.core.power import Pow
  11. from sympy.core.relational import Eq
  12. from sympy.core.singleton import S
  13. from sympy.core.symbol import (Symbol, symbols)
  14. from sympy.core.sympify import sympify
  15. from sympy.functions.elementary.exponential import exp
  16. from sympy.functions.elementary.miscellaneous import sqrt
  17. from sympy.functions.elementary.piecewise import Piecewise
  18. from sympy.functions.elementary.trigonometric import (cos, sin)
  19. from sympy.matrices.dense import Matrix
  20. from sympy.matrices.expressions import Inverse, MatAdd, MatMul, Transpose
  21. from sympy.polys.rootoftools import CRootOf
  22. from sympy.series.order import O
  23. from sympy.simplify.cse_main import cse
  24. from sympy.simplify.simplify import signsimp
  25. from sympy.tensor.indexed import (Idx, IndexedBase)
  26. from sympy.core.function import count_ops
  27. from sympy.simplify.cse_opts import sub_pre, sub_post
  28. from sympy.functions.special.hyper import meijerg
  29. from sympy.simplify import cse_main, cse_opts
  30. from sympy.utilities.iterables import subsets
  31. from sympy.testing.pytest import XFAIL, raises
  32. from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix,
  33. ImmutableDenseMatrix, ImmutableSparseMatrix)
  34. from sympy.matrices.expressions import MatrixSymbol
  35. w, x, y, z = symbols('w,x,y,z')
  36. x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13')
  37. def test_numbered_symbols():
  38. ns = cse_main.numbered_symbols(prefix='y')
  39. assert list(itertools.islice(
  40. ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)]
  41. ns = cse_main.numbered_symbols(prefix='y')
  42. assert list(itertools.islice(
  43. ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)]
  44. ns = cse_main.numbered_symbols()
  45. assert list(itertools.islice(
  46. ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)]
  47. # Dummy "optimization" functions for testing.
  48. def opt1(expr):
  49. return expr + y
  50. def opt2(expr):
  51. return expr*z
  52. def test_preprocess_for_cse():
  53. assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y
  54. assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x
  55. assert cse_main.preprocess_for_cse(x, [(None, None)]) == x
  56. assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y
  57. assert cse_main.preprocess_for_cse(
  58. x, [(opt1, None), (opt2, None)]) == (x + y)*z
  59. def test_postprocess_for_cse():
  60. assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x
  61. assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y
  62. assert cse_main.postprocess_for_cse(x, [(None, None)]) == x
  63. assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z
  64. # Note the reverse order of application.
  65. assert cse_main.postprocess_for_cse(
  66. x, [(None, opt1), (None, opt2)]) == x*z + y
  67. def test_cse_single():
  68. # Simple substitution.
  69. e = Add(Pow(x + y, 2), sqrt(x + y))
  70. substs, reduced = cse([e])
  71. assert substs == [(x0, x + y)]
  72. assert reduced == [sqrt(x0) + x0**2]
  73. subst42, (red42,) = cse([42]) # issue_15082
  74. assert len(subst42) == 0 and red42 == 42
  75. subst_half, (red_half,) = cse([0.5])
  76. assert len(subst_half) == 0 and red_half == 0.5
  77. def test_cse_single2():
  78. # Simple substitution, test for being able to pass the expression directly
  79. e = Add(Pow(x + y, 2), sqrt(x + y))
  80. substs, reduced = cse(e)
  81. assert substs == [(x0, x + y)]
  82. assert reduced == [sqrt(x0) + x0**2]
  83. substs, reduced = cse(Matrix([[1]]))
  84. assert isinstance(reduced[0], Matrix)
  85. subst42, (red42,) = cse(42) # issue 15082
  86. assert len(subst42) == 0 and red42 == 42
  87. subst_half, (red_half,) = cse(0.5) # issue 15082
  88. assert len(subst_half) == 0 and red_half == 0.5
  89. def test_cse_not_possible():
  90. # No substitution possible.
  91. e = Add(x, y)
  92. substs, reduced = cse([e])
  93. assert substs == []
  94. assert reduced == [x + y]
  95. # issue 6329
  96. eq = (meijerg((1, 2), (y, 4), (5,), [], x) +
  97. meijerg((1, 3), (y, 4), (5,), [], x))
  98. assert cse(eq) == ([], [eq])
  99. def test_nested_substitution():
  100. # Substitution within a substitution.
  101. e = Add(Pow(w*x + y, 2), sqrt(w*x + y))
  102. substs, reduced = cse([e])
  103. assert substs == [(x0, w*x + y)]
  104. assert reduced == [sqrt(x0) + x0**2]
  105. def test_subtraction_opt():
  106. # Make sure subtraction is optimized.
  107. e = (x - y)*(z - y) + exp((x - y)*(z - y))
  108. substs, reduced = cse(
  109. [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
  110. assert substs == [(x0, (x - y)*(y - z))]
  111. assert reduced == [-x0 + exp(-x0)]
  112. e = -(x - y)*(z - y) + exp(-(x - y)*(z - y))
  113. substs, reduced = cse(
  114. [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)])
  115. assert substs == [(x0, (x - y)*(y - z))]
  116. assert reduced == [x0 + exp(x0)]
  117. # issue 4077
  118. n = -1 + 1/x
  119. e = n/x/(-n)**2 - 1/n/x
  120. assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \
  121. ([], [0])
  122. assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \
  123. ([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3])
  124. def test_multiple_expressions():
  125. e1 = (x + y)*z
  126. e2 = (x + y)*w
  127. substs, reduced = cse([e1, e2])
  128. assert substs == [(x0, x + y)]
  129. assert reduced == [x0*z, x0*w]
  130. l = [w*x*y + z, w*y]
  131. substs, reduced = cse(l)
  132. rsubsts, _ = cse(reversed(l))
  133. assert substs == rsubsts
  134. assert reduced == [z + x*x0, x0]
  135. l = [w*x*y, w*x*y + z, w*y]
  136. substs, reduced = cse(l)
  137. rsubsts, _ = cse(reversed(l))
  138. assert substs == rsubsts
  139. assert reduced == [x1, x1 + z, x0]
  140. l = [(x - z)*(y - z), x - z, y - z]
  141. substs, reduced = cse(l)
  142. rsubsts, _ = cse(reversed(l))
  143. assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)]
  144. assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)]
  145. assert reduced == [x1*x2, x1, x2]
  146. l = [w*y + w + x + y + z, w*x*y]
  147. assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0])
  148. assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0])
  149. assert cse([x + y, x + z]) == ([], [x + y, x + z])
  150. assert cse([x*y, z + x*y, x*y*z + 3]) == \
  151. ([(x0, x*y)], [x0, z + x0, 3 + x0*z])
  152. @XFAIL # CSE of non-commutative Mul terms is disabled
  153. def test_non_commutative_cse():
  154. A, B, C = symbols('A B C', commutative=False)
  155. l = [A*B*C, A*C]
  156. assert cse(l) == ([], l)
  157. l = [A*B*C, A*B]
  158. assert cse(l) == ([(x0, A*B)], [x0*C, x0])
  159. # Test if CSE of non-commutative Mul terms is disabled
  160. def test_bypass_non_commutatives():
  161. A, B, C = symbols('A B C', commutative=False)
  162. l = [A*B*C, A*C]
  163. assert cse(l) == ([], l)
  164. l = [A*B*C, A*B]
  165. assert cse(l) == ([], l)
  166. l = [B*C, A*B*C]
  167. assert cse(l) == ([], l)
  168. @XFAIL # CSE fails when replacing non-commutative sub-expressions
  169. def test_non_commutative_order():
  170. A, B, C = symbols('A B C', commutative=False)
  171. x0 = symbols('x0', commutative=False)
  172. l = [B+C, A*(B+C)]
  173. assert cse(l) == ([(x0, B+C)], [x0, A*x0])
  174. @XFAIL # Worked in gh-11232, but was reverted due to performance considerations
  175. def test_issue_10228():
  176. assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0])
  177. assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0])
  178. assert cse((w + 2*x + y + z, w + x + 1)) == (
  179. [(x0, w + x)], [x0 + x + y + z, x0 + 1])
  180. assert cse(((w + x + y + z)*(w - x))/(w + x)) == (
  181. [(x0, w + x)], [(x0 + y + z)*(w - x)/x0])
  182. a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m')
  183. exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2)
  184. assert cse(exprs) == (
  185. [(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1]
  186. )
  187. @XFAIL
  188. def test_powers():
  189. assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0])
  190. def test_issue_4498():
  191. assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \
  192. ([], [(w - z)/(x - y)])
  193. def test_issue_4020():
  194. assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \
  195. == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)])
  196. def test_issue_4203():
  197. assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0])
  198. def test_issue_6263():
  199. e = Eq(x*(-x + 1) + x*(x - 1), 0)
  200. assert cse(e, optimizations='basic') == ([], [True])
  201. def test_dont_cse_tuples():
  202. from sympy.core.function import Subs
  203. f = Function("f")
  204. g = Function("g")
  205. name_val, (expr,) = cse(
  206. Subs(f(x, y), (x, y), (0, 1))
  207. + Subs(g(x, y), (x, y), (0, 1)))
  208. assert name_val == []
  209. assert expr == (Subs(f(x, y), (x, y), (0, 1))
  210. + Subs(g(x, y), (x, y), (0, 1)))
  211. name_val, (expr,) = cse(
  212. Subs(f(x, y), (x, y), (0, x + y))
  213. + Subs(g(x, y), (x, y), (0, x + y)))
  214. assert name_val == [(x0, x + y)]
  215. assert expr == Subs(f(x, y), (x, y), (0, x0)) + \
  216. Subs(g(x, y), (x, y), (0, x0))
  217. def test_pow_invpow():
  218. assert cse(1/x**2 + x**2) == \
  219. ([(x0, x**2)], [x0 + 1/x0])
  220. assert cse(x**2 + (1 + 1/x**2)/x**2) == \
  221. ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)])
  222. assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \
  223. ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1])
  224. assert cse(cos(1/x**2) + sin(1/x**2)) == \
  225. ([(x0, x**(-2))], [sin(x0) + cos(x0)])
  226. assert cse(cos(x**2) + sin(x**2)) == \
  227. ([(x0, x**2)], [sin(x0) + cos(x0)])
  228. assert cse(y/(2 + x**2) + z/x**2/y) == \
  229. ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)])
  230. assert cse(exp(x**2) + x**2*cos(1/x**2)) == \
  231. ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)])
  232. assert cse((1 + 1/x**2)/x**2) == \
  233. ([(x0, x**(-2))], [x0*(x0 + 1)])
  234. assert cse(x**(2*y) + x**(-2*y)) == \
  235. ([(x0, x**(2*y))], [x0 + 1/x0])
  236. def test_postprocess():
  237. eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1))
  238. assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)],
  239. postprocess=cse_main.cse_separate) == \
  240. [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)],
  241. [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]]
  242. def test_issue_4499():
  243. # previously, this gave 16 constants
  244. from sympy.abc import a, b
  245. B = Function('B')
  246. G = Function('G')
  247. t = Tuple(*
  248. (a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a -
  249. b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1),
  250. sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b,
  251. sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1,
  252. sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1),
  253. (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1,
  254. sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b,
  255. -2*a))
  256. c = cse(t)
  257. ans = (
  258. [(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)),
  259. (x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)),
  260. (x8, B(b, x4)), (x9, x6*B(x2, x4))],
  261. [(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9,
  262. 1, 0, S.Half, z/2, -x3, -x1, -x0)])
  263. assert ans == c
  264. def test_issue_6169():
  265. r = CRootOf(x**6 - 4*x**5 - 2, 1)
  266. assert cse(r) == ([], [r])
  267. # and a check that the right thing is done with the new
  268. # mechanism
  269. assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y
  270. def test_cse_Indexed():
  271. len_y = 5
  272. y = IndexedBase('y', shape=(len_y,))
  273. x = IndexedBase('x', shape=(len_y,))
  274. i = Idx('i', len_y-1)
  275. expr1 = (y[i+1]-y[i])/(x[i+1]-x[i])
  276. expr2 = 1/(x[i+1]-x[i])
  277. replacements, reduced_exprs = cse([expr1, expr2])
  278. assert len(replacements) > 0
  279. def test_cse_MatrixSymbol():
  280. # MatrixSymbols have non-Basic args, so make sure that works
  281. A = MatrixSymbol("A", 3, 3)
  282. assert cse(A) == ([], [A])
  283. n = symbols('n', integer=True)
  284. B = MatrixSymbol("B", n, n)
  285. assert cse(B) == ([], [B])
  286. assert cse(A[0] * A[0]) == ([], [A[0]*A[0]])
  287. assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0])
  288. def test_cse_MatrixExpr():
  289. A = MatrixSymbol('A', 3, 3)
  290. y = MatrixSymbol('y', 3, 1)
  291. expr1 = (A.T*A).I * A * y
  292. expr2 = (A.T*A) * A * y
  293. replacements, reduced_exprs = cse([expr1, expr2])
  294. assert len(replacements) > 0
  295. replacements, reduced_exprs = cse([expr1 + expr2, expr1])
  296. assert replacements
  297. replacements, reduced_exprs = cse([A**2, A + A**2])
  298. assert replacements
  299. def test_Piecewise():
  300. f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True))
  301. ans = cse(f)
  302. actual_ans = ([(x0, x*y)],
  303. [Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))])
  304. assert ans == actual_ans
  305. def test_ignore_order_terms():
  306. eq = exp(x).series(x,0,3) + sin(y+x**3) - 1
  307. assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)])
  308. def test_name_conflict():
  309. z1 = x0 + y
  310. z2 = x2 + x3
  311. l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
  312. substs, reduced = cse(l)
  313. assert [e.subs(reversed(substs)) for e in reduced] == l
  314. def test_name_conflict_cust_symbols():
  315. z1 = x0 + y
  316. z2 = x2 + x3
  317. l = [cos(z1) + z1, cos(z2) + z2, x0 + x2]
  318. substs, reduced = cse(l, symbols("x:10"))
  319. assert [e.subs(reversed(substs)) for e in reduced] == l
  320. def test_symbols_exhausted_error():
  321. l = cos(x+y)+x+y+cos(w+y)+sin(w+y)
  322. sym = [x, y, z]
  323. with raises(ValueError):
  324. cse(l, symbols=sym)
  325. def test_issue_7840():
  326. # daveknippers' example
  327. C393 = sympify( \
  328. 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \
  329. C391 > 2.35), (C392, True)), True))'
  330. )
  331. C391 = sympify( \
  332. 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))'
  333. )
  334. C393 = C393.subs('C391',C391)
  335. # simple substitution
  336. sub = {}
  337. sub['C390'] = 0.703451854
  338. sub['C392'] = 1.01417794
  339. ss_answer = C393.subs(sub)
  340. # cse
  341. substitutions,new_eqn = cse(C393)
  342. for pair in substitutions:
  343. sub[pair[0].name] = pair[1].subs(sub)
  344. cse_answer = new_eqn[0].subs(sub)
  345. # both methods should be the same
  346. assert ss_answer == cse_answer
  347. # GitRay's example
  348. expr = sympify(
  349. "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \
  350. (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \
  351. Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \
  352. Symbol('AUTO'))), (Symbol('OFF'), true)), true))"
  353. )
  354. substitutions, new_eqn = cse(expr)
  355. # this Piecewise should be exactly the same
  356. assert new_eqn[0] == expr
  357. # there should not be any replacements
  358. assert len(substitutions) < 1
  359. def test_issue_8891():
  360. for cls in (MutableDenseMatrix, MutableSparseMatrix,
  361. ImmutableDenseMatrix, ImmutableSparseMatrix):
  362. m = cls(2, 2, [x + y, 0, 0, 0])
  363. res = cse([x + y, m])
  364. ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])])
  365. assert res == ans
  366. assert isinstance(res[1][-1], cls)
  367. def test_issue_11230():
  368. # a specific test that always failed
  369. a, b, f, k, l, i = symbols('a b f k l i')
  370. p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l]
  371. R, C = cse(p)
  372. assert not any(i.is_Mul for a in C for i in a.args)
  373. # random tests for the issue
  374. from sympy.core.random import choice
  375. from sympy.core.function import expand_mul
  376. s = symbols('a:m')
  377. # 35 Mul tests, none of which should ever fail
  378. ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)]
  379. for p in subsets(ex, 3):
  380. p = list(p)
  381. R, C = cse(p)
  382. assert not any(i.is_Mul for a in C for i in a.args)
  383. for ri in reversed(R):
  384. for i in range(len(C)):
  385. C[i] = C[i].subs(*ri)
  386. assert p == C
  387. # 35 Add tests, none of which should ever fail
  388. ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)]
  389. for p in subsets(ex, 3):
  390. p = list(p)
  391. R, C = cse(p)
  392. assert not any(i.is_Add for a in C for i in a.args)
  393. for ri in reversed(R):
  394. for i in range(len(C)):
  395. C[i] = C[i].subs(*ri)
  396. # use expand_mul to handle cases like this:
  397. # p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g]
  398. # x0 = 2*(b + e) is identified giving a rebuilt p that
  399. # is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]`
  400. assert p == [expand_mul(i) for i in C]
  401. @XFAIL
  402. def test_issue_11577():
  403. def check(eq):
  404. r, c = cse(eq)
  405. assert eq.count_ops() >= \
  406. len(r) + sum([i[1].count_ops() for i in r]) + \
  407. count_ops(c)
  408. eq = x**5*y**2 + x**5*y + x**5
  409. assert cse(eq) == (
  410. [(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1])
  411. # ([(x0, x**5*y)], [x0*y + x0 + x**5]) or
  412. # ([(x0, x**5)], [x0*y**2 + x0*y + x0])
  413. check(eq)
  414. eq = x**2/(y + 1)**2 + x/(y + 1)
  415. assert cse(eq) == (
  416. [(x0, y + 1)], [x**2/x0**2 + x/x0])
  417. # ([(x0, x/(y + 1))], [x0**2 + x0])
  418. check(eq)
  419. def test_hollow_rejection():
  420. eq = [x + 3, x + 4]
  421. assert cse(eq) == ([], eq)
  422. def test_cse_ignore():
  423. exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))]
  424. subst1, red1 = cse(exprs)
  425. assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y"
  426. subst2, red2 = cse(exprs, ignore=(y,)) # y is not allowed in substitutions
  427. assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored"
  428. assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression"
  429. def test_cse_ignore_issue_15002():
  430. l = [
  431. w*exp(x)*exp(-z),
  432. exp(y)*exp(x)*exp(-z)
  433. ]
  434. substs, reduced = cse(l, ignore=(x,))
  435. rl = [e.subs(reversed(substs)) for e in reduced]
  436. assert rl == l
  437. def test_cse_unevaluated():
  438. xp1 = UnevaluatedExpr(x + 1)
  439. # This used to cause RecursionError
  440. [(x0, ue)], [red] = cse([(-1 - xp1) / (1 - xp1)])
  441. if ue == xp1:
  442. assert red == (-1 - x0) / (1 - x0)
  443. elif ue == -xp1:
  444. assert red == (-1 + x0) / (1 + x0)
  445. else:
  446. msg = f'Expected common subexpression {xp1} or {-xp1}, instead got {ue}'
  447. assert False, msg
  448. def test_cse__performance():
  449. nexprs, nterms = 3, 20
  450. x = symbols('x:%d' % nterms)
  451. exprs = [
  452. reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)])
  453. for i in range(nexprs)
  454. ]
  455. assert (exprs[0] + exprs[1]).simplify() == 0
  456. subst, red = cse(exprs)
  457. assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE"
  458. for i, e in enumerate(red):
  459. assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0
  460. def test_issue_12070():
  461. exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z]
  462. subst, red = cse(exprs)
  463. assert 6 >= (len(subst) + sum([v.count_ops() for k, v in subst]) +
  464. count_ops(red))
  465. def test_issue_13000():
  466. eq = x/(-4*x**2 + y**2)
  467. cse_eq = cse(eq)[1][0]
  468. assert cse_eq == eq
  469. def test_issue_18203():
  470. eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1)
  471. assert cse(eq) == ([], [eq])
  472. def test_unevaluated_mul():
  473. eq = Mul(x + y, x + y, evaluate=False)
  474. assert cse(eq) == ([(x0, x + y)], [x0**2])
  475. def test_cse_release_variables():
  476. from sympy.simplify.cse_main import cse_release_variables
  477. _0, _1, _2, _3, _4 = symbols('_:5')
  478. eqs = [(x + y - 1)**2, x,
  479. x + y, (x + y)/(2*x + 1) + (x + y - 1)**2,
  480. (2*x + 1)**(x + y)]
  481. r, e = cse(eqs, postprocess=cse_release_variables)
  482. # this can change in keeping with the intention of the function
  483. assert r, e == ([
  484. (x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1),
  485. (_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1),
  486. (x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4))
  487. r.reverse()
  488. r = [(s, v) for s, v in r if v is not None]
  489. assert eqs == [i.subs(r) for i in e]
  490. def test_cse_list():
  491. _cse = lambda x: cse(x, list=False)
  492. assert _cse(x) == ([], x)
  493. assert _cse('x') == ([], 'x')
  494. it = [x]
  495. for c in (list, tuple, set):
  496. assert _cse(c(it)) == ([], c(it))
  497. #Tuple works different from tuple:
  498. assert _cse(Tuple(*it)) == ([], Tuple(*it))
  499. d = {x: 1}
  500. assert _cse(d) == ([], d)
  501. def test_issue_18991():
  502. A = MatrixSymbol('A', 2, 2)
  503. assert signsimp(-A * A - A) == -A * A - A
  504. def test_unevaluated_Mul():
  505. m = [Mul(1, 2, evaluate=False)]
  506. assert cse(m) == ([], m)
  507. def test_cse_matrix_expression_inverse():
  508. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  509. x = Inverse(A)
  510. cse_expr = cse(x)
  511. assert cse_expr == ([], [Inverse(A)])
  512. def test_cse_matrix_expression_matmul_inverse():
  513. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  514. b = ImmutableDenseMatrix(symbols('b:2'))
  515. x = MatMul(Inverse(A), b)
  516. cse_expr = cse(x)
  517. assert cse_expr == ([], [x])
  518. def test_cse_matrix_negate_matrix():
  519. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  520. x = MatMul(S.NegativeOne, A)
  521. cse_expr = cse(x)
  522. assert cse_expr == ([], [x])
  523. def test_cse_matrix_negate_matmul_not_extracted():
  524. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  525. B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
  526. x = MatMul(S.NegativeOne, A, B)
  527. cse_expr = cse(x)
  528. assert cse_expr == ([], [x])
  529. @XFAIL # No simplification rule for nested associative operations
  530. def test_cse_matrix_nested_matmul_collapsed():
  531. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  532. B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2)
  533. x = MatMul(S.NegativeOne, MatMul(A, B))
  534. cse_expr = cse(x)
  535. assert cse_expr == ([], [MatMul(S.NegativeOne, A, B)])
  536. def test_cse_matrix_optimize_out_single_argument_mul():
  537. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  538. x = MatMul(MatMul(MatMul(A)))
  539. cse_expr = cse(x)
  540. assert cse_expr == ([], [A])
  541. @XFAIL # Multiple simplification passed not supported in CSE
  542. def test_cse_matrix_optimize_out_single_argument_mul_combined():
  543. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  544. x = MatAdd(MatMul(MatMul(MatMul(A))), MatMul(MatMul(A)), MatMul(A), A)
  545. cse_expr = cse(x)
  546. assert cse_expr == ([], [MatMul(4, A)])
  547. def test_cse_matrix_optimize_out_single_argument_add():
  548. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  549. x = MatAdd(MatAdd(MatAdd(MatAdd(A))))
  550. cse_expr = cse(x)
  551. assert cse_expr == ([], [A])
  552. @XFAIL # Multiple simplification passed not supported in CSE
  553. def test_cse_matrix_optimize_out_single_argument_add_combined():
  554. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  555. x = MatMul(MatAdd(MatAdd(MatAdd(A))), MatAdd(MatAdd(A)), MatAdd(A), A)
  556. cse_expr = cse(x)
  557. assert cse_expr == ([], [MatMul(4, A)])
  558. def test_cse_matrix_expression_matrix_solve():
  559. A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2)
  560. b = ImmutableDenseMatrix(symbols('b:2'))
  561. x = MatrixSolve(A, b)
  562. cse_expr = cse(x)
  563. assert cse_expr == ([], [x])
  564. def test_cse_matrix_matrix_expression():
  565. X = ImmutableDenseMatrix(symbols('X:4')).reshape(2, 2)
  566. y = ImmutableDenseMatrix(symbols('y:2'))
  567. b = MatMul(Inverse(MatMul(Transpose(X), X)), Transpose(X), y)
  568. cse_expr = cse(b)
  569. x0 = MatrixSymbol('x0', 2, 2)
  570. reduced_expr_expected = MatMul(Inverse(MatMul(x0, X)), x0, y)
  571. assert cse_expr == ([(x0, Transpose(X))], [reduced_expr_expected])
  572. def test_cse_matrix_kalman_filter():
  573. """Kalman Filter example from Matthew Rocklin's SciPy 2013 talk.
  574. Talk titled: "Matrix Expressions and BLAS/LAPACK; SciPy 2013 Presentation"
  575. Video: https://pyvideo.org/scipy-2013/matrix-expressions-and-blaslapack-scipy-2013-pr.html
  576. Notes
  577. =====
  578. Equations are:
  579. new_mu = mu + Sigma*H.T * (R + H*Sigma*H.T).I * (H*mu - data)
  580. = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
  581. new_Sigma = Sigma - Sigma*H.T * (R + H*Sigma*H.T).I * H * Sigma
  582. = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H)), Inverse(MatAdd(R, MatMul(H*Sigma*Transpose(H)))), H, Sigma))
  583. """
  584. N = 2
  585. mu = ImmutableDenseMatrix(symbols(f'mu:{N}'))
  586. Sigma = ImmutableDenseMatrix(symbols(f'Sigma:{N * N}')).reshape(N, N)
  587. H = ImmutableDenseMatrix(symbols(f'H:{N * N}')).reshape(N, N)
  588. R = ImmutableDenseMatrix(symbols(f'R:{N * N}')).reshape(N, N)
  589. data = ImmutableDenseMatrix(symbols(f'data:{N}'))
  590. new_mu = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data))))
  591. new_Sigma = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), H, Sigma))
  592. cse_expr = cse([new_mu, new_Sigma])
  593. x0 = MatrixSymbol('x0', N, N)
  594. x1 = MatrixSymbol('x1', N, N)
  595. replacements_expected = [
  596. (x0, Transpose(H)),
  597. (x1, Inverse(MatAdd(R, MatMul(H, Sigma, x0)))),
  598. ]
  599. reduced_exprs_expected = [
  600. MatAdd(mu, MatMul(Sigma, x0, x1, MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))),
  601. MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, x0, x1, H, Sigma)),
  602. ]
  603. assert cse_expr == (replacements_expected, reduced_exprs_expected)