test_sympy.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from sympy.core.add import Add
  2. from sympy.core.basic import Basic
  3. from sympy.core.containers import Tuple
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import (Symbol, symbols)
  6. from sympy.logic.boolalg import And
  7. from sympy.core.symbol import Str
  8. from sympy.unify.core import Compound, Variable
  9. from sympy.unify.usympy import (deconstruct, construct, unify, is_associative,
  10. is_commutative)
  11. from sympy.abc import x, y, z, n
  12. def test_deconstruct():
  13. expr = Basic(S(1), S(2), S(3))
  14. expected = Compound(Basic, (1, 2, 3))
  15. assert deconstruct(expr) == expected
  16. assert deconstruct(1) == 1
  17. assert deconstruct(x) == x
  18. assert deconstruct(x, variables=(x,)) == Variable(x)
  19. assert deconstruct(Add(1, x, evaluate=False)) == Compound(Add, (1, x))
  20. assert deconstruct(Add(1, x, evaluate=False), variables=(x,)) == \
  21. Compound(Add, (1, Variable(x)))
  22. def test_construct():
  23. expr = Compound(Basic, (S(1), S(2), S(3)))
  24. expected = Basic(S(1), S(2), S(3))
  25. assert construct(expr) == expected
  26. def test_nested():
  27. expr = Basic(S(1), Basic(S(2)), S(3))
  28. cmpd = Compound(Basic, (S(1), Compound(Basic, Tuple(2)), S(3)))
  29. assert deconstruct(expr) == cmpd
  30. assert construct(cmpd) == expr
  31. def test_unify():
  32. expr = Basic(S(1), S(2), S(3))
  33. a, b, c = map(Symbol, 'abc')
  34. pattern = Basic(a, b, c)
  35. assert list(unify(expr, pattern, {}, (a, b, c))) == [{a: 1, b: 2, c: 3}]
  36. assert list(unify(expr, pattern, variables=(a, b, c))) == \
  37. [{a: 1, b: 2, c: 3}]
  38. def test_unify_variables():
  39. assert list(unify(Basic(S(1), S(2)), Basic(S(1), x), {}, variables=(x,))) == [{x: 2}]
  40. def test_s_input():
  41. expr = Basic(S(1), S(2))
  42. a, b = map(Symbol, 'ab')
  43. pattern = Basic(a, b)
  44. assert list(unify(expr, pattern, {}, (a, b))) == [{a: 1, b: 2}]
  45. assert list(unify(expr, pattern, {a: 5}, (a, b))) == []
  46. def iterdicteq(a, b):
  47. a = tuple(a)
  48. b = tuple(b)
  49. return len(a) == len(b) and all(x in b for x in a)
  50. def test_unify_commutative():
  51. expr = Add(1, 2, 3, evaluate=False)
  52. a, b, c = map(Symbol, 'abc')
  53. pattern = Add(a, b, c, evaluate=False)
  54. result = tuple(unify(expr, pattern, {}, (a, b, c)))
  55. expected = ({a: 1, b: 2, c: 3},
  56. {a: 1, b: 3, c: 2},
  57. {a: 2, b: 1, c: 3},
  58. {a: 2, b: 3, c: 1},
  59. {a: 3, b: 1, c: 2},
  60. {a: 3, b: 2, c: 1})
  61. assert iterdicteq(result, expected)
  62. def test_unify_iter():
  63. expr = Add(1, 2, 3, evaluate=False)
  64. a, b, c = map(Symbol, 'abc')
  65. pattern = Add(a, c, evaluate=False)
  66. assert is_associative(deconstruct(pattern))
  67. assert is_commutative(deconstruct(pattern))
  68. result = list(unify(expr, pattern, {}, (a, c)))
  69. expected = [{a: 1, c: Add(2, 3, evaluate=False)},
  70. {a: 1, c: Add(3, 2, evaluate=False)},
  71. {a: 2, c: Add(1, 3, evaluate=False)},
  72. {a: 2, c: Add(3, 1, evaluate=False)},
  73. {a: 3, c: Add(1, 2, evaluate=False)},
  74. {a: 3, c: Add(2, 1, evaluate=False)},
  75. {a: Add(1, 2, evaluate=False), c: 3},
  76. {a: Add(2, 1, evaluate=False), c: 3},
  77. {a: Add(1, 3, evaluate=False), c: 2},
  78. {a: Add(3, 1, evaluate=False), c: 2},
  79. {a: Add(2, 3, evaluate=False), c: 1},
  80. {a: Add(3, 2, evaluate=False), c: 1}]
  81. assert iterdicteq(result, expected)
  82. def test_hard_match():
  83. from sympy.functions.elementary.trigonometric import (cos, sin)
  84. expr = sin(x) + cos(x)**2
  85. p, q = map(Symbol, 'pq')
  86. pattern = sin(p) + cos(p)**2
  87. assert list(unify(expr, pattern, {}, (p, q))) == [{p: x}]
  88. def test_matrix():
  89. from sympy.matrices.expressions.matexpr import MatrixSymbol
  90. X = MatrixSymbol('X', n, n)
  91. Y = MatrixSymbol('Y', 2, 2)
  92. Z = MatrixSymbol('Z', 2, 3)
  93. assert list(unify(X, Y, {}, variables=[n, Str('X')])) == [{Str('X'): Str('Y'), n: 2}]
  94. assert list(unify(X, Z, {}, variables=[n, Str('X')])) == []
  95. def test_non_frankenAdds():
  96. # the is_commutative property used to fail because of Basic.__new__
  97. # This caused is_commutative and str calls to fail
  98. expr = x+y*2
  99. rebuilt = construct(deconstruct(expr))
  100. # Ensure that we can run these commands without causing an error
  101. str(rebuilt)
  102. rebuilt.is_commutative
  103. def test_FiniteSet_commutivity():
  104. from sympy.sets.sets import FiniteSet
  105. a, b, c, x, y = symbols('a,b,c,x,y')
  106. s = FiniteSet(a, b, c)
  107. t = FiniteSet(x, y)
  108. variables = (x, y)
  109. assert {x: FiniteSet(a, c), y: b} in tuple(unify(s, t, variables=variables))
  110. def test_FiniteSet_complex():
  111. from sympy.sets.sets import FiniteSet
  112. a, b, c, x, y, z = symbols('a,b,c,x,y,z')
  113. expr = FiniteSet(Basic(S(1), x), y, Basic(x, z))
  114. pattern = FiniteSet(a, Basic(x, b))
  115. variables = a, b
  116. expected = ({b: 1, a: FiniteSet(y, Basic(x, z))},
  117. {b: z, a: FiniteSet(y, Basic(S(1), x))})
  118. assert iterdicteq(unify(expr, pattern, variables=variables), expected)
  119. def test_and():
  120. variables = x, y
  121. expected = ({x: z > 0, y: n < 3},)
  122. assert iterdicteq(unify((z>0) & (n<3), And(x, y), variables=variables),
  123. expected)
  124. def test_Union():
  125. from sympy.sets.sets import Interval
  126. assert list(unify(Interval(0, 1) + Interval(10, 11),
  127. Interval(0, 1) + Interval(12, 13),
  128. variables=(Interval(12, 13),)))
  129. def test_is_commutative():
  130. assert is_commutative(deconstruct(x+y))
  131. assert is_commutative(deconstruct(x*y))
  132. assert not is_commutative(deconstruct(x**y))
  133. def test_commutative_in_commutative():
  134. from sympy.abc import a,b,c,d
  135. from sympy.functions.elementary.trigonometric import (cos, sin)
  136. eq = sin(3)*sin(4)*sin(5) + 4*cos(3)*cos(4)
  137. pat = a*cos(b)*cos(c) + d*sin(b)*sin(c)
  138. assert next(unify(eq, pat, variables=(a,b,c,d)))