pytables.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641
  1. """ manage PyTables query interface via Expressions """
  2. from __future__ import annotations
  3. import ast
  4. from functools import partial
  5. from typing import Any
  6. import numpy as np
  7. from pandas._libs.tslibs import (
  8. Timedelta,
  9. Timestamp,
  10. )
  11. from pandas._typing import npt
  12. from pandas.errors import UndefinedVariableError
  13. from pandas.core.dtypes.common import is_list_like
  14. import pandas.core.common as com
  15. from pandas.core.computation import (
  16. expr,
  17. ops,
  18. scope as _scope,
  19. )
  20. from pandas.core.computation.common import ensure_decoded
  21. from pandas.core.computation.expr import BaseExprVisitor
  22. from pandas.core.computation.ops import is_term
  23. from pandas.core.construction import extract_array
  24. from pandas.core.indexes.base import Index
  25. from pandas.io.formats.printing import (
  26. pprint_thing,
  27. pprint_thing_encoded,
  28. )
  29. class PyTablesScope(_scope.Scope):
  30. __slots__ = ("queryables",)
  31. queryables: dict[str, Any]
  32. def __init__(
  33. self,
  34. level: int,
  35. global_dict=None,
  36. local_dict=None,
  37. queryables: dict[str, Any] | None = None,
  38. ) -> None:
  39. super().__init__(level + 1, global_dict=global_dict, local_dict=local_dict)
  40. self.queryables = queryables or {}
  41. class Term(ops.Term):
  42. env: PyTablesScope
  43. def __new__(cls, name, env, side=None, encoding=None):
  44. if isinstance(name, str):
  45. klass = cls
  46. else:
  47. klass = Constant
  48. return object.__new__(klass)
  49. def __init__(self, name, env: PyTablesScope, side=None, encoding=None) -> None:
  50. super().__init__(name, env, side=side, encoding=encoding)
  51. def _resolve_name(self):
  52. # must be a queryables
  53. if self.side == "left":
  54. # Note: The behavior of __new__ ensures that self.name is a str here
  55. if self.name not in self.env.queryables:
  56. raise NameError(f"name {repr(self.name)} is not defined")
  57. return self.name
  58. # resolve the rhs (and allow it to be None)
  59. try:
  60. return self.env.resolve(self.name, is_local=False)
  61. except UndefinedVariableError:
  62. return self.name
  63. # read-only property overwriting read/write property
  64. @property # type: ignore[misc]
  65. def value(self):
  66. return self._value
  67. class Constant(Term):
  68. def __init__(self, value, env: PyTablesScope, side=None, encoding=None) -> None:
  69. assert isinstance(env, PyTablesScope), type(env)
  70. super().__init__(value, env, side=side, encoding=encoding)
  71. def _resolve_name(self):
  72. return self._name
  73. class BinOp(ops.BinOp):
  74. _max_selectors = 31
  75. op: str
  76. queryables: dict[str, Any]
  77. condition: str | None
  78. def __init__(self, op: str, lhs, rhs, queryables: dict[str, Any], encoding) -> None:
  79. super().__init__(op, lhs, rhs)
  80. self.queryables = queryables
  81. self.encoding = encoding
  82. self.condition = None
  83. def _disallow_scalar_only_bool_ops(self) -> None:
  84. pass
  85. def prune(self, klass):
  86. def pr(left, right):
  87. """create and return a new specialized BinOp from myself"""
  88. if left is None:
  89. return right
  90. elif right is None:
  91. return left
  92. k = klass
  93. if isinstance(left, ConditionBinOp):
  94. if isinstance(right, ConditionBinOp):
  95. k = JointConditionBinOp
  96. elif isinstance(left, k):
  97. return left
  98. elif isinstance(right, k):
  99. return right
  100. elif isinstance(left, FilterBinOp):
  101. if isinstance(right, FilterBinOp):
  102. k = JointFilterBinOp
  103. elif isinstance(left, k):
  104. return left
  105. elif isinstance(right, k):
  106. return right
  107. return k(
  108. self.op, left, right, queryables=self.queryables, encoding=self.encoding
  109. ).evaluate()
  110. left, right = self.lhs, self.rhs
  111. if is_term(left) and is_term(right):
  112. res = pr(left.value, right.value)
  113. elif not is_term(left) and is_term(right):
  114. res = pr(left.prune(klass), right.value)
  115. elif is_term(left) and not is_term(right):
  116. res = pr(left.value, right.prune(klass))
  117. elif not (is_term(left) or is_term(right)):
  118. res = pr(left.prune(klass), right.prune(klass))
  119. return res
  120. def conform(self, rhs):
  121. """inplace conform rhs"""
  122. if not is_list_like(rhs):
  123. rhs = [rhs]
  124. if isinstance(rhs, np.ndarray):
  125. rhs = rhs.ravel()
  126. return rhs
  127. @property
  128. def is_valid(self) -> bool:
  129. """return True if this is a valid field"""
  130. return self.lhs in self.queryables
  131. @property
  132. def is_in_table(self) -> bool:
  133. """
  134. return True if this is a valid column name for generation (e.g. an
  135. actual column in the table)
  136. """
  137. return self.queryables.get(self.lhs) is not None
  138. @property
  139. def kind(self):
  140. """the kind of my field"""
  141. return getattr(self.queryables.get(self.lhs), "kind", None)
  142. @property
  143. def meta(self):
  144. """the meta of my field"""
  145. return getattr(self.queryables.get(self.lhs), "meta", None)
  146. @property
  147. def metadata(self):
  148. """the metadata of my field"""
  149. return getattr(self.queryables.get(self.lhs), "metadata", None)
  150. def generate(self, v) -> str:
  151. """create and return the op string for this TermValue"""
  152. val = v.tostring(self.encoding)
  153. return f"({self.lhs} {self.op} {val})"
  154. def convert_value(self, v) -> TermValue:
  155. """
  156. convert the expression that is in the term to something that is
  157. accepted by pytables
  158. """
  159. def stringify(value):
  160. if self.encoding is not None:
  161. return pprint_thing_encoded(value, encoding=self.encoding)
  162. return pprint_thing(value)
  163. kind = ensure_decoded(self.kind)
  164. meta = ensure_decoded(self.meta)
  165. if kind in ("datetime64", "datetime"):
  166. if isinstance(v, (int, float)):
  167. v = stringify(v)
  168. v = ensure_decoded(v)
  169. v = Timestamp(v).as_unit("ns")
  170. if v.tz is not None:
  171. v = v.tz_convert("UTC")
  172. return TermValue(v, v._value, kind)
  173. elif kind in ("timedelta64", "timedelta"):
  174. if isinstance(v, str):
  175. v = Timedelta(v)
  176. else:
  177. v = Timedelta(v, unit="s")
  178. v = v.as_unit("ns")._value
  179. return TermValue(int(v), v, kind)
  180. elif meta == "category":
  181. metadata = extract_array(self.metadata, extract_numpy=True)
  182. result: npt.NDArray[np.intp] | np.intp | int
  183. if v not in metadata:
  184. result = -1
  185. else:
  186. result = metadata.searchsorted(v, side="left")
  187. return TermValue(result, result, "integer")
  188. elif kind == "integer":
  189. v = int(float(v))
  190. return TermValue(v, v, kind)
  191. elif kind == "float":
  192. v = float(v)
  193. return TermValue(v, v, kind)
  194. elif kind == "bool":
  195. if isinstance(v, str):
  196. v = v.strip().lower() not in [
  197. "false",
  198. "f",
  199. "no",
  200. "n",
  201. "none",
  202. "0",
  203. "[]",
  204. "{}",
  205. "",
  206. ]
  207. else:
  208. v = bool(v)
  209. return TermValue(v, v, kind)
  210. elif isinstance(v, str):
  211. # string quoting
  212. return TermValue(v, stringify(v), "string")
  213. else:
  214. raise TypeError(f"Cannot compare {v} of type {type(v)} to {kind} column")
  215. def convert_values(self) -> None:
  216. pass
  217. class FilterBinOp(BinOp):
  218. filter: tuple[Any, Any, Index] | None = None
  219. def __repr__(self) -> str:
  220. if self.filter is None:
  221. return "Filter: Not Initialized"
  222. return pprint_thing(f"[Filter : [{self.filter[0]}] -> [{self.filter[1]}]")
  223. def invert(self):
  224. """invert the filter"""
  225. if self.filter is not None:
  226. self.filter = (
  227. self.filter[0],
  228. self.generate_filter_op(invert=True),
  229. self.filter[2],
  230. )
  231. return self
  232. def format(self):
  233. """return the actual filter format"""
  234. return [self.filter]
  235. def evaluate(self):
  236. if not self.is_valid:
  237. raise ValueError(f"query term is not valid [{self}]")
  238. rhs = self.conform(self.rhs)
  239. values = list(rhs)
  240. if self.is_in_table:
  241. # if too many values to create the expression, use a filter instead
  242. if self.op in ["==", "!="] and len(values) > self._max_selectors:
  243. filter_op = self.generate_filter_op()
  244. self.filter = (self.lhs, filter_op, Index(values))
  245. return self
  246. return None
  247. # equality conditions
  248. if self.op in ["==", "!="]:
  249. filter_op = self.generate_filter_op()
  250. self.filter = (self.lhs, filter_op, Index(values))
  251. else:
  252. raise TypeError(
  253. f"passing a filterable condition to a non-table indexer [{self}]"
  254. )
  255. return self
  256. def generate_filter_op(self, invert: bool = False):
  257. if (self.op == "!=" and not invert) or (self.op == "==" and invert):
  258. return lambda axis, vals: ~axis.isin(vals)
  259. else:
  260. return lambda axis, vals: axis.isin(vals)
  261. class JointFilterBinOp(FilterBinOp):
  262. def format(self):
  263. raise NotImplementedError("unable to collapse Joint Filters")
  264. def evaluate(self):
  265. return self
  266. class ConditionBinOp(BinOp):
  267. def __repr__(self) -> str:
  268. return pprint_thing(f"[Condition : [{self.condition}]]")
  269. def invert(self):
  270. """invert the condition"""
  271. # if self.condition is not None:
  272. # self.condition = "~(%s)" % self.condition
  273. # return self
  274. raise NotImplementedError(
  275. "cannot use an invert condition when passing to numexpr"
  276. )
  277. def format(self):
  278. """return the actual ne format"""
  279. return self.condition
  280. def evaluate(self):
  281. if not self.is_valid:
  282. raise ValueError(f"query term is not valid [{self}]")
  283. # convert values if we are in the table
  284. if not self.is_in_table:
  285. return None
  286. rhs = self.conform(self.rhs)
  287. values = [self.convert_value(v) for v in rhs]
  288. # equality conditions
  289. if self.op in ["==", "!="]:
  290. # too many values to create the expression?
  291. if len(values) <= self._max_selectors:
  292. vs = [self.generate(v) for v in values]
  293. self.condition = f"({' | '.join(vs)})"
  294. # use a filter after reading
  295. else:
  296. return None
  297. else:
  298. self.condition = self.generate(values[0])
  299. return self
  300. class JointConditionBinOp(ConditionBinOp):
  301. def evaluate(self):
  302. self.condition = f"({self.lhs.condition} {self.op} {self.rhs.condition})"
  303. return self
  304. class UnaryOp(ops.UnaryOp):
  305. def prune(self, klass):
  306. if self.op != "~":
  307. raise NotImplementedError("UnaryOp only support invert type ops")
  308. operand = self.operand
  309. operand = operand.prune(klass)
  310. if operand is not None and (
  311. issubclass(klass, ConditionBinOp)
  312. and operand.condition is not None
  313. or not issubclass(klass, ConditionBinOp)
  314. and issubclass(klass, FilterBinOp)
  315. and operand.filter is not None
  316. ):
  317. return operand.invert()
  318. return None
  319. class PyTablesExprVisitor(BaseExprVisitor):
  320. const_type = Constant
  321. term_type = Term
  322. def __init__(self, env, engine, parser, **kwargs) -> None:
  323. super().__init__(env, engine, parser)
  324. for bin_op in self.binary_ops:
  325. bin_node = self.binary_op_nodes_map[bin_op]
  326. setattr(
  327. self,
  328. f"visit_{bin_node}",
  329. lambda node, bin_op=bin_op: partial(BinOp, bin_op, **kwargs),
  330. )
  331. def visit_UnaryOp(self, node, **kwargs):
  332. if isinstance(node.op, (ast.Not, ast.Invert)):
  333. return UnaryOp("~", self.visit(node.operand))
  334. elif isinstance(node.op, ast.USub):
  335. return self.const_type(-self.visit(node.operand).value, self.env)
  336. elif isinstance(node.op, ast.UAdd):
  337. raise NotImplementedError("Unary addition not supported")
  338. def visit_Index(self, node, **kwargs):
  339. return self.visit(node.value).value
  340. def visit_Assign(self, node, **kwargs):
  341. cmpr = ast.Compare(
  342. ops=[ast.Eq()], left=node.targets[0], comparators=[node.value]
  343. )
  344. return self.visit(cmpr)
  345. def visit_Subscript(self, node, **kwargs):
  346. # only allow simple subscripts
  347. value = self.visit(node.value)
  348. slobj = self.visit(node.slice)
  349. try:
  350. value = value.value
  351. except AttributeError:
  352. pass
  353. if isinstance(slobj, Term):
  354. # In py39 np.ndarray lookups with Term containing int raise
  355. slobj = slobj.value
  356. try:
  357. return self.const_type(value[slobj], self.env)
  358. except TypeError as err:
  359. raise ValueError(
  360. f"cannot subscript {repr(value)} with {repr(slobj)}"
  361. ) from err
  362. def visit_Attribute(self, node, **kwargs):
  363. attr = node.attr
  364. value = node.value
  365. ctx = type(node.ctx)
  366. if ctx == ast.Load:
  367. # resolve the value
  368. resolved = self.visit(value)
  369. # try to get the value to see if we are another expression
  370. try:
  371. resolved = resolved.value
  372. except AttributeError:
  373. pass
  374. try:
  375. return self.term_type(getattr(resolved, attr), self.env)
  376. except AttributeError:
  377. # something like datetime.datetime where scope is overridden
  378. if isinstance(value, ast.Name) and value.id == attr:
  379. return resolved
  380. raise ValueError(f"Invalid Attribute context {ctx.__name__}")
  381. def translate_In(self, op):
  382. return ast.Eq() if isinstance(op, ast.In) else op
  383. def _rewrite_membership_op(self, node, left, right):
  384. return self.visit(node.op), node.op, left, right
  385. def _validate_where(w):
  386. """
  387. Validate that the where statement is of the right type.
  388. The type may either be String, Expr, or list-like of Exprs.
  389. Parameters
  390. ----------
  391. w : String term expression, Expr, or list-like of Exprs.
  392. Returns
  393. -------
  394. where : The original where clause if the check was successful.
  395. Raises
  396. ------
  397. TypeError : An invalid data type was passed in for w (e.g. dict).
  398. """
  399. if not (isinstance(w, (PyTablesExpr, str)) or is_list_like(w)):
  400. raise TypeError(
  401. "where must be passed as a string, PyTablesExpr, "
  402. "or list-like of PyTablesExpr"
  403. )
  404. return w
  405. class PyTablesExpr(expr.Expr):
  406. """
  407. Hold a pytables-like expression, comprised of possibly multiple 'terms'.
  408. Parameters
  409. ----------
  410. where : string term expression, PyTablesExpr, or list-like of PyTablesExprs
  411. queryables : a "kinds" map (dict of column name -> kind), or None if column
  412. is non-indexable
  413. encoding : an encoding that will encode the query terms
  414. Returns
  415. -------
  416. a PyTablesExpr object
  417. Examples
  418. --------
  419. 'index>=date'
  420. "columns=['A', 'D']"
  421. 'columns=A'
  422. 'columns==A'
  423. "~(columns=['A','B'])"
  424. 'index>df.index[3] & string="bar"'
  425. '(index>df.index[3] & index<=df.index[6]) | string="bar"'
  426. "ts>=Timestamp('2012-02-01')"
  427. "major_axis>=20130101"
  428. """
  429. _visitor: PyTablesExprVisitor | None
  430. env: PyTablesScope
  431. expr: str
  432. def __init__(
  433. self,
  434. where,
  435. queryables: dict[str, Any] | None = None,
  436. encoding=None,
  437. scope_level: int = 0,
  438. ) -> None:
  439. where = _validate_where(where)
  440. self.encoding = encoding
  441. self.condition = None
  442. self.filter = None
  443. self.terms = None
  444. self._visitor = None
  445. # capture the environment if needed
  446. local_dict: _scope.DeepChainMap[Any, Any] | None = None
  447. if isinstance(where, PyTablesExpr):
  448. local_dict = where.env.scope
  449. _where = where.expr
  450. elif is_list_like(where):
  451. where = list(where)
  452. for idx, w in enumerate(where):
  453. if isinstance(w, PyTablesExpr):
  454. local_dict = w.env.scope
  455. else:
  456. w = _validate_where(w)
  457. where[idx] = w
  458. _where = " & ".join([f"({w})" for w in com.flatten(where)])
  459. else:
  460. # _validate_where ensures we otherwise have a string
  461. _where = where
  462. self.expr = _where
  463. self.env = PyTablesScope(scope_level + 1, local_dict=local_dict)
  464. if queryables is not None and isinstance(self.expr, str):
  465. self.env.queryables.update(queryables)
  466. self._visitor = PyTablesExprVisitor(
  467. self.env,
  468. queryables=queryables,
  469. parser="pytables",
  470. engine="pytables",
  471. encoding=encoding,
  472. )
  473. self.terms = self.parse()
  474. def __repr__(self) -> str:
  475. if self.terms is not None:
  476. return pprint_thing(self.terms)
  477. return pprint_thing(self.expr)
  478. def evaluate(self):
  479. """create and return the numexpr condition and filter"""
  480. try:
  481. self.condition = self.terms.prune(ConditionBinOp)
  482. except AttributeError as err:
  483. raise ValueError(
  484. f"cannot process expression [{self.expr}], [{self}] "
  485. "is not a valid condition"
  486. ) from err
  487. try:
  488. self.filter = self.terms.prune(FilterBinOp)
  489. except AttributeError as err:
  490. raise ValueError(
  491. f"cannot process expression [{self.expr}], [{self}] "
  492. "is not a valid filter"
  493. ) from err
  494. return self.condition, self.filter
  495. class TermValue:
  496. """hold a term value the we use to construct a condition/filter"""
  497. def __init__(self, value, converted, kind: str) -> None:
  498. assert isinstance(kind, str), kind
  499. self.value = value
  500. self.converted = converted
  501. self.kind = kind
  502. def tostring(self, encoding) -> str:
  503. """quote the string if not encoded else encode and return"""
  504. if self.kind == "string":
  505. if encoding is not None:
  506. return str(self.converted)
  507. return f'"{self.converted}"'
  508. elif self.kind == "float":
  509. # python 2 str(float) is not always
  510. # round-trippable so use repr()
  511. return repr(self.converted)
  512. return str(self.converted)
  513. def maybe_expression(s) -> bool:
  514. """loose checking if s is a pytables-acceptable expression"""
  515. if not isinstance(s, str):
  516. return False
  517. operations = PyTablesExprVisitor.binary_ops + PyTablesExprVisitor.unary_ops + ("=",)
  518. # make sure we have an op at least
  519. return any(op in s for op in operations)