matchpy_connector.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. """
  2. The objects in this module allow the usage of the MatchPy pattern matching
  3. library on SymPy expressions.
  4. """
  5. import re
  6. from typing import List, Callable
  7. from sympy.core.sympify import _sympify
  8. from sympy.external import import_module
  9. from sympy.functions import (log, sin, cos, tan, cot, csc, sec, erf, gamma, uppergamma)
  10. from sympy.functions.elementary.hyperbolic import acosh, asinh, atanh, acoth, acsch, asech, cosh, sinh, tanh, coth, sech, csch
  11. from sympy.functions.elementary.trigonometric import atan, acsc, asin, acot, acos, asec
  12. from sympy.functions.special.error_functions import fresnelc, fresnels, erfc, erfi, Ei
  13. from sympy.core.add import Add
  14. from sympy.core.basic import Basic
  15. from sympy.core.expr import Expr
  16. from sympy.core.mul import Mul
  17. from sympy.core.power import Pow
  18. from sympy.core.relational import (Equality, Unequality)
  19. from sympy.core.symbol import Symbol
  20. from sympy.functions.elementary.exponential import exp
  21. from sympy.integrals.integrals import Integral
  22. from sympy.printing.repr import srepr
  23. from sympy.utilities.decorator import doctest_depends_on
  24. matchpy = import_module("matchpy")
  25. if matchpy:
  26. from matchpy import Operation, CommutativeOperation, AssociativeOperation, OneIdentityOperation
  27. from matchpy.expressions.functions import op_iter, create_operation_expression, op_len
  28. Operation.register(Integral)
  29. Operation.register(Pow)
  30. OneIdentityOperation.register(Pow)
  31. Operation.register(Add)
  32. OneIdentityOperation.register(Add)
  33. CommutativeOperation.register(Add)
  34. AssociativeOperation.register(Add)
  35. Operation.register(Mul)
  36. OneIdentityOperation.register(Mul)
  37. CommutativeOperation.register(Mul)
  38. AssociativeOperation.register(Mul)
  39. Operation.register(Equality)
  40. CommutativeOperation.register(Equality)
  41. Operation.register(Unequality)
  42. CommutativeOperation.register(Unequality)
  43. Operation.register(exp)
  44. Operation.register(log)
  45. Operation.register(gamma)
  46. Operation.register(uppergamma)
  47. Operation.register(fresnels)
  48. Operation.register(fresnelc)
  49. Operation.register(erf)
  50. Operation.register(Ei)
  51. Operation.register(erfc)
  52. Operation.register(erfi)
  53. Operation.register(sin)
  54. Operation.register(cos)
  55. Operation.register(tan)
  56. Operation.register(cot)
  57. Operation.register(csc)
  58. Operation.register(sec)
  59. Operation.register(sinh)
  60. Operation.register(cosh)
  61. Operation.register(tanh)
  62. Operation.register(coth)
  63. Operation.register(csch)
  64. Operation.register(sech)
  65. Operation.register(asin)
  66. Operation.register(acos)
  67. Operation.register(atan)
  68. Operation.register(acot)
  69. Operation.register(acsc)
  70. Operation.register(asec)
  71. Operation.register(asinh)
  72. Operation.register(acosh)
  73. Operation.register(atanh)
  74. Operation.register(acoth)
  75. Operation.register(acsch)
  76. Operation.register(asech)
  77. @op_iter.register(Integral) # type: ignore
  78. def _(operation):
  79. return iter((operation._args[0],) + operation._args[1])
  80. @op_iter.register(Basic) # type: ignore
  81. def _(operation):
  82. return iter(operation._args)
  83. @op_len.register(Integral) # type: ignore
  84. def _(operation):
  85. return 1 + len(operation._args[1])
  86. @op_len.register(Basic) # type: ignore
  87. def _(operation):
  88. return len(operation._args)
  89. @create_operation_expression.register(Basic)
  90. def sympy_op_factory(old_operation, new_operands, variable_name=True):
  91. return type(old_operation)(*new_operands)
  92. if matchpy:
  93. from matchpy import Wildcard
  94. else:
  95. class Wildcard: # type: ignore
  96. def __init__(self, min_length, fixed_size, variable_name, optional):
  97. self.min_count = min_length
  98. self.fixed_size = fixed_size
  99. self.variable_name = variable_name
  100. self.optional = optional
  101. @doctest_depends_on(modules=('matchpy',))
  102. class _WildAbstract(Wildcard, Symbol):
  103. min_length: int # abstract field required in subclasses
  104. fixed_size: bool # abstract field required in subclasses
  105. def __init__(self, variable_name=None, optional=None, **assumptions):
  106. min_length = self.min_length
  107. fixed_size = self.fixed_size
  108. if optional is not None:
  109. optional = _sympify(optional)
  110. Wildcard.__init__(self, min_length, fixed_size, str(variable_name), optional)
  111. def __getstate__(self):
  112. return {
  113. "min_length": self.min_length,
  114. "fixed_size": self.fixed_size,
  115. "min_count": self.min_count,
  116. "variable_name": self.variable_name,
  117. "optional": self.optional,
  118. }
  119. def __new__(cls, variable_name=None, optional=None, **assumptions):
  120. cls._sanitize(assumptions, cls)
  121. return _WildAbstract.__xnew__(cls, variable_name, optional, **assumptions)
  122. def __getnewargs__(self):
  123. return self.variable_name, self.optional
  124. @staticmethod
  125. def __xnew__(cls, variable_name=None, optional=None, **assumptions):
  126. obj = Symbol.__xnew__(cls, variable_name, **assumptions)
  127. return obj
  128. def _hashable_content(self):
  129. if self.optional:
  130. return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name, self.optional)
  131. else:
  132. return super()._hashable_content() + (self.min_count, self.fixed_size, self.variable_name)
  133. def __copy__(self) -> '_WildAbstract':
  134. return type(self)(variable_name=self.variable_name, optional=self.optional)
  135. def __repr__(self):
  136. return str(self)
  137. def __str__(self):
  138. return self.name
  139. @doctest_depends_on(modules=('matchpy',))
  140. class WildDot(_WildAbstract):
  141. min_length = 1
  142. fixed_size = True
  143. @doctest_depends_on(modules=('matchpy',))
  144. class WildPlus(_WildAbstract):
  145. min_length = 1
  146. fixed_size = False
  147. @doctest_depends_on(modules=('matchpy',))
  148. class WildStar(_WildAbstract):
  149. min_length = 0
  150. fixed_size = False
  151. def _get_srepr(expr):
  152. s = srepr(expr)
  153. s = re.sub(r"WildDot\('(\w+)'\)", r"\1", s)
  154. s = re.sub(r"WildPlus\('(\w+)'\)", r"*\1", s)
  155. s = re.sub(r"WildStar\('(\w+)'\)", r"*\1", s)
  156. return s
  157. @doctest_depends_on(modules=('matchpy',))
  158. class Replacer:
  159. """
  160. Replacer object to perform multiple pattern matching and subexpression
  161. replacements in SymPy expressions.
  162. Examples
  163. ========
  164. Example to construct a simple first degree equation solver:
  165. >>> from sympy.utilities.matchpy_connector import WildDot, Replacer
  166. >>> from sympy import Equality, Symbol
  167. >>> x = Symbol("x")
  168. >>> a_ = WildDot("a_", optional=1)
  169. >>> b_ = WildDot("b_", optional=0)
  170. The lines above have defined two wildcards, ``a_`` and ``b_``, the
  171. coefficients of the equation `a x + b = 0`. The optional values specified
  172. indicate which expression to return in case no match is found, they are
  173. necessary in equations like `a x = 0` and `x + b = 0`.
  174. Create two constraints to make sure that ``a_`` and ``b_`` will not match
  175. any expression containing ``x``:
  176. >>> from matchpy import CustomConstraint
  177. >>> free_x_a = CustomConstraint(lambda a_: not a_.has(x))
  178. >>> free_x_b = CustomConstraint(lambda b_: not b_.has(x))
  179. Now create the rule replacer with the constraints:
  180. >>> replacer = Replacer(common_constraints=[free_x_a, free_x_b])
  181. Add the matching rule:
  182. >>> replacer.add(Equality(a_*x + b_, 0), -b_/a_)
  183. Let's try it:
  184. >>> replacer.replace(Equality(3*x + 4, 0))
  185. -4/3
  186. Notice that it will not match equations expressed with other patterns:
  187. >>> eq = Equality(3*x, 4)
  188. >>> replacer.replace(eq)
  189. Eq(3*x, 4)
  190. In order to extend the matching patterns, define another one (we also need
  191. to clear the cache, because the previous result has already been memorized
  192. and the pattern matcher will not iterate again if given the same expression)
  193. >>> replacer.add(Equality(a_*x, b_), b_/a_)
  194. >>> replacer._replacer.matcher.clear()
  195. >>> replacer.replace(eq)
  196. 4/3
  197. """
  198. def __init__(self, common_constraints: list = []):
  199. self._replacer = matchpy.ManyToOneReplacer()
  200. self._common_constraint = common_constraints
  201. def _get_lambda(self, lambda_str: str) -> Callable[..., Expr]:
  202. exec("from sympy import *")
  203. return eval(lambda_str, locals())
  204. def _get_custom_constraint(self, constraint_expr: Expr, condition_template: str) -> Callable[..., Expr]:
  205. wilds = [x.name for x in constraint_expr.atoms(_WildAbstract)]
  206. lambdaargs = ', '.join(wilds)
  207. fullexpr = _get_srepr(constraint_expr)
  208. condition = condition_template.format(fullexpr)
  209. return matchpy.CustomConstraint(
  210. self._get_lambda(f"lambda {lambdaargs}: ({condition})"))
  211. def _get_custom_constraint_nonfalse(self, constraint_expr: Expr) -> Callable[..., Expr]:
  212. return self._get_custom_constraint(constraint_expr, "({}) != False")
  213. def _get_custom_constraint_true(self, constraint_expr: Expr) -> Callable[..., Expr]:
  214. return self._get_custom_constraint(constraint_expr, "({}) == True")
  215. def add(self, expr: Expr, result: Expr, conditions_true: List[Expr] = [], conditions_nonfalse: List[Expr] = []) -> None:
  216. expr = _sympify(expr)
  217. result = _sympify(result)
  218. lambda_str = f"lambda {', '.join((x.name for x in expr.atoms(_WildAbstract)))}: {_get_srepr(result)}"
  219. lambda_expr = self._get_lambda(lambda_str)
  220. constraints = self._common_constraint[:]
  221. constraint_conditions_true = [
  222. self._get_custom_constraint_true(cond) for cond in conditions_true]
  223. constraint_conditions_nonfalse = [
  224. self._get_custom_constraint_nonfalse(cond) for cond in conditions_nonfalse]
  225. constraints.extend(constraint_conditions_true)
  226. constraints.extend(constraint_conditions_nonfalse)
  227. self._replacer.add(
  228. matchpy.ReplacementRule(matchpy.Pattern(expr, *constraints), lambda_expr))
  229. def replace(self, expr: Expr) -> Expr:
  230. return self._replacer.replace(expr)