test_numba.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import pytest
  2. import pandas.util._test_decorators as td
  3. from pandas import (
  4. DataFrame,
  5. Series,
  6. )
  7. import pandas._testing as tm
  8. @td.skip_if_no("numba")
  9. @pytest.mark.filterwarnings("ignore")
  10. # Filter warnings when parallel=True and the function can't be parallelized by Numba
  11. class TestEngine:
  12. def test_cython_vs_numba_frame(
  13. self, sort, nogil, parallel, nopython, numba_supported_reductions
  14. ):
  15. func, kwargs = numba_supported_reductions
  16. df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
  17. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  18. gb = df.groupby("a", sort=sort)
  19. result = getattr(gb, func)(
  20. engine="numba", engine_kwargs=engine_kwargs, **kwargs
  21. )
  22. expected = getattr(gb, func)(**kwargs)
  23. # check_dtype can be removed if GH 44952 is addressed
  24. check_dtype = func not in ("sum", "min", "max")
  25. tm.assert_frame_equal(result, expected, check_dtype=check_dtype)
  26. def test_cython_vs_numba_getitem(
  27. self, sort, nogil, parallel, nopython, numba_supported_reductions
  28. ):
  29. func, kwargs = numba_supported_reductions
  30. df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
  31. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  32. gb = df.groupby("a", sort=sort)["c"]
  33. result = getattr(gb, func)(
  34. engine="numba", engine_kwargs=engine_kwargs, **kwargs
  35. )
  36. expected = getattr(gb, func)(**kwargs)
  37. # check_dtype can be removed if GH 44952 is addressed
  38. check_dtype = func not in ("sum", "min", "max")
  39. tm.assert_series_equal(result, expected, check_dtype=check_dtype)
  40. def test_cython_vs_numba_series(
  41. self, sort, nogil, parallel, nopython, numba_supported_reductions
  42. ):
  43. func, kwargs = numba_supported_reductions
  44. ser = Series(range(3), index=[1, 2, 1], name="foo")
  45. engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
  46. gb = ser.groupby(level=0, sort=sort)
  47. result = getattr(gb, func)(
  48. engine="numba", engine_kwargs=engine_kwargs, **kwargs
  49. )
  50. expected = getattr(gb, func)(**kwargs)
  51. # check_dtype can be removed if GH 44952 is addressed
  52. check_dtype = func not in ("sum", "min", "max")
  53. tm.assert_series_equal(result, expected, check_dtype=check_dtype)
  54. def test_as_index_false_unsupported(self, numba_supported_reductions):
  55. func, kwargs = numba_supported_reductions
  56. df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
  57. gb = df.groupby("a", as_index=False)
  58. with pytest.raises(NotImplementedError, match="as_index=False"):
  59. getattr(gb, func)(engine="numba", **kwargs)
  60. def test_axis_1_unsupported(self, numba_supported_reductions):
  61. func, kwargs = numba_supported_reductions
  62. df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
  63. gb = df.groupby("a", axis=1)
  64. with pytest.raises(NotImplementedError, match="axis=1"):
  65. getattr(gb, func)(engine="numba", **kwargs)