123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- """ Generic SymPy-Independent Strategies """
- from __future__ import annotations
- from collections.abc import Callable, Mapping
- from typing import TypeVar
- from sys import stdout
- _S = TypeVar('_S')
- _T = TypeVar('_T')
- def identity(x: _T) -> _T:
- return x
- def exhaust(rule: Callable[[_T], _T]) -> Callable[[_T], _T]:
- """ Apply a rule repeatedly until it has no effect """
- def exhaustive_rl(expr: _T) -> _T:
- new, old = rule(expr), expr
- while new != old:
- new, old = rule(new), new
- return new
- return exhaustive_rl
- def memoize(rule: Callable[[_S], _T]) -> Callable[[_S], _T]:
- """Memoized version of a rule
- Notes
- =====
- This cache can grow infinitely, so it is not recommended to use this
- than ``functools.lru_cache`` unless you need very heavy computation.
- """
- cache: dict[_S, _T] = {}
- def memoized_rl(expr: _S) -> _T:
- if expr in cache:
- return cache[expr]
- else:
- result = rule(expr)
- cache[expr] = result
- return result
- return memoized_rl
- def condition(
- cond: Callable[[_T], bool], rule: Callable[[_T], _T]
- ) -> Callable[[_T], _T]:
- """ Only apply rule if condition is true """
- def conditioned_rl(expr: _T) -> _T:
- if cond(expr):
- return rule(expr)
- return expr
- return conditioned_rl
- def chain(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
- """
- Compose a sequence of rules so that they apply to the expr sequentially
- """
- def chain_rl(expr: _T) -> _T:
- for rule in rules:
- expr = rule(expr)
- return expr
- return chain_rl
- def debug(rule, file=None):
- """ Print out before and after expressions each time rule is used """
- if file is None:
- file = stdout
- def debug_rl(*args, **kwargs):
- expr = args[0]
- result = rule(*args, **kwargs)
- if result != expr:
- file.write("Rule: %s\n" % rule.__name__)
- file.write("In: %s\nOut: %s\n\n" % (expr, result))
- return result
- return debug_rl
- def null_safe(rule: Callable[[_T], _T | None]) -> Callable[[_T], _T]:
- """ Return original expr if rule returns None """
- def null_safe_rl(expr: _T) -> _T:
- result = rule(expr)
- if result is None:
- return expr
- return result
- return null_safe_rl
- def tryit(rule: Callable[[_T], _T], exception) -> Callable[[_T], _T]:
- """ Return original expr if rule raises exception """
- def try_rl(expr: _T) -> _T:
- try:
- return rule(expr)
- except exception:
- return expr
- return try_rl
- def do_one(*rules: Callable[[_T], _T]) -> Callable[[_T], _T]:
- """ Try each of the rules until one works. Then stop. """
- def do_one_rl(expr: _T) -> _T:
- for rl in rules:
- result = rl(expr)
- if result != expr:
- return result
- return expr
- return do_one_rl
- def switch(
- key: Callable[[_S], _T],
- ruledict: Mapping[_T, Callable[[_S], _S]]
- ) -> Callable[[_S], _S]:
- """ Select a rule based on the result of key called on the function """
- def switch_rl(expr: _S) -> _S:
- rl = ruledict.get(key(expr), identity)
- return rl(expr)
- return switch_rl
- # XXX Untyped default argument for minimize function
- # where python requires SupportsRichComparison type
- def _identity(x):
- return x
- def minimize(
- *rules: Callable[[_S], _T],
- objective=_identity
- ) -> Callable[[_S], _T]:
- """ Select result of rules that minimizes objective
- >>> from sympy.strategies import minimize
- >>> inc = lambda x: x + 1
- >>> dec = lambda x: x - 1
- >>> rl = minimize(inc, dec)
- >>> rl(4)
- 3
- >>> rl = minimize(inc, dec, objective=lambda x: -x) # maximize
- >>> rl(4)
- 5
- """
- def minrule(expr: _S) -> _T:
- return min([rule(expr) for rule in rules], key=objective)
- return minrule
|