12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- import pytest
- import pandas.util._test_decorators as td
- from pandas import (
- DataFrame,
- Series,
- )
- import pandas._testing as tm
- @td.skip_if_no("numba")
- @pytest.mark.filterwarnings("ignore")
- # Filter warnings when parallel=True and the function can't be parallelized by Numba
- class TestEngine:
- def test_cython_vs_numba_frame(
- self, sort, nogil, parallel, nopython, numba_supported_reductions
- ):
- func, kwargs = numba_supported_reductions
- df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
- engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
- gb = df.groupby("a", sort=sort)
- result = getattr(gb, func)(
- engine="numba", engine_kwargs=engine_kwargs, **kwargs
- )
- expected = getattr(gb, func)(**kwargs)
- # check_dtype can be removed if GH 44952 is addressed
- check_dtype = func not in ("sum", "min", "max")
- tm.assert_frame_equal(result, expected, check_dtype=check_dtype)
- def test_cython_vs_numba_getitem(
- self, sort, nogil, parallel, nopython, numba_supported_reductions
- ):
- func, kwargs = numba_supported_reductions
- df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
- engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
- gb = df.groupby("a", sort=sort)["c"]
- result = getattr(gb, func)(
- engine="numba", engine_kwargs=engine_kwargs, **kwargs
- )
- expected = getattr(gb, func)(**kwargs)
- # check_dtype can be removed if GH 44952 is addressed
- check_dtype = func not in ("sum", "min", "max")
- tm.assert_series_equal(result, expected, check_dtype=check_dtype)
- def test_cython_vs_numba_series(
- self, sort, nogil, parallel, nopython, numba_supported_reductions
- ):
- func, kwargs = numba_supported_reductions
- ser = Series(range(3), index=[1, 2, 1], name="foo")
- engine_kwargs = {"nogil": nogil, "parallel": parallel, "nopython": nopython}
- gb = ser.groupby(level=0, sort=sort)
- result = getattr(gb, func)(
- engine="numba", engine_kwargs=engine_kwargs, **kwargs
- )
- expected = getattr(gb, func)(**kwargs)
- # check_dtype can be removed if GH 44952 is addressed
- check_dtype = func not in ("sum", "min", "max")
- tm.assert_series_equal(result, expected, check_dtype=check_dtype)
- def test_as_index_false_unsupported(self, numba_supported_reductions):
- func, kwargs = numba_supported_reductions
- df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
- gb = df.groupby("a", as_index=False)
- with pytest.raises(NotImplementedError, match="as_index=False"):
- getattr(gb, func)(engine="numba", **kwargs)
- def test_axis_1_unsupported(self, numba_supported_reductions):
- func, kwargs = numba_supported_reductions
- df = DataFrame({"a": [3, 2, 3, 2], "b": range(4), "c": range(1, 5)})
- gb = df.groupby("a", axis=1)
- with pytest.raises(NotImplementedError, match="axis=1"):
- getattr(gb, func)(engine="numba", **kwargs)
|