test_matchpy_connector.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import pickle
  2. from sympy.core.relational import (Eq, Ne)
  3. from sympy.core.singleton import S
  4. from sympy.core.symbol import symbols
  5. from sympy.functions.elementary.miscellaneous import sqrt
  6. from sympy.functions.elementary.trigonometric import (cos, sin)
  7. from sympy.external import import_module
  8. from sympy.testing.pytest import skip
  9. from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar, Replacer
  10. matchpy = import_module("matchpy")
  11. x, y, z = symbols("x y z")
  12. def _get_first_match(expr, pattern):
  13. from matchpy import ManyToOneMatcher, Pattern
  14. matcher = ManyToOneMatcher()
  15. matcher.add(Pattern(pattern))
  16. return next(iter(matcher.match(expr)))
  17. def test_matchpy_connector():
  18. if matchpy is None:
  19. skip("matchpy not installed")
  20. from multiset import Multiset
  21. from matchpy import Pattern, Substitution
  22. w_ = WildDot("w_")
  23. w__ = WildPlus("w__")
  24. w___ = WildStar("w___")
  25. expr = x + y
  26. pattern = x + w_
  27. p, subst = _get_first_match(expr, pattern)
  28. assert p == Pattern(pattern)
  29. assert subst == Substitution({'w_': y})
  30. expr = x + y + z
  31. pattern = x + w__
  32. p, subst = _get_first_match(expr, pattern)
  33. assert p == Pattern(pattern)
  34. assert subst == Substitution({'w__': Multiset([y, z])})
  35. expr = x + y + z
  36. pattern = x + y + z + w___
  37. p, subst = _get_first_match(expr, pattern)
  38. assert p == Pattern(pattern)
  39. assert subst == Substitution({'w___': Multiset()})
  40. def test_matchpy_optional():
  41. if matchpy is None:
  42. skip("matchpy not installed")
  43. from matchpy import Pattern, Substitution
  44. from matchpy import ManyToOneReplacer, ReplacementRule
  45. p = WildDot("p", optional=1)
  46. q = WildDot("q", optional=0)
  47. pattern = p*x + q
  48. expr1 = 2*x
  49. pa, subst = _get_first_match(expr1, pattern)
  50. assert pa == Pattern(pattern)
  51. assert subst == Substitution({'p': 2, 'q': 0})
  52. expr2 = x + 3
  53. pa, subst = _get_first_match(expr2, pattern)
  54. assert pa == Pattern(pattern)
  55. assert subst == Substitution({'p': 1, 'q': 3})
  56. expr3 = x
  57. pa, subst = _get_first_match(expr3, pattern)
  58. assert pa == Pattern(pattern)
  59. assert subst == Substitution({'p': 1, 'q': 0})
  60. expr4 = x*y + z
  61. pa, subst = _get_first_match(expr4, pattern)
  62. assert pa == Pattern(pattern)
  63. assert subst == Substitution({'p': y, 'q': z})
  64. replacer = ManyToOneReplacer()
  65. replacer.add(ReplacementRule(Pattern(pattern), lambda p, q: sin(p)*cos(q)))
  66. assert replacer.replace(expr1) == sin(2)*cos(0)
  67. assert replacer.replace(expr2) == sin(1)*cos(3)
  68. assert replacer.replace(expr3) == sin(1)*cos(0)
  69. assert replacer.replace(expr4) == sin(y)*cos(z)
  70. def test_replacer():
  71. if matchpy is None:
  72. skip("matchpy not installed")
  73. x1_ = WildDot("x1_")
  74. x2_ = WildDot("x2_")
  75. a_ = WildDot("a_", optional=S.One)
  76. b_ = WildDot("b_", optional=S.One)
  77. c_ = WildDot("c_", optional=S.Zero)
  78. replacer = Replacer(common_constraints=[
  79. matchpy.CustomConstraint(lambda a_: not a_.has(x)),
  80. matchpy.CustomConstraint(lambda b_: not b_.has(x)),
  81. matchpy.CustomConstraint(lambda c_: not c_.has(x)),
  82. ])
  83. # Rewrite the equation into implicit form, unless it's already solved:
  84. replacer.add(Eq(x1_, x2_), Eq(x1_ - x2_, 0), conditions_nonfalse=[Ne(x2_, 0), Ne(x1_, 0), Ne(x1_, x), Ne(x2_, x)])
  85. # Simple equation solver for real numbers:
  86. replacer.add(Eq(a_*x + b_, 0), Eq(x, -b_/a_))
  87. disc = b_**2 - 4*a_*c_
  88. replacer.add(
  89. Eq(a_*x**2 + b_*x + c_, 0),
  90. Eq(x, (-b_ - sqrt(disc))/(2*a_)) | Eq(x, (-b_ + sqrt(disc))/(2*a_)),
  91. conditions_nonfalse=[disc >= 0]
  92. )
  93. replacer.add(
  94. Eq(a_*x**2 + c_, 0),
  95. Eq(x, sqrt(-c_/a_)) | Eq(x, -sqrt(-c_/a_)),
  96. conditions_nonfalse=[-c_*a_ > 0]
  97. )
  98. assert replacer.replace(Eq(3*x, y)) == Eq(x, y/3)
  99. assert replacer.replace(Eq(x**2 + 1, 0)) == Eq(x**2 + 1, 0)
  100. assert replacer.replace(Eq(x**2, 4)) == (Eq(x, 2) | Eq(x, -2))
  101. assert replacer.replace(Eq(x**2 + 4*y*x + 4*y**2, 0)) == Eq(x, -2*y)
  102. def test_matchpy_object_pickle():
  103. if matchpy is None:
  104. return
  105. a1 = WildDot("a")
  106. a2 = pickle.loads(pickle.dumps(a1))
  107. assert a1 == a2
  108. a1 = WildDot("a", S(1))
  109. a2 = pickle.loads(pickle.dumps(a1))
  110. assert a1 == a2
  111. a1 = WildPlus("a", S(1))
  112. a2 = pickle.loads(pickle.dumps(a1))
  113. assert a1 == a2
  114. a1 = WildStar("a", S(1))
  115. a2 = pickle.loads(pickle.dumps(a1))
  116. assert a1 == a2