test_numba.py 7.7 KB


  1. import pytest
  2. from pandas.errors import NumbaUtilError
  3. import pandas.util._test_decorators as td
  4. from pandas import (
  5. DataFrame,
  6. Series,
  7. option_context,
  8. )
  9. import pandas._testing as tm
  10. @td.skip_if_no("numba")
  11. def test_correct_function_signature():
  12. def incorrect_function(x):
  13. return x + 1
  14. data = DataFrame(
  15. {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
  16. columns=["key", "data"],
  17. )
  18. with pytest.raises(NumbaUtilError, match="The first 2"):
  19. data.groupby("key").transform(incorrect_function, engine="numba")
  20. with pytest.raises(NumbaUtilError, match="The first 2"):
  21. data.groupby("key")["data"].transform(incorrect_function, engine="numba")
  22. @td.skip_if_no("numba")
  23. def test_check_nopython_kwargs():
  24. def incorrect_function(values, index):
  25. return values + 1
  26. data = DataFrame(
  27. {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
  28. columns=["key", "data"],
  29. )
  30. with pytest.raises(NumbaUtilError, match="numba does not support"):
  31. data.groupby("key").transform(incorrect_function, engine="numba", a=1)
  32. with pytest.raises(NumbaUtilError, match="numba does not support"):
  33. data.groupby("key")["data"].transform(incorrect_function, engine="numba", a=1)
  34. @td.skip_if_no("numba")
  35. @pytest.mark.filterwarnings("ignore")
  36. # Filter warnings when parallel=True and the function can't be parallelized by Numba
  37. @pytest.mark.parametrize("jit", [True, False])
  38. @pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
  39. @pytest.mark.parametrize("as_index", [True, False])
  40. def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index):
  41. def func(values, index):
  42. return values + 1
  43. if jit:
  44. # Test accepted jitted functions
  45. import numba
  46. func = numba.jit(func)
  47. data = DataFrame(
  48. {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
  49. )
  50. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  51. grouped = data.groupby(0, as_index=as_index)
  52. if pandas_obj == "Series":
  53. grouped = grouped[1]
  54. result = grouped.transform(func, engine="numba", engine_kwargs=engine_kwargs)
  55. expected = grouped.transform(lambda x: x + 1, engine="cython")
  56. tm.assert_equal(result, expected)
  57. @td.skip_if_no("numba")
  58. @pytest.mark.filterwarnings("ignore")
  59. # Filter warnings when parallel=True and the function can't be parallelized by Numba
  60. @pytest.mark.parametrize("jit", [True, False])
  61. @pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
  62. def test_cache(jit, pandas_obj, nogil, parallel, nopython):
  63. # Test that the functions are cached correctly if we switch functions
  64. def func_1(values, index):
  65. return values + 1
  66. def func_2(values, index):
  67. return values * 5
  68. if jit:
  69. import numba
  70. func_1 = numba.jit(func_1)
  71. func_2 = numba.jit(func_2)
  72. data = DataFrame(
  73. {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
  74. )
  75. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  76. grouped = data.groupby(0)
  77. if pandas_obj == "Series":
  78. grouped = grouped[1]
  79. result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs)
  80. expected = grouped.transform(lambda x: x + 1, engine="cython")
  81. tm.assert_equal(result, expected)
  82. result = grouped.transform(func_2, engine="numba", engine_kwargs=engine_kwargs)
  83. expected = grouped.transform(lambda x: x * 5, engine="cython")
  84. tm.assert_equal(result, expected)
  85. # Retest func_1 which should use the cache
  86. result = grouped.transform(func_1, engine="numba", engine_kwargs=engine_kwargs)
  87. expected = grouped.transform(lambda x: x + 1, engine="cython")
  88. tm.assert_equal(result, expected)
  89. @td.skip_if_no("numba")
  90. def test_use_global_config():
  91. def func_1(values, index):
  92. return values + 1
  93. data = DataFrame(
  94. {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
  95. )
  96. grouped = data.groupby(0)
  97. expected = grouped.transform(func_1, engine="numba")
  98. with option_context("compute.use_numba", True):
  99. result = grouped.transform(func_1, engine=None)
  100. tm.assert_frame_equal(expected, result)
  101. @td.skip_if_no("numba")
  102. @pytest.mark.parametrize(
  103. "agg_func", [["min", "max"], "min", {"B": ["min", "max"], "C": "sum"}]
  104. )
  105. def test_multifunc_notimplimented(agg_func):
  106. data = DataFrame(
  107. {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
  108. )
  109. grouped = data.groupby(0)
  110. with pytest.raises(NotImplementedError, match="Numba engine can"):
  111. grouped.transform(agg_func, engine="numba")
  112. with pytest.raises(NotImplementedError, match="Numba engine can"):
  113. grouped[1].transform(agg_func, engine="numba")
  114. @td.skip_if_no("numba")
  115. def test_args_not_cached():
  116. # GH 41647
  117. def sum_last(values, index, n):
  118. return values[-n:].sum()
  119. df = DataFrame({"id": [0, 0, 1, 1], "x": [1, 1, 1, 1]})
  120. grouped_x = df.groupby("id")["x"]
  121. result = grouped_x.transform(sum_last, 1, engine="numba")
  122. expected = Series([1.0] * 4, name="x")
  123. tm.assert_series_equal(result, expected)
  124. result = grouped_x.transform(sum_last, 2, engine="numba")
  125. expected = Series([2.0] * 4, name="x")
  126. tm.assert_series_equal(result, expected)
  127. @td.skip_if_no("numba")
  128. def test_index_data_correctly_passed():
  129. # GH 43133
  130. def f(values, index):
  131. return index - 1
  132. df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3])
  133. result = df.groupby("group").transform(f, engine="numba")
  134. expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3])
  135. tm.assert_frame_equal(result, expected)
  136. @td.skip_if_no("numba")
  137. def test_engine_kwargs_not_cached():
  138. # If the user passes a different set of engine_kwargs don't return the same
  139. # jitted function
  140. nogil = True
  141. parallel = False
  142. nopython = True
  143. def func_kwargs(values, index):
  144. return nogil + parallel + nopython
  145. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  146. df = DataFrame({"value": [0, 0, 0]})
  147. result = df.groupby(level=0).transform(
  148. func_kwargs, engine="numba", engine_kwargs=engine_kwargs
  149. )
  150. expected = DataFrame({"value": [2.0, 2.0, 2.0]})
  151. tm.assert_frame_equal(result, expected)
  152. nogil = False
  153. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  154. result = df.groupby(level=0).transform(
  155. func_kwargs, engine="numba", engine_kwargs=engine_kwargs
  156. )
  157. expected = DataFrame({"value": [1.0, 1.0, 1.0]})
  158. tm.assert_frame_equal(result, expected)
  159. @td.skip_if_no("numba")
  160. @pytest.mark.filterwarnings("ignore")
  161. def test_multiindex_one_key(nogil, parallel, nopython):
  162. def numba_func(values, index):
  163. return 1
  164. df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
  165. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  166. result = df.groupby("A").transform(
  167. numba_func, engine="numba", engine_kwargs=engine_kwargs
  168. )
  169. expected = DataFrame([{"A": 1, "B": 2, "C": 1.0}]).set_index(["A", "B"])
  170. tm.assert_frame_equal(result, expected)
  171. @td.skip_if_no("numba")
  172. def test_multiindex_multi_key_not_supported(nogil, parallel, nopython):
  173. def numba_func(values, index):
  174. return 1
  175. df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
  176. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  177. with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
  178. df.groupby(["A", "B"]).transform(
  179. numba_func, engine="numba", engine_kwargs=engine_kwargs
  180. )