123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- """
- Expressions
- -----------
- Offer fast expression evaluation through numexpr
- """
- from __future__ import annotations
- import operator
- import warnings
- import numpy as np
- from pandas._config import get_option
- from pandas._typing import FuncType
- from pandas.util._exceptions import find_stack_level
- from pandas.core.computation.check import NUMEXPR_INSTALLED
- from pandas.core.ops import roperator
- if NUMEXPR_INSTALLED:
- import numexpr as ne
- _TEST_MODE: bool | None = None
- _TEST_RESULT: list[bool] = []
- USE_NUMEXPR = NUMEXPR_INSTALLED
- _evaluate: FuncType | None = None
- _where: FuncType | None = None
- # the set of dtypes that we will allow pass to numexpr
- _ALLOWED_DTYPES = {
- "evaluate": {"int64", "int32", "float64", "float32", "bool"},
- "where": {"int64", "float64", "bool"},
- }
- # the minimum prod shape that we will use numexpr
- _MIN_ELEMENTS = 1_000_000
- def set_use_numexpr(v: bool = True) -> None:
- # set/unset to use numexpr
- global USE_NUMEXPR
- if NUMEXPR_INSTALLED:
- USE_NUMEXPR = v
- # choose what we are going to do
- global _evaluate, _where
- _evaluate = _evaluate_numexpr if USE_NUMEXPR else _evaluate_standard
- _where = _where_numexpr if USE_NUMEXPR else _where_standard
- def set_numexpr_threads(n=None) -> None:
- # if we are using numexpr, set the threads to n
- # otherwise reset
- if NUMEXPR_INSTALLED and USE_NUMEXPR:
- if n is None:
- n = ne.detect_number_of_cores()
- ne.set_num_threads(n)
- def _evaluate_standard(op, op_str, a, b):
- """
- Standard evaluation.
- """
- if _TEST_MODE:
- _store_test_result(False)
- return op(a, b)
- def _can_use_numexpr(op, op_str, a, b, dtype_check) -> bool:
- """return a boolean if we WILL be using numexpr"""
- if op_str is not None:
- # required min elements (otherwise we are adding overhead)
- if a.size > _MIN_ELEMENTS:
- # check for dtype compatibility
- dtypes: set[str] = set()
- for o in [a, b]:
- # ndarray and Series Case
- if hasattr(o, "dtype"):
- dtypes |= {o.dtype.name}
- # allowed are a superset
- if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes:
- return True
- return False
- def _evaluate_numexpr(op, op_str, a, b):
- result = None
- if _can_use_numexpr(op, op_str, a, b, "evaluate"):
- is_reversed = op.__name__.strip("_").startswith("r")
- if is_reversed:
- # we were originally called by a reversed op method
- a, b = b, a
- a_value = a
- b_value = b
- try:
- result = ne.evaluate(
- f"a_value {op_str} b_value",
- local_dict={"a_value": a_value, "b_value": b_value},
- casting="safe",
- )
- except TypeError:
- # numexpr raises eg for array ** array with integers
- # (https://github.com/pydata/numexpr/issues/379)
- pass
- except NotImplementedError:
- if _bool_arith_fallback(op_str, a, b):
- pass
- else:
- raise
- if is_reversed:
- # reverse order to original for fallback
- a, b = b, a
- if _TEST_MODE:
- _store_test_result(result is not None)
- if result is None:
- result = _evaluate_standard(op, op_str, a, b)
- return result
- _op_str_mapping = {
- operator.add: "+",
- roperator.radd: "+",
- operator.mul: "*",
- roperator.rmul: "*",
- operator.sub: "-",
- roperator.rsub: "-",
- operator.truediv: "/",
- roperator.rtruediv: "/",
- # floordiv not supported by numexpr 2.x
- operator.floordiv: None,
- roperator.rfloordiv: None,
- # we require Python semantics for mod of negative for backwards compatibility
- # see https://github.com/pydata/numexpr/issues/365
- # so sticking with unaccelerated for now GH#36552
- operator.mod: None,
- roperator.rmod: None,
- operator.pow: "**",
- roperator.rpow: "**",
- operator.eq: "==",
- operator.ne: "!=",
- operator.le: "<=",
- operator.lt: "<",
- operator.ge: ">=",
- operator.gt: ">",
- operator.and_: "&",
- roperator.rand_: "&",
- operator.or_: "|",
- roperator.ror_: "|",
- operator.xor: "^",
- roperator.rxor: "^",
- divmod: None,
- roperator.rdivmod: None,
- }
- def _where_standard(cond, a, b):
- # Caller is responsible for extracting ndarray if necessary
- return np.where(cond, a, b)
- def _where_numexpr(cond, a, b):
- # Caller is responsible for extracting ndarray if necessary
- result = None
- if _can_use_numexpr(None, "where", a, b, "where"):
- result = ne.evaluate(
- "where(cond_value, a_value, b_value)",
- local_dict={"cond_value": cond, "a_value": a, "b_value": b},
- casting="safe",
- )
- if result is None:
- result = _where_standard(cond, a, b)
- return result
- # turn myself on
- set_use_numexpr(get_option("compute.use_numexpr"))
- def _has_bool_dtype(x):
- try:
- return x.dtype == bool
- except AttributeError:
- return isinstance(x, (bool, np.bool_))
- _BOOL_OP_UNSUPPORTED = {"+": "|", "*": "&", "-": "^"}
- def _bool_arith_fallback(op_str, a, b) -> bool:
- """
- Check if we should fallback to the python `_evaluate_standard` in case
- of an unsupported operation by numexpr, which is the case for some
- boolean ops.
- """
- if _has_bool_dtype(a) and _has_bool_dtype(b):
- if op_str in _BOOL_OP_UNSUPPORTED:
- warnings.warn(
- f"evaluating in Python space because the {repr(op_str)} "
- "operator is not supported by numexpr for the bool dtype, "
- f"use {repr(_BOOL_OP_UNSUPPORTED[op_str])} instead.",
- stacklevel=find_stack_level(),
- )
- return True
- return False
- def evaluate(op, a, b, use_numexpr: bool = True):
- """
- Evaluate and return the expression of the op on a and b.
- Parameters
- ----------
- op : the actual operand
- a : left operand
- b : right operand
- use_numexpr : bool, default True
- Whether to try to use numexpr.
- """
- op_str = _op_str_mapping[op]
- if op_str is not None:
- if use_numexpr:
- # error: "None" not callable
- return _evaluate(op, op_str, a, b) # type: ignore[misc]
- return _evaluate_standard(op, op_str, a, b)
- def where(cond, a, b, use_numexpr: bool = True):
- """
- Evaluate the where condition cond on a and b.
- Parameters
- ----------
- cond : np.ndarray[bool]
- a : return if cond is True
- b : return if cond is False
- use_numexpr : bool, default True
- Whether to try to use numexpr.
- """
- assert _where is not None
- return _where(cond, a, b) if use_numexpr else _where_standard(cond, a, b)
- def set_test_mode(v: bool = True) -> None:
- """
- Keeps track of whether numexpr was used.
- Stores an additional ``True`` for every successful use of evaluate with
- numexpr since the last ``get_test_result``.
- """
- global _TEST_MODE, _TEST_RESULT
- _TEST_MODE = v
- _TEST_RESULT = []
- def _store_test_result(used_numexpr: bool) -> None:
- if used_numexpr:
- _TEST_RESULT.append(used_numexpr)
- def get_test_result() -> list[bool]:
- """
- Get test result and reset test_results.
- """
- global _TEST_RESULT
- res = _TEST_RESULT
- _TEST_RESULT = []
- return res
|