test_traversal.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from sympy.core.basic import Basic
  2. from sympy.core.containers import Tuple
  3. from sympy.core.sorting import default_sort_key
  4. from sympy.core.symbol import symbols
  5. from sympy.core.singleton import S
  6. from sympy.core.function import expand, Function
  7. from sympy.core.numbers import I
  8. from sympy.integrals.integrals import Integral
  9. from sympy.polys.polytools import factor
  10. from sympy.core.traversal import preorder_traversal, use, postorder_traversal, iterargs, iterfreeargs
  11. from sympy.functions.elementary.piecewise import ExprCondPair, Piecewise
  12. from sympy.testing.pytest import warns_deprecated_sympy
  13. from sympy.utilities.iterables import capture
  14. b1 = Basic()
  15. b2 = Basic(b1)
  16. b3 = Basic(b2)
  17. b21 = Basic(b2, b1)
  18. def test_preorder_traversal():
  19. expr = Basic(b21, b3)
  20. assert list(
  21. preorder_traversal(expr)) == [expr, b21, b2, b1, b1, b3, b2, b1]
  22. assert list(preorder_traversal(('abc', ('d', 'ef')))) == [
  23. ('abc', ('d', 'ef')), 'abc', ('d', 'ef'), 'd', 'ef']
  24. result = []
  25. pt = preorder_traversal(expr)
  26. for i in pt:
  27. result.append(i)
  28. if i == b2:
  29. pt.skip()
  30. assert result == [expr, b21, b2, b1, b3, b2]
  31. w, x, y, z = symbols('w:z')
  32. expr = z + w*(x + y)
  33. assert list(preorder_traversal([expr], keys=default_sort_key)) == \
  34. [[w*(x + y) + z], w*(x + y) + z, z, w*(x + y), w, x + y, x, y]
  35. assert list(preorder_traversal((x + y)*z, keys=True)) == \
  36. [z*(x + y), z, x + y, x, y]
  37. def test_use():
  38. x, y = symbols('x y')
  39. assert use(0, expand) == 0
  40. f = (x + y)**2*x + 1
  41. assert use(f, expand, level=0) == x**3 + 2*x**2*y + x*y**2 + + 1
  42. assert use(f, expand, level=1) == x**3 + 2*x**2*y + x*y**2 + + 1
  43. assert use(f, expand, level=2) == 1 + x*(2*x*y + x**2 + y**2)
  44. assert use(f, expand, level=3) == (x + y)**2*x + 1
  45. f = (x**2 + 1)**2 - 1
  46. kwargs = {'gaussian': True}
  47. assert use(f, factor, level=0, kwargs=kwargs) == x**2*(x**2 + 2)
  48. assert use(f, factor, level=1, kwargs=kwargs) == (x + I)**2*(x - I)**2 - 1
  49. assert use(f, factor, level=2, kwargs=kwargs) == (x + I)**2*(x - I)**2 - 1
  50. assert use(f, factor, level=3, kwargs=kwargs) == (x**2 + 1)**2 - 1
  51. def test_postorder_traversal():
  52. x, y, z, w = symbols('x y z w')
  53. expr = z + w*(x + y)
  54. expected = [z, w, x, y, x + y, w*(x + y), w*(x + y) + z]
  55. assert list(postorder_traversal(expr, keys=default_sort_key)) == expected
  56. assert list(postorder_traversal(expr, keys=True)) == expected
  57. expr = Piecewise((x, x < 1), (x**2, True))
  58. expected = [
  59. x, 1, x, x < 1, ExprCondPair(x, x < 1),
  60. 2, x, x**2, S.true,
  61. ExprCondPair(x**2, True), Piecewise((x, x < 1), (x**2, True))
  62. ]
  63. assert list(postorder_traversal(expr, keys=default_sort_key)) == expected
  64. assert list(postorder_traversal(
  65. [expr], keys=default_sort_key)) == expected + [[expr]]
  66. assert list(postorder_traversal(Integral(x**2, (x, 0, 1)),
  67. keys=default_sort_key)) == [
  68. 2, x, x**2, 0, 1, x, Tuple(x, 0, 1),
  69. Integral(x**2, Tuple(x, 0, 1))
  70. ]
  71. assert list(postorder_traversal(('abc', ('d', 'ef')))) == [
  72. 'abc', 'd', 'ef', ('d', 'ef'), ('abc', ('d', 'ef'))]
  73. def test_iterargs():
  74. f = Function('f')
  75. x = symbols('x')
  76. assert list(iterfreeargs(Integral(f(x), (f(x), 1)))) == [
  77. Integral(f(x), (f(x), 1)), 1]
  78. assert list(iterargs(Integral(f(x), (f(x), 1)))) == [
  79. Integral(f(x), (f(x), 1)), f(x), (f(x), 1), x, f(x), 1, x]
  80. def test_deprecated_imports():
  81. x = symbols('x')
  82. with warns_deprecated_sympy():
  83. from sympy.core.basic import preorder_traversal
  84. preorder_traversal(x)
  85. with warns_deprecated_sympy():
  86. from sympy.simplify.simplify import bottom_up
  87. bottom_up(x, lambda x: x)
  88. with warns_deprecated_sympy():
  89. from sympy.simplify.simplify import walk
  90. walk(x, lambda x: x)
  91. with warns_deprecated_sympy():
  92. from sympy.simplify.traversaltools import use
  93. use(x, lambda x: x)
  94. with warns_deprecated_sympy():
  95. from sympy.utilities.iterables import postorder_traversal
  96. postorder_traversal(x)
  97. with warns_deprecated_sympy():
  98. from sympy.utilities.iterables import interactive_traversal
  99. capture(lambda: interactive_traversal(x))