utilities.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from contextlib import contextmanager
  2. from threading import local
  3. from sympy.core.function import expand_mul
  4. class DotProdSimpState(local):
  5. def __init__(self):
  6. self.state = None
  7. _dotprodsimp_state = DotProdSimpState()
  8. @contextmanager
  9. def dotprodsimp(x):
  10. old = _dotprodsimp_state.state
  11. try:
  12. _dotprodsimp_state.state = x
  13. yield
  14. finally:
  15. _dotprodsimp_state.state = old
  16. def _dotprodsimp(expr, withsimp=False):
  17. """Wrapper for simplify.dotprodsimp to avoid circular imports."""
  18. from sympy.simplify.simplify import dotprodsimp as dps
  19. return dps(expr, withsimp=withsimp)
  20. def _get_intermediate_simp(deffunc=lambda x: x, offfunc=lambda x: x,
  21. onfunc=_dotprodsimp, dotprodsimp=None):
  22. """Support function for controlling intermediate simplification. Returns a
  23. simplification function according to the global setting of dotprodsimp
  24. operation.
  25. ``deffunc`` - Function to be used by default.
  26. ``offfunc`` - Function to be used if dotprodsimp has been turned off.
  27. ``onfunc`` - Function to be used if dotprodsimp has been turned on.
  28. ``dotprodsimp`` - True, False or None. Will be overridden by global
  29. _dotprodsimp_state.state if that is not None.
  30. """
  31. if dotprodsimp is False or _dotprodsimp_state.state is False:
  32. return offfunc
  33. if dotprodsimp is True or _dotprodsimp_state.state is True:
  34. return onfunc
  35. return deffunc # None, None
  36. def _get_intermediate_simp_bool(default=False, dotprodsimp=None):
  37. """Same as ``_get_intermediate_simp`` but returns bools instead of functions
  38. by default."""
  39. return _get_intermediate_simp(default, False, True, dotprodsimp)
  40. def _iszero(x):
  41. """Returns True if x is zero."""
  42. return getattr(x, 'is_zero', None)
  43. def _is_zero_after_expand_mul(x):
  44. """Tests by expand_mul only, suitable for polynomials and rational
  45. functions."""
  46. return expand_mul(x) == 0
  47. def _simplify(expr):
  48. """ Wrapper to avoid circular imports. """
  49. from sympy.simplify.simplify import simplify
  50. return simplify(expr)