test_expressions.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. import operator
  2. import re
  3. import warnings
  4. import numpy as np
  5. import pytest
  6. from pandas import option_context
  7. import pandas._testing as tm
  8. from pandas.core.api import (
  9. DataFrame,
  10. Index,
  11. Series,
  12. )
  13. from pandas.core.computation import expressions as expr
  14. @pytest.fixture
  15. def _frame():
  16. return DataFrame(np.random.randn(10001, 4), columns=list("ABCD"), dtype="float64")
  17. @pytest.fixture
  18. def _frame2():
  19. return DataFrame(np.random.randn(100, 4), columns=list("ABCD"), dtype="float64")
  20. @pytest.fixture
  21. def _mixed(_frame):
  22. return DataFrame(
  23. {
  24. "A": _frame["A"].copy(),
  25. "B": _frame["B"].astype("float32"),
  26. "C": _frame["C"].astype("int64"),
  27. "D": _frame["D"].astype("int32"),
  28. }
  29. )
  30. @pytest.fixture
  31. def _mixed2(_frame2):
  32. return DataFrame(
  33. {
  34. "A": _frame2["A"].copy(),
  35. "B": _frame2["B"].astype("float32"),
  36. "C": _frame2["C"].astype("int64"),
  37. "D": _frame2["D"].astype("int32"),
  38. }
  39. )
  40. @pytest.fixture
  41. def _integer():
  42. return DataFrame(
  43. np.random.randint(1, 100, size=(10001, 4)), columns=list("ABCD"), dtype="int64"
  44. )
  45. @pytest.fixture
  46. def _integer_randint(_integer):
  47. # randint to get a case with zeros
  48. return _integer * np.random.randint(0, 2, size=np.shape(_integer))
  49. @pytest.fixture
  50. def _integer2():
  51. return DataFrame(
  52. np.random.randint(1, 100, size=(101, 4)), columns=list("ABCD"), dtype="int64"
  53. )
  54. @pytest.fixture
  55. def _array(_frame):
  56. return _frame["A"].values.copy()
  57. @pytest.fixture
  58. def _array2(_frame2):
  59. return _frame2["A"].values.copy()
  60. @pytest.fixture
  61. def _array_mixed(_mixed):
  62. return _mixed["D"].values.copy()
  63. @pytest.fixture
  64. def _array_mixed2(_mixed2):
  65. return _mixed2["D"].values.copy()
  66. @pytest.mark.skipif(not expr.USE_NUMEXPR, reason="not using numexpr")
  67. class TestExpressions:
  68. @pytest.fixture(autouse=True)
  69. def save_min_elements(self):
  70. min_elements = expr._MIN_ELEMENTS
  71. yield
  72. expr._MIN_ELEMENTS = min_elements
  73. @staticmethod
  74. def call_op(df, other, flex: bool, opname: str):
  75. if flex:
  76. op = lambda x, y: getattr(x, opname)(y)
  77. op.__name__ = opname
  78. else:
  79. op = getattr(operator, opname)
  80. with option_context("compute.use_numexpr", False):
  81. expected = op(df, other)
  82. expr.get_test_result()
  83. result = op(df, other)
  84. return result, expected
  85. @pytest.mark.parametrize(
  86. "fixture",
  87. [
  88. "_integer",
  89. "_integer2",
  90. "_integer_randint",
  91. "_frame",
  92. "_frame2",
  93. "_mixed",
  94. "_mixed2",
  95. ],
  96. )
  97. @pytest.mark.parametrize("flex", [True, False])
  98. @pytest.mark.parametrize(
  99. "arith", ["add", "sub", "mul", "mod", "truediv", "floordiv"]
  100. )
  101. def test_run_arithmetic(self, request, fixture, flex, arith):
  102. df = request.getfixturevalue(fixture)
  103. expr._MIN_ELEMENTS = 0
  104. result, expected = self.call_op(df, df, flex, arith)
  105. if arith == "truediv":
  106. assert all(x.kind == "f" for x in expected.dtypes.values)
  107. tm.assert_equal(expected, result)
  108. for i in range(len(df.columns)):
  109. result, expected = self.call_op(df.iloc[:, i], df.iloc[:, i], flex, arith)
  110. if arith == "truediv":
  111. assert expected.dtype.kind == "f"
  112. tm.assert_equal(expected, result)
  113. @pytest.mark.parametrize(
  114. "fixture",
  115. [
  116. "_integer",
  117. "_integer2",
  118. "_integer_randint",
  119. "_frame",
  120. "_frame2",
  121. "_mixed",
  122. "_mixed2",
  123. ],
  124. )
  125. @pytest.mark.parametrize("flex", [True, False])
  126. def test_run_binary(self, request, fixture, flex, comparison_op):
  127. """
  128. tests solely that the result is the same whether or not numexpr is
  129. enabled. Need to test whether the function does the correct thing
  130. elsewhere.
  131. """
  132. df = request.getfixturevalue(fixture)
  133. arith = comparison_op.__name__
  134. with option_context("compute.use_numexpr", False):
  135. other = df.copy() + 1
  136. expr._MIN_ELEMENTS = 0
  137. expr.set_test_mode(True)
  138. result, expected = self.call_op(df, other, flex, arith)
  139. used_numexpr = expr.get_test_result()
  140. assert used_numexpr, "Did not use numexpr as expected."
  141. tm.assert_equal(expected, result)
  142. # FIXME: dont leave commented-out
  143. # series doesn't uses vec_compare instead of numexpr...
  144. # for i in range(len(df.columns)):
  145. # binary_comp = other.iloc[:, i] + 1
  146. # self.run_binary(df.iloc[:, i], binary_comp, flex)
  147. def test_invalid(self):
  148. array = np.random.randn(1_000_001)
  149. array2 = np.random.randn(100)
  150. # no op
  151. result = expr._can_use_numexpr(operator.add, None, array, array, "evaluate")
  152. assert not result
  153. # min elements
  154. result = expr._can_use_numexpr(operator.add, "+", array2, array2, "evaluate")
  155. assert not result
  156. # ok, we only check on first part of expression
  157. result = expr._can_use_numexpr(operator.add, "+", array, array2, "evaluate")
  158. assert result
  159. @pytest.mark.filterwarnings(
  160. "ignore:invalid value encountered in true_divide:RuntimeWarning"
  161. )
  162. @pytest.mark.parametrize(
  163. "opname,op_str",
  164. [("add", "+"), ("sub", "-"), ("mul", "*"), ("truediv", "/"), ("pow", "**")],
  165. )
  166. @pytest.mark.parametrize(
  167. "left_fix,right_fix", [("_array", "_array2"), ("_array_mixed", "_array_mixed2")]
  168. )
  169. def test_binary_ops(self, request, opname, op_str, left_fix, right_fix):
  170. left = request.getfixturevalue(left_fix)
  171. right = request.getfixturevalue(right_fix)
  172. def testit():
  173. if opname == "pow":
  174. # TODO: get this working
  175. return
  176. op = getattr(operator, opname)
  177. with warnings.catch_warnings():
  178. # array has 0s
  179. msg = "invalid value encountered in divide|true_divide"
  180. warnings.filterwarnings("ignore", msg, RuntimeWarning)
  181. result = expr.evaluate(op, left, left, use_numexpr=True)
  182. expected = expr.evaluate(op, left, left, use_numexpr=False)
  183. tm.assert_numpy_array_equal(result, expected)
  184. result = expr._can_use_numexpr(op, op_str, right, right, "evaluate")
  185. assert not result
  186. with option_context("compute.use_numexpr", False):
  187. testit()
  188. expr.set_numexpr_threads(1)
  189. testit()
  190. expr.set_numexpr_threads()
  191. testit()
  192. @pytest.mark.parametrize(
  193. "left_fix,right_fix", [("_array", "_array2"), ("_array_mixed", "_array_mixed2")]
  194. )
  195. def test_comparison_ops(self, request, comparison_op, left_fix, right_fix):
  196. left = request.getfixturevalue(left_fix)
  197. right = request.getfixturevalue(right_fix)
  198. def testit():
  199. f12 = left + 1
  200. f22 = right + 1
  201. op = comparison_op
  202. result = expr.evaluate(op, left, f12, use_numexpr=True)
  203. expected = expr.evaluate(op, left, f12, use_numexpr=False)
  204. tm.assert_numpy_array_equal(result, expected)
  205. result = expr._can_use_numexpr(op, op, right, f22, "evaluate")
  206. assert not result
  207. with option_context("compute.use_numexpr", False):
  208. testit()
  209. expr.set_numexpr_threads(1)
  210. testit()
  211. expr.set_numexpr_threads()
  212. testit()
  213. @pytest.mark.parametrize("cond", [True, False])
  214. @pytest.mark.parametrize("fixture", ["_frame", "_frame2", "_mixed", "_mixed2"])
  215. def test_where(self, request, cond, fixture):
  216. df = request.getfixturevalue(fixture)
  217. def testit():
  218. c = np.empty(df.shape, dtype=np.bool_)
  219. c.fill(cond)
  220. result = expr.where(c, df.values, df.values + 1)
  221. expected = np.where(c, df.values, df.values + 1)
  222. tm.assert_numpy_array_equal(result, expected)
  223. with option_context("compute.use_numexpr", False):
  224. testit()
  225. expr.set_numexpr_threads(1)
  226. testit()
  227. expr.set_numexpr_threads()
  228. testit()
  229. @pytest.mark.parametrize(
  230. "op_str,opname", [("/", "truediv"), ("//", "floordiv"), ("**", "pow")]
  231. )
  232. def test_bool_ops_raise_on_arithmetic(self, op_str, opname):
  233. df = DataFrame({"a": np.random.rand(10) > 0.5, "b": np.random.rand(10) > 0.5})
  234. msg = f"operator '{opname}' not implemented for bool dtypes"
  235. f = getattr(operator, opname)
  236. err_msg = re.escape(msg)
  237. with pytest.raises(NotImplementedError, match=err_msg):
  238. f(df, df)
  239. with pytest.raises(NotImplementedError, match=err_msg):
  240. f(df.a, df.b)
  241. with pytest.raises(NotImplementedError, match=err_msg):
  242. f(df.a, True)
  243. with pytest.raises(NotImplementedError, match=err_msg):
  244. f(False, df.a)
  245. with pytest.raises(NotImplementedError, match=err_msg):
  246. f(False, df)
  247. with pytest.raises(NotImplementedError, match=err_msg):
  248. f(df, True)
  249. @pytest.mark.parametrize(
  250. "op_str,opname", [("+", "add"), ("*", "mul"), ("-", "sub")]
  251. )
  252. def test_bool_ops_warn_on_arithmetic(self, op_str, opname):
  253. n = 10
  254. df = DataFrame({"a": np.random.rand(n) > 0.5, "b": np.random.rand(n) > 0.5})
  255. subs = {"+": "|", "*": "&", "-": "^"}
  256. sub_funcs = {"|": "or_", "&": "and_", "^": "xor"}
  257. f = getattr(operator, opname)
  258. fe = getattr(operator, sub_funcs[subs[op_str]])
  259. if op_str == "-":
  260. # raises TypeError
  261. return
  262. with tm.use_numexpr(True, min_elements=5):
  263. with tm.assert_produces_warning():
  264. r = f(df, df)
  265. e = fe(df, df)
  266. tm.assert_frame_equal(r, e)
  267. with tm.assert_produces_warning():
  268. r = f(df.a, df.b)
  269. e = fe(df.a, df.b)
  270. tm.assert_series_equal(r, e)
  271. with tm.assert_produces_warning():
  272. r = f(df.a, True)
  273. e = fe(df.a, True)
  274. tm.assert_series_equal(r, e)
  275. with tm.assert_produces_warning():
  276. r = f(False, df.a)
  277. e = fe(False, df.a)
  278. tm.assert_series_equal(r, e)
  279. with tm.assert_produces_warning():
  280. r = f(False, df)
  281. e = fe(False, df)
  282. tm.assert_frame_equal(r, e)
  283. with tm.assert_produces_warning():
  284. r = f(df, True)
  285. e = fe(df, True)
  286. tm.assert_frame_equal(r, e)
  287. @pytest.mark.parametrize(
  288. "test_input,expected",
  289. [
  290. (
  291. DataFrame(
  292. [[0, 1, 2, "aa"], [0, 1, 2, "aa"]], columns=["a", "b", "c", "dtype"]
  293. ),
  294. DataFrame([[False, False], [False, False]], columns=["a", "dtype"]),
  295. ),
  296. (
  297. DataFrame(
  298. [[0, 3, 2, "aa"], [0, 4, 2, "aa"], [0, 1, 1, "bb"]],
  299. columns=["a", "b", "c", "dtype"],
  300. ),
  301. DataFrame(
  302. [[False, False], [False, False], [False, False]],
  303. columns=["a", "dtype"],
  304. ),
  305. ),
  306. ],
  307. )
  308. def test_bool_ops_column_name_dtype(self, test_input, expected):
  309. # GH 22383 - .ne fails if columns containing column name 'dtype'
  310. result = test_input.loc[:, ["a", "dtype"]].ne(test_input.loc[:, ["a", "dtype"]])
  311. tm.assert_frame_equal(result, expected)
  312. @pytest.mark.parametrize(
  313. "arith", ("add", "sub", "mul", "mod", "truediv", "floordiv")
  314. )
  315. @pytest.mark.parametrize("axis", (0, 1))
  316. def test_frame_series_axis(self, axis, arith, _frame):
  317. # GH#26736 Dataframe.floordiv(Series, axis=1) fails
  318. df = _frame
  319. if axis == 1:
  320. other = df.iloc[0, :]
  321. else:
  322. other = df.iloc[:, 0]
  323. expr._MIN_ELEMENTS = 0
  324. op_func = getattr(df, arith)
  325. with option_context("compute.use_numexpr", False):
  326. expected = op_func(other, axis=axis)
  327. result = op_func(other, axis=axis)
  328. tm.assert_frame_equal(expected, result)
  329. @pytest.mark.parametrize(
  330. "op",
  331. [
  332. "__mod__",
  333. "__rmod__",
  334. "__floordiv__",
  335. "__rfloordiv__",
  336. ],
  337. )
  338. @pytest.mark.parametrize("box", [DataFrame, Series, Index])
  339. @pytest.mark.parametrize("scalar", [-5, 5])
  340. def test_python_semantics_with_numexpr_installed(self, op, box, scalar):
  341. # https://github.com/pandas-dev/pandas/issues/36047
  342. expr._MIN_ELEMENTS = 0
  343. data = np.arange(-50, 50)
  344. obj = box(data)
  345. method = getattr(obj, op)
  346. result = method(scalar)
  347. # compare result with numpy
  348. with option_context("compute.use_numexpr", False):
  349. expected = method(scalar)
  350. tm.assert_equal(result, expected)
  351. # compare result element-wise with Python
  352. for i, elem in enumerate(data):
  353. if box == DataFrame:
  354. scalar_result = result.iloc[i, 0]
  355. else:
  356. scalar_result = result[i]
  357. try:
  358. expected = getattr(int(elem), op)(scalar)
  359. except ZeroDivisionError:
  360. pass
  361. else:
  362. assert scalar_result == expected