test_tree.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from sympy.strategies.tree import treeapply, greedy, allresults, brute
  2. from functools import partial, reduce
  3. def inc(x):
  4. return x + 1
  5. def dec(x):
  6. return x - 1
  7. def double(x):
  8. return 2 * x
  9. def square(x):
  10. return x**2
  11. def add(*args):
  12. return sum(args)
  13. def mul(*args):
  14. return reduce(lambda a, b: a * b, args, 1)
  15. def test_treeapply():
  16. tree = ([3, 3], [4, 1], 2)
  17. assert treeapply(tree, {list: min, tuple: max}) == 3
  18. assert treeapply(tree, {list: add, tuple: mul}) == 60
  19. def test_treeapply_leaf():
  20. assert treeapply(3, {}, leaf=lambda x: x**2) == 9
  21. tree = ([3, 3], [4, 1], 2)
  22. treep1 = ([4, 4], [5, 2], 3)
  23. assert treeapply(tree, {list: min, tuple: max}, leaf=lambda x: x + 1) == \
  24. treeapply(treep1, {list: min, tuple: max})
  25. def test_treeapply_strategies():
  26. from sympy.strategies import chain, minimize
  27. join = {list: chain, tuple: minimize}
  28. assert treeapply(inc, join) == inc
  29. assert treeapply((inc, dec), join)(5) == minimize(inc, dec)(5)
  30. assert treeapply([inc, dec], join)(5) == chain(inc, dec)(5)
  31. tree = (inc, [dec, double]) # either inc or dec-then-double
  32. assert treeapply(tree, join)(5) == 6
  33. assert treeapply(tree, join)(1) == 0
  34. maximize = partial(minimize, objective=lambda x: -x)
  35. join = {list: chain, tuple: maximize}
  36. fn = treeapply(tree, join)
  37. assert fn(4) == 6 # highest value comes from the dec then double
  38. assert fn(1) == 2 # highest value comes from the inc
  39. def test_greedy():
  40. tree = [inc, (dec, double)] # either inc or dec-then-double
  41. fn = greedy(tree, objective=lambda x: -x)
  42. assert fn(4) == 6 # highest value comes from the dec then double
  43. assert fn(1) == 2 # highest value comes from the inc
  44. tree = [inc, dec, [inc, dec, [(inc, inc), (dec, dec)]]]
  45. lowest = greedy(tree)
  46. assert lowest(10) == 8
  47. highest = greedy(tree, objective=lambda x: -x)
  48. assert highest(10) == 12
  49. def test_allresults():
  50. # square = lambda x: x**2
  51. assert set(allresults(inc)(3)) == {inc(3)}
  52. assert set(allresults([inc, dec])(3)) == {2, 4}
  53. assert set(allresults((inc, dec))(3)) == {3}
  54. assert set(allresults([inc, (dec, double)])(4)) == {5, 6}
  55. def test_brute():
  56. tree = ([inc, dec], square)
  57. fn = brute(tree, lambda x: -x)
  58. assert fn(2) == (2 + 1)**2
  59. assert fn(-2) == (-2 - 1)**2
  60. assert brute(inc)(1) == 2