test_numba.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. import numpy as np
  2. import pytest
  3. from pandas.errors import NumbaUtilError
  4. import pandas.util._test_decorators as td
  5. from pandas import (
  6. DataFrame,
  7. Index,
  8. NamedAgg,
  9. Series,
  10. option_context,
  11. )
  12. import pandas._testing as tm
  13. @td.skip_if_no("numba")
  14. def test_correct_function_signature():
  15. def incorrect_function(x):
  16. return sum(x) * 2.7
  17. data = DataFrame(
  18. {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
  19. columns=["key", "data"],
  20. )
  21. with pytest.raises(NumbaUtilError, match="The first 2"):
  22. data.groupby("key").agg(incorrect_function, engine="numba")
  23. with pytest.raises(NumbaUtilError, match="The first 2"):
  24. data.groupby("key")["data"].agg(incorrect_function, engine="numba")
  25. @td.skip_if_no("numba")
  26. def test_check_nopython_kwargs():
  27. def incorrect_function(values, index):
  28. return sum(values) * 2.7
  29. data = DataFrame(
  30. {"key": ["a", "a", "b", "b", "a"], "data": [1.0, 2.0, 3.0, 4.0, 5.0]},
  31. columns=["key", "data"],
  32. )
  33. with pytest.raises(NumbaUtilError, match="numba does not support"):
  34. data.groupby("key").agg(incorrect_function, engine="numba", a=1)
  35. with pytest.raises(NumbaUtilError, match="numba does not support"):
  36. data.groupby("key")["data"].agg(incorrect_function, engine="numba", a=1)
  37. @td.skip_if_no("numba")
  38. @pytest.mark.filterwarnings("ignore")
  39. # Filter warnings when parallel=True and the function can't be parallelized by Numba
  40. @pytest.mark.parametrize("jit", [True, False])
  41. @pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
  42. @pytest.mark.parametrize("as_index", [True, False])
  43. def test_numba_vs_cython(jit, pandas_obj, nogil, parallel, nopython, as_index):
  44. def func_numba(values, index):
  45. return np.mean(values) * 2.7
  46. if jit:
  47. # Test accepted jitted functions
  48. import numba
  49. func_numba = numba.jit(func_numba)
  50. data = DataFrame(
  51. {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
  52. )
  53. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  54. grouped = data.groupby(0, as_index=as_index)
  55. if pandas_obj == "Series":
  56. grouped = grouped[1]
  57. result = grouped.agg(func_numba, engine="numba", engine_kwargs=engine_kwargs)
  58. expected = grouped.agg(lambda x: np.mean(x) * 2.7, engine="cython")
  59. tm.assert_equal(result, expected)
  60. @td.skip_if_no("numba")
  61. @pytest.mark.filterwarnings("ignore")
  62. # Filter warnings when parallel=True and the function can't be parallelized by Numba
  63. @pytest.mark.parametrize("jit", [True, False])
  64. @pytest.mark.parametrize("pandas_obj", ["Series", "DataFrame"])
  65. def test_cache(jit, pandas_obj, nogil, parallel, nopython):
  66. # Test that the functions are cached correctly if we switch functions
  67. def func_1(values, index):
  68. return np.mean(values) - 3.4
  69. def func_2(values, index):
  70. return np.mean(values) * 2.7
  71. if jit:
  72. import numba
  73. func_1 = numba.jit(func_1)
  74. func_2 = numba.jit(func_2)
  75. data = DataFrame(
  76. {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
  77. )
  78. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  79. grouped = data.groupby(0)
  80. if pandas_obj == "Series":
  81. grouped = grouped[1]
  82. result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs)
  83. expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython")
  84. tm.assert_equal(result, expected)
  85. # Add func_2 to the cache
  86. result = grouped.agg(func_2, engine="numba", engine_kwargs=engine_kwargs)
  87. expected = grouped.agg(lambda x: np.mean(x) * 2.7, engine="cython")
  88. tm.assert_equal(result, expected)
  89. # Retest func_1 which should use the cache
  90. result = grouped.agg(func_1, engine="numba", engine_kwargs=engine_kwargs)
  91. expected = grouped.agg(lambda x: np.mean(x) - 3.4, engine="cython")
  92. tm.assert_equal(result, expected)
  93. @td.skip_if_no("numba")
  94. def test_use_global_config():
  95. def func_1(values, index):
  96. return np.mean(values) - 3.4
  97. data = DataFrame(
  98. {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
  99. )
  100. grouped = data.groupby(0)
  101. expected = grouped.agg(func_1, engine="numba")
  102. with option_context("compute.use_numba", True):
  103. result = grouped.agg(func_1, engine=None)
  104. tm.assert_frame_equal(expected, result)
  105. @td.skip_if_no("numba")
  106. @pytest.mark.parametrize(
  107. "agg_func",
  108. [
  109. ["min", "max"],
  110. "min",
  111. {"B": ["min", "max"], "C": "sum"},
  112. NamedAgg(column="B", aggfunc="min"),
  113. ],
  114. )
  115. def test_multifunc_notimplimented(agg_func):
  116. data = DataFrame(
  117. {0: ["a", "a", "b", "b", "a"], 1: [1.0, 2.0, 3.0, 4.0, 5.0]}, columns=[0, 1]
  118. )
  119. grouped = data.groupby(0)
  120. with pytest.raises(NotImplementedError, match="Numba engine can"):
  121. grouped.agg(agg_func, engine="numba")
  122. with pytest.raises(NotImplementedError, match="Numba engine can"):
  123. grouped[1].agg(agg_func, engine="numba")
  124. @td.skip_if_no("numba")
  125. def test_args_not_cached():
  126. # GH 41647
  127. def sum_last(values, index, n):
  128. return values[-n:].sum()
  129. df = DataFrame({"id": [0, 0, 1, 1], "x": [1, 1, 1, 1]})
  130. grouped_x = df.groupby("id")["x"]
  131. result = grouped_x.agg(sum_last, 1, engine="numba")
  132. expected = Series([1.0] * 2, name="x", index=Index([0, 1], name="id"))
  133. tm.assert_series_equal(result, expected)
  134. result = grouped_x.agg(sum_last, 2, engine="numba")
  135. expected = Series([2.0] * 2, name="x", index=Index([0, 1], name="id"))
  136. tm.assert_series_equal(result, expected)
  137. @td.skip_if_no("numba")
  138. def test_index_data_correctly_passed():
  139. # GH 43133
  140. def f(values, index):
  141. return np.mean(index)
  142. df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3])
  143. result = df.groupby("group").aggregate(f, engine="numba")
  144. expected = DataFrame(
  145. [-1.5, -3.0], columns=["v"], index=Index(["A", "B"], name="group")
  146. )
  147. tm.assert_frame_equal(result, expected)
  148. @td.skip_if_no("numba")
  149. def test_engine_kwargs_not_cached():
  150. # If the user passes a different set of engine_kwargs don't return the same
  151. # jitted function
  152. nogil = True
  153. parallel = False
  154. nopython = True
  155. def func_kwargs(values, index):
  156. return nogil + parallel + nopython
  157. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  158. df = DataFrame({"value": [0, 0, 0]})
  159. result = df.groupby(level=0).aggregate(
  160. func_kwargs, engine="numba", engine_kwargs=engine_kwargs
  161. )
  162. expected = DataFrame({"value": [2.0, 2.0, 2.0]})
  163. tm.assert_frame_equal(result, expected)
  164. nogil = False
  165. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  166. result = df.groupby(level=0).aggregate(
  167. func_kwargs, engine="numba", engine_kwargs=engine_kwargs
  168. )
  169. expected = DataFrame({"value": [1.0, 1.0, 1.0]})
  170. tm.assert_frame_equal(result, expected)
  171. @td.skip_if_no("numba")
  172. @pytest.mark.filterwarnings("ignore")
  173. def test_multiindex_one_key(nogil, parallel, nopython):
  174. def numba_func(values, index):
  175. return 1
  176. df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
  177. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  178. result = df.groupby("A").agg(
  179. numba_func, engine="numba", engine_kwargs=engine_kwargs
  180. )
  181. expected = DataFrame([1.0], index=Index([1], name="A"), columns=["C"])
  182. tm.assert_frame_equal(result, expected)
  183. @td.skip_if_no("numba")
  184. def test_multiindex_multi_key_not_supported(nogil, parallel, nopython):
  185. def numba_func(values, index):
  186. return 1
  187. df = DataFrame([{"A": 1, "B": 2, "C": 3}]).set_index(["A", "B"])
  188. engine_kwargs = {"nopython": nopython, "nogil": nogil, "parallel": parallel}
  189. with pytest.raises(NotImplementedError, match="More than 1 grouping labels"):
  190. df.groupby(["A", "B"]).agg(
  191. numba_func, engine="numba", engine_kwargs=engine_kwargs
  192. )