ops.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620
  1. """
  2. Operator classes for eval.
  3. """
  4. from __future__ import annotations
  5. from datetime import datetime
  6. from functools import partial
  7. import operator
  8. from typing import (
  9. Callable,
  10. Iterable,
  11. Iterator,
  12. Literal,
  13. )
  14. import numpy as np
  15. from pandas._libs.tslibs import Timestamp
  16. from pandas.core.dtypes.common import (
  17. is_list_like,
  18. is_scalar,
  19. )
  20. import pandas.core.common as com
  21. from pandas.core.computation.common import (
  22. ensure_decoded,
  23. result_type_many,
  24. )
  25. from pandas.core.computation.scope import DEFAULT_GLOBALS
  26. from pandas.io.formats.printing import (
  27. pprint_thing,
  28. pprint_thing_encoded,
  29. )
  30. REDUCTIONS = ("sum", "prod", "min", "max")
  31. _unary_math_ops = (
  32. "sin",
  33. "cos",
  34. "exp",
  35. "log",
  36. "expm1",
  37. "log1p",
  38. "sqrt",
  39. "sinh",
  40. "cosh",
  41. "tanh",
  42. "arcsin",
  43. "arccos",
  44. "arctan",
  45. "arccosh",
  46. "arcsinh",
  47. "arctanh",
  48. "abs",
  49. "log10",
  50. "floor",
  51. "ceil",
  52. )
  53. _binary_math_ops = ("arctan2",)
  54. MATHOPS = _unary_math_ops + _binary_math_ops
  55. LOCAL_TAG = "__pd_eval_local_"
  56. class Term:
  57. def __new__(cls, name, env, side=None, encoding=None):
  58. klass = Constant if not isinstance(name, str) else cls
  59. # error: Argument 2 for "super" not an instance of argument 1
  60. supr_new = super(Term, klass).__new__ # type: ignore[misc]
  61. return supr_new(klass)
  62. is_local: bool
  63. def __init__(self, name, env, side=None, encoding=None) -> None:
  64. # name is a str for Term, but may be something else for subclasses
  65. self._name = name
  66. self.env = env
  67. self.side = side
  68. tname = str(name)
  69. self.is_local = tname.startswith(LOCAL_TAG) or tname in DEFAULT_GLOBALS
  70. self._value = self._resolve_name()
  71. self.encoding = encoding
  72. @property
  73. def local_name(self) -> str:
  74. return self.name.replace(LOCAL_TAG, "")
  75. def __repr__(self) -> str:
  76. return pprint_thing(self.name)
  77. def __call__(self, *args, **kwargs):
  78. return self.value
  79. def evaluate(self, *args, **kwargs) -> Term:
  80. return self
  81. def _resolve_name(self):
  82. local_name = str(self.local_name)
  83. is_local = self.is_local
  84. if local_name in self.env.scope and isinstance(
  85. self.env.scope[local_name], type
  86. ):
  87. is_local = False
  88. res = self.env.resolve(local_name, is_local=is_local)
  89. self.update(res)
  90. if hasattr(res, "ndim") and res.ndim > 2:
  91. raise NotImplementedError(
  92. "N-dimensional objects, where N > 2, are not supported with eval"
  93. )
  94. return res
  95. def update(self, value) -> None:
  96. """
  97. search order for local (i.e., @variable) variables:
  98. scope, key_variable
  99. [('locals', 'local_name'),
  100. ('globals', 'local_name'),
  101. ('locals', 'key'),
  102. ('globals', 'key')]
  103. """
  104. key = self.name
  105. # if it's a variable name (otherwise a constant)
  106. if isinstance(key, str):
  107. self.env.swapkey(self.local_name, key, new_value=value)
  108. self.value = value
  109. @property
  110. def is_scalar(self) -> bool:
  111. return is_scalar(self._value)
  112. @property
  113. def type(self):
  114. try:
  115. # potentially very slow for large, mixed dtype frames
  116. return self._value.values.dtype
  117. except AttributeError:
  118. try:
  119. # ndarray
  120. return self._value.dtype
  121. except AttributeError:
  122. # scalar
  123. return type(self._value)
  124. return_type = type
  125. @property
  126. def raw(self) -> str:
  127. return f"{type(self).__name__}(name={repr(self.name)}, type={self.type})"
  128. @property
  129. def is_datetime(self) -> bool:
  130. try:
  131. t = self.type.type
  132. except AttributeError:
  133. t = self.type
  134. return issubclass(t, (datetime, np.datetime64))
  135. @property
  136. def value(self):
  137. return self._value
  138. @value.setter
  139. def value(self, new_value) -> None:
  140. self._value = new_value
  141. @property
  142. def name(self):
  143. return self._name
  144. @property
  145. def ndim(self) -> int:
  146. return self._value.ndim
  147. class Constant(Term):
  148. def __init__(self, value, env, side=None, encoding=None) -> None:
  149. super().__init__(value, env, side=side, encoding=encoding)
  150. def _resolve_name(self):
  151. return self._name
  152. @property
  153. def name(self):
  154. return self.value
  155. def __repr__(self) -> str:
  156. # in python 2 str() of float
  157. # can truncate shorter than repr()
  158. return repr(self.name)
  159. _bool_op_map = {"not": "~", "and": "&", "or": "|"}
  160. class Op:
  161. """
  162. Hold an operator of arbitrary arity.
  163. """
  164. op: str
  165. def __init__(self, op: str, operands: Iterable[Term | Op], encoding=None) -> None:
  166. self.op = _bool_op_map.get(op, op)
  167. self.operands = operands
  168. self.encoding = encoding
  169. def __iter__(self) -> Iterator:
  170. return iter(self.operands)
  171. def __repr__(self) -> str:
  172. """
  173. Print a generic n-ary operator and its operands using infix notation.
  174. """
  175. # recurse over the operands
  176. parened = (f"({pprint_thing(opr)})" for opr in self.operands)
  177. return pprint_thing(f" {self.op} ".join(parened))
  178. @property
  179. def return_type(self):
  180. # clobber types to bool if the op is a boolean operator
  181. if self.op in (CMP_OPS_SYMS + BOOL_OPS_SYMS):
  182. return np.bool_
  183. return result_type_many(*(term.type for term in com.flatten(self)))
  184. @property
  185. def has_invalid_return_type(self) -> bool:
  186. types = self.operand_types
  187. obj_dtype_set = frozenset([np.dtype("object")])
  188. return self.return_type == object and types - obj_dtype_set
  189. @property
  190. def operand_types(self):
  191. return frozenset(term.type for term in com.flatten(self))
  192. @property
  193. def is_scalar(self) -> bool:
  194. return all(operand.is_scalar for operand in self.operands)
  195. @property
  196. def is_datetime(self) -> bool:
  197. try:
  198. t = self.return_type.type
  199. except AttributeError:
  200. t = self.return_type
  201. return issubclass(t, (datetime, np.datetime64))
  202. def _in(x, y):
  203. """
  204. Compute the vectorized membership of ``x in y`` if possible, otherwise
  205. use Python.
  206. """
  207. try:
  208. return x.isin(y)
  209. except AttributeError:
  210. if is_list_like(x):
  211. try:
  212. return y.isin(x)
  213. except AttributeError:
  214. pass
  215. return x in y
  216. def _not_in(x, y):
  217. """
  218. Compute the vectorized membership of ``x not in y`` if possible,
  219. otherwise use Python.
  220. """
  221. try:
  222. return ~x.isin(y)
  223. except AttributeError:
  224. if is_list_like(x):
  225. try:
  226. return ~y.isin(x)
  227. except AttributeError:
  228. pass
  229. return x not in y
  230. CMP_OPS_SYMS = (">", "<", ">=", "<=", "==", "!=", "in", "not in")
  231. _cmp_ops_funcs = (
  232. operator.gt,
  233. operator.lt,
  234. operator.ge,
  235. operator.le,
  236. operator.eq,
  237. operator.ne,
  238. _in,
  239. _not_in,
  240. )
  241. _cmp_ops_dict = dict(zip(CMP_OPS_SYMS, _cmp_ops_funcs))
  242. BOOL_OPS_SYMS = ("&", "|", "and", "or")
  243. _bool_ops_funcs = (operator.and_, operator.or_, operator.and_, operator.or_)
  244. _bool_ops_dict = dict(zip(BOOL_OPS_SYMS, _bool_ops_funcs))
  245. ARITH_OPS_SYMS = ("+", "-", "*", "/", "**", "//", "%")
  246. _arith_ops_funcs = (
  247. operator.add,
  248. operator.sub,
  249. operator.mul,
  250. operator.truediv,
  251. operator.pow,
  252. operator.floordiv,
  253. operator.mod,
  254. )
  255. _arith_ops_dict = dict(zip(ARITH_OPS_SYMS, _arith_ops_funcs))
  256. SPECIAL_CASE_ARITH_OPS_SYMS = ("**", "//", "%")
  257. _special_case_arith_ops_funcs = (operator.pow, operator.floordiv, operator.mod)
  258. _special_case_arith_ops_dict = dict(
  259. zip(SPECIAL_CASE_ARITH_OPS_SYMS, _special_case_arith_ops_funcs)
  260. )
  261. _binary_ops_dict = {}
  262. for d in (_cmp_ops_dict, _bool_ops_dict, _arith_ops_dict):
  263. _binary_ops_dict.update(d)
  264. def _cast_inplace(terms, acceptable_dtypes, dtype) -> None:
  265. """
  266. Cast an expression inplace.
  267. Parameters
  268. ----------
  269. terms : Op
  270. The expression that should cast.
  271. acceptable_dtypes : list of acceptable numpy.dtype
  272. Will not cast if term's dtype in this list.
  273. dtype : str or numpy.dtype
  274. The dtype to cast to.
  275. """
  276. dt = np.dtype(dtype)
  277. for term in terms:
  278. if term.type in acceptable_dtypes:
  279. continue
  280. try:
  281. new_value = term.value.astype(dt)
  282. except AttributeError:
  283. new_value = dt.type(term.value)
  284. term.update(new_value)
  285. def is_term(obj) -> bool:
  286. return isinstance(obj, Term)
  287. class BinOp(Op):
  288. """
  289. Hold a binary operator and its operands.
  290. Parameters
  291. ----------
  292. op : str
  293. lhs : Term or Op
  294. rhs : Term or Op
  295. """
  296. def __init__(self, op: str, lhs, rhs) -> None:
  297. super().__init__(op, (lhs, rhs))
  298. self.lhs = lhs
  299. self.rhs = rhs
  300. self._disallow_scalar_only_bool_ops()
  301. self.convert_values()
  302. try:
  303. self.func = _binary_ops_dict[op]
  304. except KeyError as err:
  305. # has to be made a list for python3
  306. keys = list(_binary_ops_dict.keys())
  307. raise ValueError(
  308. f"Invalid binary operator {repr(op)}, valid operators are {keys}"
  309. ) from err
  310. def __call__(self, env):
  311. """
  312. Recursively evaluate an expression in Python space.
  313. Parameters
  314. ----------
  315. env : Scope
  316. Returns
  317. -------
  318. object
  319. The result of an evaluated expression.
  320. """
  321. # recurse over the left/right nodes
  322. left = self.lhs(env)
  323. right = self.rhs(env)
  324. return self.func(left, right)
  325. def evaluate(self, env, engine: str, parser, term_type, eval_in_python):
  326. """
  327. Evaluate a binary operation *before* being passed to the engine.
  328. Parameters
  329. ----------
  330. env : Scope
  331. engine : str
  332. parser : str
  333. term_type : type
  334. eval_in_python : list
  335. Returns
  336. -------
  337. term_type
  338. The "pre-evaluated" expression as an instance of ``term_type``
  339. """
  340. if engine == "python":
  341. res = self(env)
  342. else:
  343. # recurse over the left/right nodes
  344. left = self.lhs.evaluate(
  345. env,
  346. engine=engine,
  347. parser=parser,
  348. term_type=term_type,
  349. eval_in_python=eval_in_python,
  350. )
  351. right = self.rhs.evaluate(
  352. env,
  353. engine=engine,
  354. parser=parser,
  355. term_type=term_type,
  356. eval_in_python=eval_in_python,
  357. )
  358. # base cases
  359. if self.op in eval_in_python:
  360. res = self.func(left.value, right.value)
  361. else:
  362. from pandas.core.computation.eval import eval
  363. res = eval(self, local_dict=env, engine=engine, parser=parser)
  364. name = env.add_tmp(res)
  365. return term_type(name, env=env)
  366. def convert_values(self) -> None:
  367. """
  368. Convert datetimes to a comparable value in an expression.
  369. """
  370. def stringify(value):
  371. encoder: Callable
  372. if self.encoding is not None:
  373. encoder = partial(pprint_thing_encoded, encoding=self.encoding)
  374. else:
  375. encoder = pprint_thing
  376. return encoder(value)
  377. lhs, rhs = self.lhs, self.rhs
  378. if is_term(lhs) and lhs.is_datetime and is_term(rhs) and rhs.is_scalar:
  379. v = rhs.value
  380. if isinstance(v, (int, float)):
  381. v = stringify(v)
  382. v = Timestamp(ensure_decoded(v))
  383. if v.tz is not None:
  384. v = v.tz_convert("UTC")
  385. self.rhs.update(v)
  386. if is_term(rhs) and rhs.is_datetime and is_term(lhs) and lhs.is_scalar:
  387. v = lhs.value
  388. if isinstance(v, (int, float)):
  389. v = stringify(v)
  390. v = Timestamp(ensure_decoded(v))
  391. if v.tz is not None:
  392. v = v.tz_convert("UTC")
  393. self.lhs.update(v)
  394. def _disallow_scalar_only_bool_ops(self):
  395. rhs = self.rhs
  396. lhs = self.lhs
  397. # GH#24883 unwrap dtype if necessary to ensure we have a type object
  398. rhs_rt = rhs.return_type
  399. rhs_rt = getattr(rhs_rt, "type", rhs_rt)
  400. lhs_rt = lhs.return_type
  401. lhs_rt = getattr(lhs_rt, "type", lhs_rt)
  402. if (
  403. (lhs.is_scalar or rhs.is_scalar)
  404. and self.op in _bool_ops_dict
  405. and (
  406. not (
  407. issubclass(rhs_rt, (bool, np.bool_))
  408. and issubclass(lhs_rt, (bool, np.bool_))
  409. )
  410. )
  411. ):
  412. raise NotImplementedError("cannot evaluate scalar only bool ops")
  413. def isnumeric(dtype) -> bool:
  414. return issubclass(np.dtype(dtype).type, np.number)
  415. class Div(BinOp):
  416. """
  417. Div operator to special case casting.
  418. Parameters
  419. ----------
  420. lhs, rhs : Term or Op
  421. The Terms or Ops in the ``/`` expression.
  422. """
  423. def __init__(self, lhs, rhs) -> None:
  424. super().__init__("/", lhs, rhs)
  425. if not isnumeric(lhs.return_type) or not isnumeric(rhs.return_type):
  426. raise TypeError(
  427. f"unsupported operand type(s) for {self.op}: "
  428. f"'{lhs.return_type}' and '{rhs.return_type}'"
  429. )
  430. # do not upcast float32s to float64 un-necessarily
  431. acceptable_dtypes = [np.float32, np.float_]
  432. _cast_inplace(com.flatten(self), acceptable_dtypes, np.float_)
  433. UNARY_OPS_SYMS = ("+", "-", "~", "not")
  434. _unary_ops_funcs = (operator.pos, operator.neg, operator.invert, operator.invert)
  435. _unary_ops_dict = dict(zip(UNARY_OPS_SYMS, _unary_ops_funcs))
  436. class UnaryOp(Op):
  437. """
  438. Hold a unary operator and its operands.
  439. Parameters
  440. ----------
  441. op : str
  442. The token used to represent the operator.
  443. operand : Term or Op
  444. The Term or Op operand to the operator.
  445. Raises
  446. ------
  447. ValueError
  448. * If no function associated with the passed operator token is found.
  449. """
  450. def __init__(self, op: Literal["+", "-", "~", "not"], operand) -> None:
  451. super().__init__(op, (operand,))
  452. self.operand = operand
  453. try:
  454. self.func = _unary_ops_dict[op]
  455. except KeyError as err:
  456. raise ValueError(
  457. f"Invalid unary operator {repr(op)}, "
  458. f"valid operators are {UNARY_OPS_SYMS}"
  459. ) from err
  460. def __call__(self, env) -> MathCall:
  461. operand = self.operand(env)
  462. # error: Cannot call function of unknown type
  463. return self.func(operand) # type: ignore[operator]
  464. def __repr__(self) -> str:
  465. return pprint_thing(f"{self.op}({self.operand})")
  466. @property
  467. def return_type(self) -> np.dtype:
  468. operand = self.operand
  469. if operand.return_type == np.dtype("bool"):
  470. return np.dtype("bool")
  471. if isinstance(operand, Op) and (
  472. operand.op in _cmp_ops_dict or operand.op in _bool_ops_dict
  473. ):
  474. return np.dtype("bool")
  475. return np.dtype("int")
  476. class MathCall(Op):
  477. def __init__(self, func, args) -> None:
  478. super().__init__(func.name, args)
  479. self.func = func
  480. def __call__(self, env):
  481. # error: "Op" not callable
  482. operands = [op(env) for op in self.operands] # type: ignore[operator]
  483. with np.errstate(all="ignore"):
  484. return self.func.func(*operands)
  485. def __repr__(self) -> str:
  486. operands = map(str, self.operands)
  487. return pprint_thing(f"{self.op}({','.join(operands)})")
  488. class FuncNode:
  489. def __init__(self, name: str) -> None:
  490. if name not in MATHOPS:
  491. raise ValueError(f'"{name}" is not a supported function')
  492. self.name = name
  493. self.func = getattr(np, name)
  494. def __call__(self, *args):
  495. return MathCall(self, args)