123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495 |
- import decimal
- import operator
- import numpy as np
- import pytest
- import pandas as pd
- import pandas._testing as tm
- from pandas.api.types import infer_dtype
- from pandas.tests.extension import base
- from pandas.tests.extension.decimal.array import (
- DecimalArray,
- DecimalDtype,
- make_data,
- to_decimal,
- )
- @pytest.fixture
- def dtype():
- return DecimalDtype()
- @pytest.fixture
- def data():
- return DecimalArray(make_data())
- @pytest.fixture
- def data_for_twos():
- return DecimalArray([decimal.Decimal(2) for _ in range(100)])
- @pytest.fixture
- def data_missing():
- return DecimalArray([decimal.Decimal("NaN"), decimal.Decimal(1)])
- @pytest.fixture
- def data_for_sorting():
- return DecimalArray(
- [decimal.Decimal("1"), decimal.Decimal("2"), decimal.Decimal("0")]
- )
- @pytest.fixture
- def data_missing_for_sorting():
- return DecimalArray(
- [decimal.Decimal("1"), decimal.Decimal("NaN"), decimal.Decimal("0")]
- )
- @pytest.fixture
- def na_cmp():
- return lambda x, y: x.is_nan() and y.is_nan()
- @pytest.fixture
- def na_value():
- return decimal.Decimal("NaN")
- @pytest.fixture
- def data_for_grouping():
- b = decimal.Decimal("1.0")
- a = decimal.Decimal("0.0")
- c = decimal.Decimal("2.0")
- na = decimal.Decimal("NaN")
- return DecimalArray([b, b, na, na, a, a, b, c])
- class TestDtype(base.BaseDtypeTests):
- def test_hashable(self, dtype):
- pass
- @pytest.mark.parametrize("skipna", [True, False])
- def test_infer_dtype(self, data, data_missing, skipna):
- # here overriding base test to ensure we fall back to return
- # "unknown-array" for an EA pandas doesn't know
- assert infer_dtype(data, skipna=skipna) == "unknown-array"
- assert infer_dtype(data_missing, skipna=skipna) == "unknown-array"
- class TestInterface(base.BaseInterfaceTests):
- pass
- class TestConstructors(base.BaseConstructorsTests):
- pass
- class TestReshaping(base.BaseReshapingTests):
- pass
- class TestGetitem(base.BaseGetitemTests):
- def test_take_na_value_other_decimal(self):
- arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
- result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0"))
- expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")])
- self.assert_extension_array_equal(result, expected)
- class TestIndex(base.BaseIndexTests):
- pass
- class TestMissing(base.BaseMissingTests):
- pass
- class Reduce:
- def check_reduce(self, s, op_name, skipna):
- if op_name in ["median", "skew", "kurt", "sem"]:
- msg = r"decimal does not support the .* operation"
- with pytest.raises(NotImplementedError, match=msg):
- getattr(s, op_name)(skipna=skipna)
- elif op_name == "count":
- result = getattr(s, op_name)()
- expected = len(s) - s.isna().sum()
- tm.assert_almost_equal(result, expected)
- else:
- result = getattr(s, op_name)(skipna=skipna)
- expected = getattr(np.asarray(s), op_name)()
- tm.assert_almost_equal(result, expected)
- class TestNumericReduce(Reduce, base.BaseNumericReduceTests):
- pass
- class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
- pass
- class TestMethods(base.BaseMethodsTests):
- @pytest.mark.parametrize("dropna", [True, False])
- def test_value_counts(self, all_data, dropna, request):
- all_data = all_data[:10]
- if dropna:
- other = np.array(all_data[~all_data.isna()])
- else:
- other = all_data
- vcs = pd.Series(all_data).value_counts(dropna=dropna)
- vcs_ex = pd.Series(other).value_counts(dropna=dropna)
- with decimal.localcontext() as ctx:
- # avoid raising when comparing Decimal("NAN") < Decimal(2)
- ctx.traps[decimal.InvalidOperation] = False
- result = vcs.sort_index()
- expected = vcs_ex.sort_index()
- tm.assert_series_equal(result, expected)
- class TestCasting(base.BaseCastingTests):
- pass
- class TestGroupby(base.BaseGroupbyTests):
- pass
- class TestSetitem(base.BaseSetitemTests):
- pass
- class TestPrinting(base.BasePrintingTests):
- def test_series_repr(self, data):
- # Overriding this base test to explicitly test that
- # the custom _formatter is used
- ser = pd.Series(data)
- assert data.dtype.name in repr(ser)
- assert "Decimal: " in repr(ser)
- @pytest.mark.xfail(
- reason=(
- "DecimalArray constructor raises bc _from_sequence wants Decimals, not ints."
- "Easy to fix, just need to do it."
- ),
- raises=TypeError,
- )
- def test_series_constructor_coerce_data_to_extension_dtype_raises():
- xpr = (
- "Cannot cast data to extension dtype 'decimal'. Pass the "
- "extension array directly."
- )
- with pytest.raises(ValueError, match=xpr):
- pd.Series([0, 1, 2], dtype=DecimalDtype())
- def test_series_constructor_with_dtype():
- arr = DecimalArray([decimal.Decimal("10.0")])
- result = pd.Series(arr, dtype=DecimalDtype())
- expected = pd.Series(arr)
- tm.assert_series_equal(result, expected)
- result = pd.Series(arr, dtype="int64")
- expected = pd.Series([10])
- tm.assert_series_equal(result, expected)
- def test_dataframe_constructor_with_dtype():
- arr = DecimalArray([decimal.Decimal("10.0")])
- result = pd.DataFrame({"A": arr}, dtype=DecimalDtype())
- expected = pd.DataFrame({"A": arr})
- tm.assert_frame_equal(result, expected)
- arr = DecimalArray([decimal.Decimal("10.0")])
- result = pd.DataFrame({"A": arr}, dtype="int64")
- expected = pd.DataFrame({"A": [10]})
- tm.assert_frame_equal(result, expected)
- @pytest.mark.parametrize("frame", [True, False])
- def test_astype_dispatches(frame):
- # This is a dtype-specific test that ensures Series[decimal].astype
- # gets all the way through to ExtensionArray.astype
- # Designing a reliable smoke test that works for arbitrary data types
- # is difficult.
- data = pd.Series(DecimalArray([decimal.Decimal(2)]), name="a")
- ctx = decimal.Context()
- ctx.prec = 5
- if frame:
- data = data.to_frame()
- result = data.astype(DecimalDtype(ctx))
- if frame:
- result = result["a"]
- assert result.dtype.context.prec == ctx.prec
- class TestArithmeticOps(base.BaseArithmeticOpsTests):
- def check_opname(self, s, op_name, other, exc=None):
- super().check_opname(s, op_name, other, exc=None)
- def test_arith_series_with_array(self, data, all_arithmetic_operators):
- op_name = all_arithmetic_operators
- s = pd.Series(data)
- context = decimal.getcontext()
- divbyzerotrap = context.traps[decimal.DivisionByZero]
- invalidoptrap = context.traps[decimal.InvalidOperation]
- context.traps[decimal.DivisionByZero] = 0
- context.traps[decimal.InvalidOperation] = 0
- # Decimal supports ops with int, but not float
- other = pd.Series([int(d * 100) for d in data])
- self.check_opname(s, op_name, other)
- if "mod" not in op_name:
- self.check_opname(s, op_name, s * 2)
- self.check_opname(s, op_name, 0)
- self.check_opname(s, op_name, 5)
- context.traps[decimal.DivisionByZero] = divbyzerotrap
- context.traps[decimal.InvalidOperation] = invalidoptrap
- def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
- # We implement divmod
- super()._check_divmod_op(s, op, other, exc=None)
- class TestComparisonOps(base.BaseComparisonOpsTests):
- def test_compare_scalar(self, data, comparison_op):
- s = pd.Series(data)
- self._compare_other(s, data, comparison_op, 0.5)
- def test_compare_array(self, data, comparison_op):
- s = pd.Series(data)
- alter = np.random.choice([-1, 0, 1], len(data))
- # Randomly double, halve or keep same value
- other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter]
- self._compare_other(s, data, comparison_op, other)
- class DecimalArrayWithoutFromSequence(DecimalArray):
- """Helper class for testing error handling in _from_sequence."""
- @classmethod
- def _from_sequence(cls, scalars, dtype=None, copy=False):
- raise KeyError("For the test")
- class DecimalArrayWithoutCoercion(DecimalArrayWithoutFromSequence):
- @classmethod
- def _create_arithmetic_method(cls, op):
- return cls._create_method(op, coerce_to_dtype=False)
- DecimalArrayWithoutCoercion._add_arithmetic_ops()
- def test_combine_from_sequence_raises(monkeypatch):
- # https://github.com/pandas-dev/pandas/issues/22850
- cls = DecimalArrayWithoutFromSequence
- @classmethod
- def construct_array_type(cls):
- return DecimalArrayWithoutFromSequence
- monkeypatch.setattr(DecimalDtype, "construct_array_type", construct_array_type)
- arr = cls([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
- ser = pd.Series(arr)
- result = ser.combine(ser, operator.add)
- # note: object dtype
- expected = pd.Series(
- [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object"
- )
- tm.assert_series_equal(result, expected)
- @pytest.mark.parametrize(
- "class_", [DecimalArrayWithoutFromSequence, DecimalArrayWithoutCoercion]
- )
- def test_scalar_ops_from_sequence_raises(class_):
- # op(EA, EA) should return an EA, or an ndarray if it's not possible
- # to return an EA with the return values.
- arr = class_([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
- result = arr + arr
- expected = np.array(
- [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object"
- )
- tm.assert_numpy_array_equal(result, expected)
- @pytest.mark.parametrize(
- "reverse, expected_div, expected_mod",
- [(False, [0, 1, 1, 2], [1, 0, 1, 0]), (True, [2, 1, 0, 0], [0, 0, 2, 2])],
- )
- def test_divmod_array(reverse, expected_div, expected_mod):
- # https://github.com/pandas-dev/pandas/issues/22930
- arr = to_decimal([1, 2, 3, 4])
- if reverse:
- div, mod = divmod(2, arr)
- else:
- div, mod = divmod(arr, 2)
- expected_div = to_decimal(expected_div)
- expected_mod = to_decimal(expected_mod)
- tm.assert_extension_array_equal(div, expected_div)
- tm.assert_extension_array_equal(mod, expected_mod)
- def test_ufunc_fallback(data):
- a = data[:5]
- s = pd.Series(a, index=range(3, 8))
- result = np.abs(s)
- expected = pd.Series(np.abs(a), index=range(3, 8))
- tm.assert_series_equal(result, expected)
- def test_array_ufunc():
- a = to_decimal([1, 2, 3])
- result = np.exp(a)
- expected = to_decimal(np.exp(a._data))
- tm.assert_extension_array_equal(result, expected)
- def test_array_ufunc_series():
- a = to_decimal([1, 2, 3])
- s = pd.Series(a)
- result = np.exp(s)
- expected = pd.Series(to_decimal(np.exp(a._data)))
- tm.assert_series_equal(result, expected)
- def test_array_ufunc_series_scalar_other():
- # check _HANDLED_TYPES
- a = to_decimal([1, 2, 3])
- s = pd.Series(a)
- result = np.add(s, decimal.Decimal(1))
- expected = pd.Series(np.add(a, decimal.Decimal(1)))
- tm.assert_series_equal(result, expected)
- def test_array_ufunc_series_defer():
- a = to_decimal([1, 2, 3])
- s = pd.Series(a)
- expected = pd.Series(to_decimal([2, 4, 6]))
- r1 = np.add(s, a)
- r2 = np.add(a, s)
- tm.assert_series_equal(r1, expected)
- tm.assert_series_equal(r2, expected)
- def test_groupby_agg():
- # Ensure that the result of agg is inferred to be decimal dtype
- # https://github.com/pandas-dev/pandas/issues/29141
- data = make_data()[:5]
- df = pd.DataFrame(
- {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)}
- )
- # single key, selected column
- expected = pd.Series(to_decimal([data[0], data[3]]))
- result = df.groupby("id1")["decimals"].agg(lambda x: x.iloc[0])
- tm.assert_series_equal(result, expected, check_names=False)
- result = df["decimals"].groupby(df["id1"]).agg(lambda x: x.iloc[0])
- tm.assert_series_equal(result, expected, check_names=False)
- # multiple keys, selected column
- expected = pd.Series(
- to_decimal([data[0], data[1], data[3]]),
- index=pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 1)]),
- )
- result = df.groupby(["id1", "id2"])["decimals"].agg(lambda x: x.iloc[0])
- tm.assert_series_equal(result, expected, check_names=False)
- result = df["decimals"].groupby([df["id1"], df["id2"]]).agg(lambda x: x.iloc[0])
- tm.assert_series_equal(result, expected, check_names=False)
- # multiple columns
- expected = pd.DataFrame({"id2": [0, 1], "decimals": to_decimal([data[0], data[3]])})
- result = df.groupby("id1").agg(lambda x: x.iloc[0])
- tm.assert_frame_equal(result, expected, check_names=False)
- def test_groupby_agg_ea_method(monkeypatch):
- # Ensure that the result of agg is inferred to be decimal dtype
- # https://github.com/pandas-dev/pandas/issues/29141
- def DecimalArray__my_sum(self):
- return np.sum(np.array(self))
- monkeypatch.setattr(DecimalArray, "my_sum", DecimalArray__my_sum, raising=False)
- data = make_data()[:5]
- df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)})
- expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]]))
- result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum())
- tm.assert_series_equal(result, expected, check_names=False)
- s = pd.Series(DecimalArray(data))
- grouper = np.array([0, 0, 0, 1, 1], dtype=np.int64)
- result = s.groupby(grouper).agg(lambda x: x.values.my_sum())
- tm.assert_series_equal(result, expected, check_names=False)
- def test_indexing_no_materialize(monkeypatch):
- # See https://github.com/pandas-dev/pandas/issues/29708
- # Ensure that indexing operations do not materialize (convert to a numpy
- # array) the ExtensionArray unnecessary
- def DecimalArray__array__(self, dtype=None):
- raise Exception("tried to convert a DecimalArray to a numpy array")
- monkeypatch.setattr(DecimalArray, "__array__", DecimalArray__array__, raising=False)
- data = make_data()
- s = pd.Series(DecimalArray(data))
- df = pd.DataFrame({"a": s, "b": range(len(s))})
- # ensure the following operations do not raise an error
- s[s > 0.5]
- df[s > 0.5]
- s.at[0]
- df.at[0, "a"]
- def test_to_numpy_keyword():
- # test the extra keyword
- values = [decimal.Decimal("1.1111"), decimal.Decimal("2.2222")]
- expected = np.array(
- [decimal.Decimal("1.11"), decimal.Decimal("2.22")], dtype="object"
- )
- a = pd.array(values, dtype="decimal")
- result = a.to_numpy(decimals=2)
- tm.assert_numpy_array_equal(result, expected)
- result = pd.Series(a).to_numpy(decimals=2)
- tm.assert_numpy_array_equal(result, expected)
- def test_array_copy_on_write(using_copy_on_write):
- df = pd.DataFrame({"a": [decimal.Decimal(2), decimal.Decimal(3)]}, dtype="object")
- df2 = df.astype(DecimalDtype())
- df.iloc[0, 0] = 0
- if using_copy_on_write:
- expected = pd.DataFrame(
- {"a": [decimal.Decimal(2), decimal.Decimal(3)]}, dtype=DecimalDtype()
- )
- tm.assert_equal(df2.values, expected.values)
|