expressions.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. """
  2. Expressions
  3. -----------
  4. Offer fast expression evaluation through numexpr
  5. """
  6. from __future__ import annotations
  7. import operator
  8. import warnings
  9. import numpy as np
  10. from pandas._config import get_option
  11. from pandas._typing import FuncType
  12. from pandas.util._exceptions import find_stack_level
  13. from pandas.core.computation.check import NUMEXPR_INSTALLED
  14. from pandas.core.ops import roperator
  15. if NUMEXPR_INSTALLED:
  16. import numexpr as ne
  17. _TEST_MODE: bool | None = None
  18. _TEST_RESULT: list[bool] = []
  19. USE_NUMEXPR = NUMEXPR_INSTALLED
  20. _evaluate: FuncType | None = None
  21. _where: FuncType | None = None
  22. # the set of dtypes that we will allow pass to numexpr
  23. _ALLOWED_DTYPES = {
  24. "evaluate": {"int64", "int32", "float64", "float32", "bool"},
  25. "where": {"int64", "float64", "bool"},
  26. }
  27. # the minimum prod shape that we will use numexpr
  28. _MIN_ELEMENTS = 1_000_000
  29. def set_use_numexpr(v: bool = True) -> None:
  30. # set/unset to use numexpr
  31. global USE_NUMEXPR
  32. if NUMEXPR_INSTALLED:
  33. USE_NUMEXPR = v
  34. # choose what we are going to do
  35. global _evaluate, _where
  36. _evaluate = _evaluate_numexpr if USE_NUMEXPR else _evaluate_standard
  37. _where = _where_numexpr if USE_NUMEXPR else _where_standard
  38. def set_numexpr_threads(n=None) -> None:
  39. # if we are using numexpr, set the threads to n
  40. # otherwise reset
  41. if NUMEXPR_INSTALLED and USE_NUMEXPR:
  42. if n is None:
  43. n = ne.detect_number_of_cores()
  44. ne.set_num_threads(n)
  45. def _evaluate_standard(op, op_str, a, b):
  46. """
  47. Standard evaluation.
  48. """
  49. if _TEST_MODE:
  50. _store_test_result(False)
  51. return op(a, b)
  52. def _can_use_numexpr(op, op_str, a, b, dtype_check) -> bool:
  53. """return a boolean if we WILL be using numexpr"""
  54. if op_str is not None:
  55. # required min elements (otherwise we are adding overhead)
  56. if a.size > _MIN_ELEMENTS:
  57. # check for dtype compatibility
  58. dtypes: set[str] = set()
  59. for o in [a, b]:
  60. # ndarray and Series Case
  61. if hasattr(o, "dtype"):
  62. dtypes |= {o.dtype.name}
  63. # allowed are a superset
  64. if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes:
  65. return True
  66. return False
  67. def _evaluate_numexpr(op, op_str, a, b):
  68. result = None
  69. if _can_use_numexpr(op, op_str, a, b, "evaluate"):
  70. is_reversed = op.__name__.strip("_").startswith("r")
  71. if is_reversed:
  72. # we were originally called by a reversed op method
  73. a, b = b, a
  74. a_value = a
  75. b_value = b
  76. try:
  77. result = ne.evaluate(
  78. f"a_value {op_str} b_value",
  79. local_dict={"a_value": a_value, "b_value": b_value},
  80. casting="safe",
  81. )
  82. except TypeError:
  83. # numexpr raises eg for array ** array with integers
  84. # (https://github.com/pydata/numexpr/issues/379)
  85. pass
  86. except NotImplementedError:
  87. if _bool_arith_fallback(op_str, a, b):
  88. pass
  89. else:
  90. raise
  91. if is_reversed:
  92. # reverse order to original for fallback
  93. a, b = b, a
  94. if _TEST_MODE:
  95. _store_test_result(result is not None)
  96. if result is None:
  97. result = _evaluate_standard(op, op_str, a, b)
  98. return result
  99. _op_str_mapping = {
  100. operator.add: "+",
  101. roperator.radd: "+",
  102. operator.mul: "*",
  103. roperator.rmul: "*",
  104. operator.sub: "-",
  105. roperator.rsub: "-",
  106. operator.truediv: "/",
  107. roperator.rtruediv: "/",
  108. # floordiv not supported by numexpr 2.x
  109. operator.floordiv: None,
  110. roperator.rfloordiv: None,
  111. # we require Python semantics for mod of negative for backwards compatibility
  112. # see https://github.com/pydata/numexpr/issues/365
  113. # so sticking with unaccelerated for now GH#36552
  114. operator.mod: None,
  115. roperator.rmod: None,
  116. operator.pow: "**",
  117. roperator.rpow: "**",
  118. operator.eq: "==",
  119. operator.ne: "!=",
  120. operator.le: "<=",
  121. operator.lt: "<",
  122. operator.ge: ">=",
  123. operator.gt: ">",
  124. operator.and_: "&",
  125. roperator.rand_: "&",
  126. operator.or_: "|",
  127. roperator.ror_: "|",
  128. operator.xor: "^",
  129. roperator.rxor: "^",
  130. divmod: None,
  131. roperator.rdivmod: None,
  132. }
  133. def _where_standard(cond, a, b):
  134. # Caller is responsible for extracting ndarray if necessary
  135. return np.where(cond, a, b)
  136. def _where_numexpr(cond, a, b):
  137. # Caller is responsible for extracting ndarray if necessary
  138. result = None
  139. if _can_use_numexpr(None, "where", a, b, "where"):
  140. result = ne.evaluate(
  141. "where(cond_value, a_value, b_value)",
  142. local_dict={"cond_value": cond, "a_value": a, "b_value": b},
  143. casting="safe",
  144. )
  145. if result is None:
  146. result = _where_standard(cond, a, b)
  147. return result
  148. # turn myself on
  149. set_use_numexpr(get_option("compute.use_numexpr"))
  150. def _has_bool_dtype(x):
  151. try:
  152. return x.dtype == bool
  153. except AttributeError:
  154. return isinstance(x, (bool, np.bool_))
  155. _BOOL_OP_UNSUPPORTED = {"+": "|", "*": "&", "-": "^"}
  156. def _bool_arith_fallback(op_str, a, b) -> bool:
  157. """
  158. Check if we should fallback to the python `_evaluate_standard` in case
  159. of an unsupported operation by numexpr, which is the case for some
  160. boolean ops.
  161. """
  162. if _has_bool_dtype(a) and _has_bool_dtype(b):
  163. if op_str in _BOOL_OP_UNSUPPORTED:
  164. warnings.warn(
  165. f"evaluating in Python space because the {repr(op_str)} "
  166. "operator is not supported by numexpr for the bool dtype, "
  167. f"use {repr(_BOOL_OP_UNSUPPORTED[op_str])} instead.",
  168. stacklevel=find_stack_level(),
  169. )
  170. return True
  171. return False
  172. def evaluate(op, a, b, use_numexpr: bool = True):
  173. """
  174. Evaluate and return the expression of the op on a and b.
  175. Parameters
  176. ----------
  177. op : the actual operand
  178. a : left operand
  179. b : right operand
  180. use_numexpr : bool, default True
  181. Whether to try to use numexpr.
  182. """
  183. op_str = _op_str_mapping[op]
  184. if op_str is not None:
  185. if use_numexpr:
  186. # error: "None" not callable
  187. return _evaluate(op, op_str, a, b) # type: ignore[misc]
  188. return _evaluate_standard(op, op_str, a, b)
  189. def where(cond, a, b, use_numexpr: bool = True):
  190. """
  191. Evaluate the where condition cond on a and b.
  192. Parameters
  193. ----------
  194. cond : np.ndarray[bool]
  195. a : return if cond is True
  196. b : return if cond is False
  197. use_numexpr : bool, default True
  198. Whether to try to use numexpr.
  199. """
  200. assert _where is not None
  201. return _where(cond, a, b) if use_numexpr else _where_standard(cond, a, b)
  202. def set_test_mode(v: bool = True) -> None:
  203. """
  204. Keeps track of whether numexpr was used.
  205. Stores an additional ``True`` for every successful use of evaluate with
  206. numexpr since the last ``get_test_result``.
  207. """
  208. global _TEST_MODE, _TEST_RESULT
  209. _TEST_MODE = v
  210. _TEST_RESULT = []
  211. def _store_test_result(used_numexpr: bool) -> None:
  212. if used_numexpr:
  213. _TEST_RESULT.append(used_numexpr)
  214. def get_test_result() -> list[bool]:
  215. """
  216. Get test result and reset test_results.
  217. """
  218. global _TEST_RESULT
  219. res = _TEST_RESULT
  220. _TEST_RESULT = []
  221. return res