core.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. """ Generic SymPy-Independent Strategies """
  2. from __future__ import annotations
  3. from collections.abc import Callable, Mapping
  4. from typing import TypeVar
  5. from sys import stdout
  6. _S = TypeVar('_S')
  7. _T = TypeVar('_T')
  8. def identity(x: _T) -> _T:
  9. return x
  10. def exhaust(rule: Callable[[_T], _T]) -> Callable[[_T], _T]:
  11. """ Apply a rule repeatedly until it has no effect """
  12. def exhaustive_rl(expr: _T) -> _T:
  13. new, old = rule(expr), expr
  14. while new != old:
  15. new, old = rule(new), new
  16. return new
  17. return exhaustive_rl
  18. def memoize(rule: Callable[[_S], _T]) -> Callable[[_S], _T]:
  19. """Memoized version of a rule
  20. Notes
  21. =====
  22. This cache can grow infinitely, so it is not recommended to use this
  23. than ``functools.lru_cache`` unless you need very heavy computation.
  24. """
  25. cache: dict[_S, _T] = {}
  26. def memoized_rl(expr: _S) -> _T:
  27. if expr in cache:
  28. return cache[expr]
  29. else:
  30. result = rule(expr)
  31. cache[expr] = result
  32. return result
  33. return memoized_rl
  34. def condition(
  35. cond: Callable[[_T], bool], rule: Callable[[_T], _T]
  36. ) -> Callable[[_T], _T]:
  37. """ Only apply rule if condition is true """
  38. def conditioned_rl(expr: _T) -> _T:
  39. if cond(expr):
  40. return rule(expr)
  41. return expr
  42. return conditioned_rl
  43. def chain(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
  44. """
  45. Compose a sequence of rules so that they apply to the expr sequentially
  46. """
  47. def chain_rl(expr: _T) -> _T:
  48. for rule in rules:
  49. expr = rule(expr)
  50. return expr
  51. return chain_rl
  52. def debug(rule, file=None):
  53. """ Print out before and after expressions each time rule is used """
  54. if file is None:
  55. file = stdout
  56. def debug_rl(*args, **kwargs):
  57. expr = args[0]
  58. result = rule(*args, **kwargs)
  59. if result != expr:
  60. file.write("Rule: %s\n" % rule.__name__)
  61. file.write("In: %s\nOut: %s\n\n" % (expr, result))
  62. return result
  63. return debug_rl
  64. def null_safe(rule: Callable[[_T], _T | None]) -> Callable[[_T], _T]:
  65. """ Return original expr if rule returns None """
  66. def null_safe_rl(expr: _T) -> _T:
  67. result = rule(expr)
  68. if result is None:
  69. return expr
  70. return result
  71. return null_safe_rl
  72. def tryit(rule: Callable[[_T], _T], exception) -> Callable[[_T], _T]:
  73. """ Return original expr if rule raises exception """
  74. def try_rl(expr: _T) -> _T:
  75. try:
  76. return rule(expr)
  77. except exception:
  78. return expr
  79. return try_rl
  80. def do_one(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
  81. """ Try each of the rules until one works. Then stop. """
  82. def do_one_rl(expr: _T) -> _T:
  83. for rl in rules:
  84. result = rl(expr)
  85. if result != expr:
  86. return result
  87. return expr
  88. return do_one_rl
  89. def switch(
  90. key: Callable[[_S], _T],
  91. ruledict: Mapping[_T, Callable[[_S], _S]]
  92. ) -> Callable[[_S], _S]:
  93. """ Select a rule based on the result of key called on the function """
  94. def switch_rl(expr: _S) -> _S:
  95. rl = ruledict.get(key(expr), identity)
  96. return rl(expr)
  97. return switch_rl
  98. # XXX Untyped default argument for minimize function
  99. # where python requires SupportsRichComparison type
  100. def _identity(x):
  101. return x
  102. def minimize(
  103. *rules: Callable[[_S], _T],
  104. objective=_identity
  105. ) -> Callable[[_S], _T]:
  106. """ Select result of rules that minimizes objective
  107. >>> from sympy.strategies import minimize
  108. >>> inc = lambda x: x + 1
  109. >>> dec = lambda x: x - 1
  110. >>> rl = minimize(inc, dec)
  111. >>> rl(4)
  112. 3
  113. >>> rl = minimize(inc, dec, objective=lambda x: -x) # maximize
  114. >>> rl(4)
  115. 5
  116. """
  117. def minrule(expr: _S) -> _T:
  118. return min([rule(expr) for rule in rules], key=objective)
  119. return minrule