test_traverse.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from sympy.core.basic import Basic
  2. from sympy.core.numbers import Integer
  3. from sympy.core.singleton import S
  4. from sympy.strategies.branch.traverse import top_down, sall
  5. from sympy.strategies.branch.core import do_one, identity
  6. def inc(x):
  7. if isinstance(x, Integer):
  8. yield x + 1
  9. def test_top_down_easy():
  10. expr = Basic(S(1), S(2))
  11. expected = Basic(S(2), S(3))
  12. brl = top_down(inc)
  13. assert set(brl(expr)) == {expected}
  14. def test_top_down_big_tree():
  15. expr = Basic(S(1), Basic(S(2)), Basic(S(3), Basic(S(4)), S(5)))
  16. expected = Basic(S(2), Basic(S(3)), Basic(S(4), Basic(S(5)), S(6)))
  17. brl = top_down(inc)
  18. assert set(brl(expr)) == {expected}
  19. def test_top_down_harder_function():
  20. def split5(x):
  21. if x == 5:
  22. yield x - 1
  23. yield x + 1
  24. expr = Basic(Basic(S(5), S(6)), S(1))
  25. expected = {Basic(Basic(S(4), S(6)), S(1)), Basic(Basic(S(6), S(6)), S(1))}
  26. brl = top_down(split5)
  27. assert set(brl(expr)) == expected
  28. def test_sall():
  29. expr = Basic(S(1), S(2))
  30. expected = Basic(S(2), S(3))
  31. brl = sall(inc)
  32. assert list(brl(expr)) == [expected]
  33. expr = Basic(S(1), S(2), Basic(S(3), S(4)))
  34. expected = Basic(S(2), S(3), Basic(S(3), S(4)))
  35. brl = sall(do_one(inc, identity))
  36. assert list(brl(expr)) == [expected]