test_rewrite.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from sympy.unify.rewrite import rewriterule
  2. from sympy.core.basic import Basic
  3. from sympy.core.singleton import S
  4. from sympy.core.symbol import Symbol
  5. from sympy.functions.elementary.trigonometric import sin
  6. from sympy.abc import x, y
  7. from sympy.strategies.rl import rebuild
  8. from sympy.assumptions import Q
  9. p, q = Symbol('p'), Symbol('q')
  10. def test_simple():
  11. rl = rewriterule(Basic(p, S(1)), Basic(p, S(2)), variables=(p,))
  12. assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))]
  13. p1 = p**2
  14. p2 = p**3
  15. rl = rewriterule(p1, p2, variables=(p,))
  16. expr = x**2
  17. assert list(rl(expr)) == [x**3]
  18. def test_simple_variables():
  19. rl = rewriterule(Basic(x, S(1)), Basic(x, S(2)), variables=(x,))
  20. assert list(rl(Basic(S(3), S(1)))) == [Basic(S(3), S(2))]
  21. rl = rewriterule(x**2, x**3, variables=(x,))
  22. assert list(rl(y**2)) == [y**3]
  23. def test_moderate():
  24. p1 = p**2 + q**3
  25. p2 = (p*q)**4
  26. rl = rewriterule(p1, p2, (p, q))
  27. expr = x**2 + y**3
  28. assert list(rl(expr)) == [(x*y)**4]
  29. def test_sincos():
  30. p1 = sin(p)**2 + sin(p)**2
  31. p2 = 1
  32. rl = rewriterule(p1, p2, (p, q))
  33. assert list(rl(sin(x)**2 + sin(x)**2)) == [1]
  34. assert list(rl(sin(y)**2 + sin(y)**2)) == [1]
  35. def test_Exprs_ok():
  36. rl = rewriterule(p+q, q+p, (p, q))
  37. next(rl(x+y)).is_commutative
  38. str(next(rl(x+y)))
  39. def test_condition_simple():
  40. rl = rewriterule(x, x+1, [x], lambda x: x < 10)
  41. assert not list(rl(S(15)))
  42. assert rebuild(next(rl(S(5)))) == 6
  43. def test_condition_multiple():
  44. rl = rewriterule(x + y, x**y, [x,y], lambda x, y: x.is_integer)
  45. a = Symbol('a')
  46. b = Symbol('b', integer=True)
  47. expr = a + b
  48. assert list(rl(expr)) == [b**a]
  49. c = Symbol('c', integer=True)
  50. d = Symbol('d', integer=True)
  51. assert set(rl(c + d)) == {c**d, d**c}
  52. def test_assumptions():
  53. rl = rewriterule(x + y, x**y, [x, y], assume=Q.integer(x))
  54. a, b = map(Symbol, 'ab')
  55. expr = a + b
  56. assert list(rl(expr, Q.integer(b))) == [b**a]