cache.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. """ Caching facility for SymPy """
  2. from importlib import import_module
  3. from typing import Callable
  4. class _cache(list):
  5. """ List of cached functions """
  6. def print_cache(self):
  7. """print cache info"""
  8. for item in self:
  9. name = item.__name__
  10. myfunc = item
  11. while hasattr(myfunc, '__wrapped__'):
  12. if hasattr(myfunc, 'cache_info'):
  13. info = myfunc.cache_info()
  14. break
  15. else:
  16. myfunc = myfunc.__wrapped__
  17. else:
  18. info = None
  19. print(name, info)
  20. def clear_cache(self):
  21. """clear cache content"""
  22. for item in self:
  23. myfunc = item
  24. while hasattr(myfunc, '__wrapped__'):
  25. if hasattr(myfunc, 'cache_clear'):
  26. myfunc.cache_clear()
  27. break
  28. else:
  29. myfunc = myfunc.__wrapped__
  30. # global cache registry:
  31. CACHE = _cache()
  32. # make clear and print methods available
  33. print_cache = CACHE.print_cache
  34. clear_cache = CACHE.clear_cache
  35. from functools import lru_cache, wraps
  36. def __cacheit(maxsize):
  37. """caching decorator.
  38. important: the result of cached function must be *immutable*
  39. Examples
  40. ========
  41. >>> from sympy import cacheit
  42. >>> @cacheit
  43. ... def f(a, b):
  44. ... return a+b
  45. >>> @cacheit
  46. ... def f(a, b): # noqa: F811
  47. ... return [a, b] # <-- WRONG, returns mutable object
  48. to force cacheit to check returned results mutability and consistency,
  49. set environment variable SYMPY_USE_CACHE to 'debug'
  50. """
  51. def func_wrapper(func):
  52. cfunc = lru_cache(maxsize, typed=True)(func)
  53. @wraps(func)
  54. def wrapper(*args, **kwargs):
  55. try:
  56. retval = cfunc(*args, **kwargs)
  57. except TypeError as e:
  58. if not e.args or not e.args[0].startswith('unhashable type:'):
  59. raise
  60. retval = func(*args, **kwargs)
  61. return retval
  62. wrapper.cache_info = cfunc.cache_info
  63. wrapper.cache_clear = cfunc.cache_clear
  64. CACHE.append(wrapper)
  65. return wrapper
  66. return func_wrapper
  67. ########################################
  68. def __cacheit_nocache(func):
  69. return func
  70. def __cacheit_debug(maxsize):
  71. """cacheit + code to check cache consistency"""
  72. def func_wrapper(func):
  73. cfunc = __cacheit(maxsize)(func)
  74. @wraps(func)
  75. def wrapper(*args, **kw_args):
  76. # always call function itself and compare it with cached version
  77. r1 = func(*args, **kw_args)
  78. r2 = cfunc(*args, **kw_args)
  79. # try to see if the result is immutable
  80. #
  81. # this works because:
  82. #
  83. # hash([1,2,3]) -> raise TypeError
  84. # hash({'a':1, 'b':2}) -> raise TypeError
  85. # hash((1,[2,3])) -> raise TypeError
  86. #
  87. # hash((1,2,3)) -> just computes the hash
  88. hash(r1), hash(r2)
  89. # also see if returned values are the same
  90. if r1 != r2:
  91. raise RuntimeError("Returned values are not the same")
  92. return r1
  93. return wrapper
  94. return func_wrapper
  95. def _getenv(key, default=None):
  96. from os import getenv
  97. return getenv(key, default)
  98. # SYMPY_USE_CACHE=yes/no/debug
  99. USE_CACHE = _getenv('SYMPY_USE_CACHE', 'yes').lower()
  100. # SYMPY_CACHE_SIZE=some_integer/None
  101. # special cases :
  102. # SYMPY_CACHE_SIZE=0 -> No caching
  103. # SYMPY_CACHE_SIZE=None -> Unbounded caching
  104. scs = _getenv('SYMPY_CACHE_SIZE', '1000')
  105. if scs.lower() == 'none':
  106. SYMPY_CACHE_SIZE = None
  107. else:
  108. try:
  109. SYMPY_CACHE_SIZE = int(scs)
  110. except ValueError:
  111. raise RuntimeError(
  112. 'SYMPY_CACHE_SIZE must be a valid integer or None. ' + \
  113. 'Got: %s' % SYMPY_CACHE_SIZE)
  114. if USE_CACHE == 'no':
  115. cacheit = __cacheit_nocache
  116. elif USE_CACHE == 'yes':
  117. cacheit = __cacheit(SYMPY_CACHE_SIZE)
  118. elif USE_CACHE == 'debug':
  119. cacheit = __cacheit_debug(SYMPY_CACHE_SIZE) # a lot slower
  120. else:
  121. raise RuntimeError(
  122. 'unrecognized value for SYMPY_USE_CACHE: %s' % USE_CACHE)
  123. def cached_property(func):
  124. '''Decorator to cache property method'''
  125. attrname = '__' + func.__name__
  126. _cached_property_sentinel = object()
  127. def propfunc(self):
  128. val = getattr(self, attrname, _cached_property_sentinel)
  129. if val is _cached_property_sentinel:
  130. val = func(self)
  131. setattr(self, attrname, val)
  132. return val
  133. return property(propfunc)
  134. def lazy_function(module : str, name : str) -> Callable:
  135. """Create a lazy proxy for a function in a module.
  136. The module containing the function is not imported until the function is used.
  137. """
  138. func = None
  139. def _get_function():
  140. nonlocal func
  141. if func is None:
  142. func = getattr(import_module(module), name)
  143. return func
  144. # The metaclass is needed so that help() shows the docstring
  145. class LazyFunctionMeta(type):
  146. @property
  147. def __doc__(self):
  148. docstring = _get_function().__doc__
  149. docstring += f"\n\nNote: this is a {self.__class__.__name__} wrapper of '{module}.{name}'"
  150. return docstring
  151. class LazyFunction(metaclass=LazyFunctionMeta):
  152. def __call__(self, *args, **kwargs):
  153. # inline get of function for performance gh-23832
  154. nonlocal func
  155. if func is None:
  156. func = getattr(import_module(module), name)
  157. return func(*args, **kwargs)
  158. @property
  159. def __doc__(self):
  160. docstring = _get_function().__doc__
  161. docstring += f"\n\nNote: this is a {self.__class__.__name__} wrapper of '{module}.{name}'"
  162. return docstring
  163. def __str__(self):
  164. return _get_function().__str__()
  165. def __repr__(self):
  166. return f"<{__class__.__name__} object at 0x{id(self):x}>: wrapping '{module}.{name}'"
  167. return LazyFunction()