test_unify.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from sympy.unify.core import Compound, Variable, CondVariable, allcombinations
  2. from sympy.unify import core
  3. a,b,c = 'a', 'b', 'c'
  4. w,x,y,z = map(Variable, 'wxyz')
  5. C = Compound
  6. def is_associative(x):
  7. return isinstance(x, Compound) and (x.op in ('Add', 'Mul', 'CAdd', 'CMul'))
  8. def is_commutative(x):
  9. return isinstance(x, Compound) and (x.op in ('CAdd', 'CMul'))
  10. def unify(a, b, s={}):
  11. return core.unify(a, b, s=s, is_associative=is_associative,
  12. is_commutative=is_commutative)
  13. def test_basic():
  14. assert list(unify(a, x, {})) == [{x: a}]
  15. assert list(unify(a, x, {x: 10})) == []
  16. assert list(unify(1, x, {})) == [{x: 1}]
  17. assert list(unify(a, a, {})) == [{}]
  18. assert list(unify((w, x), (y, z), {})) == [{w: y, x: z}]
  19. assert list(unify(x, (a, b), {})) == [{x: (a, b)}]
  20. assert list(unify((a, b), (x, x), {})) == []
  21. assert list(unify((y, z), (x, x), {}))!= []
  22. assert list(unify((a, (b, c)), (a, (x, y)), {})) == [{x: b, y: c}]
  23. def test_ops():
  24. assert list(unify(C('Add', (a,b,c)), C('Add', (a,x,y)), {})) == \
  25. [{x:b, y:c}]
  26. assert list(unify(C('Add', (C('Mul', (1,2)), b,c)), C('Add', (x,y,c)), {})) == \
  27. [{x: C('Mul', (1,2)), y:b}]
  28. def test_associative():
  29. c1 = C('Add', (1,2,3))
  30. c2 = C('Add', (x,y))
  31. assert tuple(unify(c1, c2, {})) == ({x: 1, y: C('Add', (2, 3))},
  32. {x: C('Add', (1, 2)), y: 3})
  33. def test_commutative():
  34. c1 = C('CAdd', (1,2,3))
  35. c2 = C('CAdd', (x,y))
  36. result = list(unify(c1, c2, {}))
  37. assert {x: 1, y: C('CAdd', (2, 3))} in result
  38. assert ({x: 2, y: C('CAdd', (1, 3))} in result or
  39. {x: 2, y: C('CAdd', (3, 1))} in result)
  40. def _test_combinations_assoc():
  41. assert set(allcombinations((1,2,3), (a,b), True)) == \
  42. {(((1, 2), (3,)), (a, b)), (((1,), (2, 3)), (a, b))}
  43. def _test_combinations_comm():
  44. assert set(allcombinations((1,2,3), (a,b), None)) == \
  45. {(((1,), (2, 3)), ('a', 'b')), (((2,), (3, 1)), ('a', 'b')),
  46. (((3,), (1, 2)), ('a', 'b')), (((1, 2), (3,)), ('a', 'b')),
  47. (((2, 3), (1,)), ('a', 'b')), (((3, 1), (2,)), ('a', 'b'))}
  48. def test_allcombinations():
  49. assert set(allcombinations((1,2), (1,2), 'commutative')) ==\
  50. {(((1,),(2,)), ((1,),(2,))), (((1,),(2,)), ((2,),(1,)))}
  51. def test_commutativity():
  52. c1 = Compound('CAdd', (a, b))
  53. c2 = Compound('CAdd', (x, y))
  54. assert is_commutative(c1) and is_commutative(c2)
  55. assert len(list(unify(c1, c2, {}))) == 2
  56. def test_CondVariable():
  57. expr = C('CAdd', (1, 2))
  58. x = Variable('x')
  59. y = CondVariable('y', lambda a: a % 2 == 0)
  60. z = CondVariable('z', lambda a: a > 3)
  61. pattern = C('CAdd', (x, y))
  62. assert list(unify(expr, pattern, {})) == \
  63. [{x: 1, y: 2}]
  64. z = CondVariable('z', lambda a: a > 3)
  65. pattern = C('CAdd', (z, y))
  66. assert list(unify(expr, pattern, {})) == []
  67. def test_defaultdict():
  68. assert next(unify(Variable('x'), 'foo')) == {Variable('x'): 'foo'}