test_rl.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from sympy.core.singleton import S
  2. from sympy.strategies.rl import (
  3. rm_id, glom, flatten, unpack, sort, distribute, subs, rebuild)
  4. from sympy.core.basic import Basic
  5. from sympy.core.add import Add
  6. from sympy.core.mul import Mul
  7. from sympy.core.symbol import symbols
  8. from sympy.abc import x
  9. def test_rm_id():
  10. rmzeros = rm_id(lambda x: x == 0)
  11. assert rmzeros(Basic(S(0), S(1))) == Basic(S(1))
  12. assert rmzeros(Basic(S(0), S(0))) == Basic(S(0))
  13. assert rmzeros(Basic(S(2), S(1))) == Basic(S(2), S(1))
  14. def test_glom():
  15. def key(x):
  16. return x.as_coeff_Mul()[1]
  17. def count(x):
  18. return x.as_coeff_Mul()[0]
  19. def newargs(cnt, arg):
  20. return cnt * arg
  21. rl = glom(key, count, newargs)
  22. result = rl(Add(x, -x, 3 * x, 2, 3, evaluate=False))
  23. expected = Add(3 * x, 5)
  24. assert set(result.args) == set(expected.args)
  25. def test_flatten():
  26. assert flatten(Basic(S(1), S(2), Basic(S(3), S(4)))) == \
  27. Basic(S(1), S(2), S(3), S(4))
  28. def test_unpack():
  29. assert unpack(Basic(S(2))) == 2
  30. assert unpack(Basic(S(2), S(3))) == Basic(S(2), S(3))
  31. def test_sort():
  32. assert sort(str)(Basic(S(3), S(1), S(2))) == Basic(S(1), S(2), S(3))
  33. def test_distribute():
  34. class T1(Basic):
  35. pass
  36. class T2(Basic):
  37. pass
  38. distribute_t12 = distribute(T1, T2)
  39. assert distribute_t12(T1(S(1), S(2), T2(S(3), S(4)), S(5))) == \
  40. T2(T1(S(1), S(2), S(3), S(5)), T1(S(1), S(2), S(4), S(5)))
  41. assert distribute_t12(T1(S(1), S(2), S(3))) == T1(S(1), S(2), S(3))
  42. def test_distribute_add_mul():
  43. x, y = symbols('x, y')
  44. expr = Mul(2, Add(x, y), evaluate=False)
  45. expected = Add(Mul(2, x), Mul(2, y))
  46. distribute_mul = distribute(Mul, Add)
  47. assert distribute_mul(expr) == expected
  48. def test_subs():
  49. rl = subs(1, 2)
  50. assert rl(1) == 2
  51. assert rl(3) == 3
  52. def test_rebuild():
  53. expr = Basic.__new__(Add, S(1), S(2))
  54. assert rebuild(expr) == 3