test_decimal.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import decimal
  2. import operator
  3. import numpy as np
  4. import pytest
  5. import pandas as pd
  6. import pandas._testing as tm
  7. from pandas.api.types import infer_dtype
  8. from pandas.tests.extension import base
  9. from pandas.tests.extension.decimal.array import (
  10. DecimalArray,
  11. DecimalDtype,
  12. make_data,
  13. to_decimal,
  14. )
  15. @pytest.fixture
  16. def dtype():
  17. return DecimalDtype()
  18. @pytest.fixture
  19. def data():
  20. return DecimalArray(make_data())
  21. @pytest.fixture
  22. def data_for_twos():
  23. return DecimalArray([decimal.Decimal(2) for _ in range(100)])
  24. @pytest.fixture
  25. def data_missing():
  26. return DecimalArray([decimal.Decimal("NaN"), decimal.Decimal(1)])
  27. @pytest.fixture
  28. def data_for_sorting():
  29. return DecimalArray(
  30. [decimal.Decimal("1"), decimal.Decimal("2"), decimal.Decimal("0")]
  31. )
  32. @pytest.fixture
  33. def data_missing_for_sorting():
  34. return DecimalArray(
  35. [decimal.Decimal("1"), decimal.Decimal("NaN"), decimal.Decimal("0")]
  36. )
  37. @pytest.fixture
  38. def na_cmp():
  39. return lambda x, y: x.is_nan() and y.is_nan()
  40. @pytest.fixture
  41. def na_value():
  42. return decimal.Decimal("NaN")
  43. @pytest.fixture
  44. def data_for_grouping():
  45. b = decimal.Decimal("1.0")
  46. a = decimal.Decimal("0.0")
  47. c = decimal.Decimal("2.0")
  48. na = decimal.Decimal("NaN")
  49. return DecimalArray([b, b, na, na, a, a, b, c])
  50. class TestDtype(base.BaseDtypeTests):
  51. def test_hashable(self, dtype):
  52. pass
  53. @pytest.mark.parametrize("skipna", [True, False])
  54. def test_infer_dtype(self, data, data_missing, skipna):
  55. # here overriding base test to ensure we fall back to return
  56. # "unknown-array" for an EA pandas doesn't know
  57. assert infer_dtype(data, skipna=skipna) == "unknown-array"
  58. assert infer_dtype(data_missing, skipna=skipna) == "unknown-array"
  59. class TestInterface(base.BaseInterfaceTests):
  60. pass
  61. class TestConstructors(base.BaseConstructorsTests):
  62. pass
  63. class TestReshaping(base.BaseReshapingTests):
  64. pass
  65. class TestGetitem(base.BaseGetitemTests):
  66. def test_take_na_value_other_decimal(self):
  67. arr = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
  68. result = arr.take([0, -1], allow_fill=True, fill_value=decimal.Decimal("-1.0"))
  69. expected = DecimalArray([decimal.Decimal("1.0"), decimal.Decimal("-1.0")])
  70. self.assert_extension_array_equal(result, expected)
  71. class TestIndex(base.BaseIndexTests):
  72. pass
  73. class TestMissing(base.BaseMissingTests):
  74. pass
  75. class Reduce:
  76. def check_reduce(self, s, op_name, skipna):
  77. if op_name in ["median", "skew", "kurt", "sem"]:
  78. msg = r"decimal does not support the .* operation"
  79. with pytest.raises(NotImplementedError, match=msg):
  80. getattr(s, op_name)(skipna=skipna)
  81. elif op_name == "count":
  82. result = getattr(s, op_name)()
  83. expected = len(s) - s.isna().sum()
  84. tm.assert_almost_equal(result, expected)
  85. else:
  86. result = getattr(s, op_name)(skipna=skipna)
  87. expected = getattr(np.asarray(s), op_name)()
  88. tm.assert_almost_equal(result, expected)
  89. class TestNumericReduce(Reduce, base.BaseNumericReduceTests):
  90. pass
  91. class TestBooleanReduce(Reduce, base.BaseBooleanReduceTests):
  92. pass
  93. class TestMethods(base.BaseMethodsTests):
  94. @pytest.mark.parametrize("dropna", [True, False])
  95. def test_value_counts(self, all_data, dropna, request):
  96. all_data = all_data[:10]
  97. if dropna:
  98. other = np.array(all_data[~all_data.isna()])
  99. else:
  100. other = all_data
  101. vcs = pd.Series(all_data).value_counts(dropna=dropna)
  102. vcs_ex = pd.Series(other).value_counts(dropna=dropna)
  103. with decimal.localcontext() as ctx:
  104. # avoid raising when comparing Decimal("NAN") < Decimal(2)
  105. ctx.traps[decimal.InvalidOperation] = False
  106. result = vcs.sort_index()
  107. expected = vcs_ex.sort_index()
  108. tm.assert_series_equal(result, expected)
  109. class TestCasting(base.BaseCastingTests):
  110. pass
  111. class TestGroupby(base.BaseGroupbyTests):
  112. pass
  113. class TestSetitem(base.BaseSetitemTests):
  114. pass
  115. class TestPrinting(base.BasePrintingTests):
  116. def test_series_repr(self, data):
  117. # Overriding this base test to explicitly test that
  118. # the custom _formatter is used
  119. ser = pd.Series(data)
  120. assert data.dtype.name in repr(ser)
  121. assert "Decimal: " in repr(ser)
  122. @pytest.mark.xfail(
  123. reason=(
  124. "DecimalArray constructor raises bc _from_sequence wants Decimals, not ints."
  125. "Easy to fix, just need to do it."
  126. ),
  127. raises=TypeError,
  128. )
  129. def test_series_constructor_coerce_data_to_extension_dtype_raises():
  130. xpr = (
  131. "Cannot cast data to extension dtype 'decimal'. Pass the "
  132. "extension array directly."
  133. )
  134. with pytest.raises(ValueError, match=xpr):
  135. pd.Series([0, 1, 2], dtype=DecimalDtype())
  136. def test_series_constructor_with_dtype():
  137. arr = DecimalArray([decimal.Decimal("10.0")])
  138. result = pd.Series(arr, dtype=DecimalDtype())
  139. expected = pd.Series(arr)
  140. tm.assert_series_equal(result, expected)
  141. result = pd.Series(arr, dtype="int64")
  142. expected = pd.Series([10])
  143. tm.assert_series_equal(result, expected)
  144. def test_dataframe_constructor_with_dtype():
  145. arr = DecimalArray([decimal.Decimal("10.0")])
  146. result = pd.DataFrame({"A": arr}, dtype=DecimalDtype())
  147. expected = pd.DataFrame({"A": arr})
  148. tm.assert_frame_equal(result, expected)
  149. arr = DecimalArray([decimal.Decimal("10.0")])
  150. result = pd.DataFrame({"A": arr}, dtype="int64")
  151. expected = pd.DataFrame({"A": [10]})
  152. tm.assert_frame_equal(result, expected)
  153. @pytest.mark.parametrize("frame", [True, False])
  154. def test_astype_dispatches(frame):
  155. # This is a dtype-specific test that ensures Series[decimal].astype
  156. # gets all the way through to ExtensionArray.astype
  157. # Designing a reliable smoke test that works for arbitrary data types
  158. # is difficult.
  159. data = pd.Series(DecimalArray([decimal.Decimal(2)]), name="a")
  160. ctx = decimal.Context()
  161. ctx.prec = 5
  162. if frame:
  163. data = data.to_frame()
  164. result = data.astype(DecimalDtype(ctx))
  165. if frame:
  166. result = result["a"]
  167. assert result.dtype.context.prec == ctx.prec
  168. class TestArithmeticOps(base.BaseArithmeticOpsTests):
  169. def check_opname(self, s, op_name, other, exc=None):
  170. super().check_opname(s, op_name, other, exc=None)
  171. def test_arith_series_with_array(self, data, all_arithmetic_operators):
  172. op_name = all_arithmetic_operators
  173. s = pd.Series(data)
  174. context = decimal.getcontext()
  175. divbyzerotrap = context.traps[decimal.DivisionByZero]
  176. invalidoptrap = context.traps[decimal.InvalidOperation]
  177. context.traps[decimal.DivisionByZero] = 0
  178. context.traps[decimal.InvalidOperation] = 0
  179. # Decimal supports ops with int, but not float
  180. other = pd.Series([int(d * 100) for d in data])
  181. self.check_opname(s, op_name, other)
  182. if "mod" not in op_name:
  183. self.check_opname(s, op_name, s * 2)
  184. self.check_opname(s, op_name, 0)
  185. self.check_opname(s, op_name, 5)
  186. context.traps[decimal.DivisionByZero] = divbyzerotrap
  187. context.traps[decimal.InvalidOperation] = invalidoptrap
  188. def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
  189. # We implement divmod
  190. super()._check_divmod_op(s, op, other, exc=None)
  191. class TestComparisonOps(base.BaseComparisonOpsTests):
  192. def test_compare_scalar(self, data, comparison_op):
  193. s = pd.Series(data)
  194. self._compare_other(s, data, comparison_op, 0.5)
  195. def test_compare_array(self, data, comparison_op):
  196. s = pd.Series(data)
  197. alter = np.random.choice([-1, 0, 1], len(data))
  198. # Randomly double, halve or keep same value
  199. other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter]
  200. self._compare_other(s, data, comparison_op, other)
  201. class DecimalArrayWithoutFromSequence(DecimalArray):
  202. """Helper class for testing error handling in _from_sequence."""
  203. @classmethod
  204. def _from_sequence(cls, scalars, dtype=None, copy=False):
  205. raise KeyError("For the test")
  206. class DecimalArrayWithoutCoercion(DecimalArrayWithoutFromSequence):
  207. @classmethod
  208. def _create_arithmetic_method(cls, op):
  209. return cls._create_method(op, coerce_to_dtype=False)
  210. DecimalArrayWithoutCoercion._add_arithmetic_ops()
  211. def test_combine_from_sequence_raises(monkeypatch):
  212. # https://github.com/pandas-dev/pandas/issues/22850
  213. cls = DecimalArrayWithoutFromSequence
  214. @classmethod
  215. def construct_array_type(cls):
  216. return DecimalArrayWithoutFromSequence
  217. monkeypatch.setattr(DecimalDtype, "construct_array_type", construct_array_type)
  218. arr = cls([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
  219. ser = pd.Series(arr)
  220. result = ser.combine(ser, operator.add)
  221. # note: object dtype
  222. expected = pd.Series(
  223. [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object"
  224. )
  225. tm.assert_series_equal(result, expected)
  226. @pytest.mark.parametrize(
  227. "class_", [DecimalArrayWithoutFromSequence, DecimalArrayWithoutCoercion]
  228. )
  229. def test_scalar_ops_from_sequence_raises(class_):
  230. # op(EA, EA) should return an EA, or an ndarray if it's not possible
  231. # to return an EA with the return values.
  232. arr = class_([decimal.Decimal("1.0"), decimal.Decimal("2.0")])
  233. result = arr + arr
  234. expected = np.array(
  235. [decimal.Decimal("2.0"), decimal.Decimal("4.0")], dtype="object"
  236. )
  237. tm.assert_numpy_array_equal(result, expected)
  238. @pytest.mark.parametrize(
  239. "reverse, expected_div, expected_mod",
  240. [(False, [0, 1, 1, 2], [1, 0, 1, 0]), (True, [2, 1, 0, 0], [0, 0, 2, 2])],
  241. )
  242. def test_divmod_array(reverse, expected_div, expected_mod):
  243. # https://github.com/pandas-dev/pandas/issues/22930
  244. arr = to_decimal([1, 2, 3, 4])
  245. if reverse:
  246. div, mod = divmod(2, arr)
  247. else:
  248. div, mod = divmod(arr, 2)
  249. expected_div = to_decimal(expected_div)
  250. expected_mod = to_decimal(expected_mod)
  251. tm.assert_extension_array_equal(div, expected_div)
  252. tm.assert_extension_array_equal(mod, expected_mod)
  253. def test_ufunc_fallback(data):
  254. a = data[:5]
  255. s = pd.Series(a, index=range(3, 8))
  256. result = np.abs(s)
  257. expected = pd.Series(np.abs(a), index=range(3, 8))
  258. tm.assert_series_equal(result, expected)
  259. def test_array_ufunc():
  260. a = to_decimal([1, 2, 3])
  261. result = np.exp(a)
  262. expected = to_decimal(np.exp(a._data))
  263. tm.assert_extension_array_equal(result, expected)
  264. def test_array_ufunc_series():
  265. a = to_decimal([1, 2, 3])
  266. s = pd.Series(a)
  267. result = np.exp(s)
  268. expected = pd.Series(to_decimal(np.exp(a._data)))
  269. tm.assert_series_equal(result, expected)
  270. def test_array_ufunc_series_scalar_other():
  271. # check _HANDLED_TYPES
  272. a = to_decimal([1, 2, 3])
  273. s = pd.Series(a)
  274. result = np.add(s, decimal.Decimal(1))
  275. expected = pd.Series(np.add(a, decimal.Decimal(1)))
  276. tm.assert_series_equal(result, expected)
  277. def test_array_ufunc_series_defer():
  278. a = to_decimal([1, 2, 3])
  279. s = pd.Series(a)
  280. expected = pd.Series(to_decimal([2, 4, 6]))
  281. r1 = np.add(s, a)
  282. r2 = np.add(a, s)
  283. tm.assert_series_equal(r1, expected)
  284. tm.assert_series_equal(r2, expected)
  285. def test_groupby_agg():
  286. # Ensure that the result of agg is inferred to be decimal dtype
  287. # https://github.com/pandas-dev/pandas/issues/29141
  288. data = make_data()[:5]
  289. df = pd.DataFrame(
  290. {"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)}
  291. )
  292. # single key, selected column
  293. expected = pd.Series(to_decimal([data[0], data[3]]))
  294. result = df.groupby("id1")["decimals"].agg(lambda x: x.iloc[0])
  295. tm.assert_series_equal(result, expected, check_names=False)
  296. result = df["decimals"].groupby(df["id1"]).agg(lambda x: x.iloc[0])
  297. tm.assert_series_equal(result, expected, check_names=False)
  298. # multiple keys, selected column
  299. expected = pd.Series(
  300. to_decimal([data[0], data[1], data[3]]),
  301. index=pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 1)]),
  302. )
  303. result = df.groupby(["id1", "id2"])["decimals"].agg(lambda x: x.iloc[0])
  304. tm.assert_series_equal(result, expected, check_names=False)
  305. result = df["decimals"].groupby([df["id1"], df["id2"]]).agg(lambda x: x.iloc[0])
  306. tm.assert_series_equal(result, expected, check_names=False)
  307. # multiple columns
  308. expected = pd.DataFrame({"id2": [0, 1], "decimals": to_decimal([data[0], data[3]])})
  309. result = df.groupby("id1").agg(lambda x: x.iloc[0])
  310. tm.assert_frame_equal(result, expected, check_names=False)
  311. def test_groupby_agg_ea_method(monkeypatch):
  312. # Ensure that the result of agg is inferred to be decimal dtype
  313. # https://github.com/pandas-dev/pandas/issues/29141
  314. def DecimalArray__my_sum(self):
  315. return np.sum(np.array(self))
  316. monkeypatch.setattr(DecimalArray, "my_sum", DecimalArray__my_sum, raising=False)
  317. data = make_data()[:5]
  318. df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)})
  319. expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]]))
  320. result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum())
  321. tm.assert_series_equal(result, expected, check_names=False)
  322. s = pd.Series(DecimalArray(data))
  323. grouper = np.array([0, 0, 0, 1, 1], dtype=np.int64)
  324. result = s.groupby(grouper).agg(lambda x: x.values.my_sum())
  325. tm.assert_series_equal(result, expected, check_names=False)
  326. def test_indexing_no_materialize(monkeypatch):
  327. # See https://github.com/pandas-dev/pandas/issues/29708
  328. # Ensure that indexing operations do not materialize (convert to a numpy
  329. # array) the ExtensionArray unnecessary
  330. def DecimalArray__array__(self, dtype=None):
  331. raise Exception("tried to convert a DecimalArray to a numpy array")
  332. monkeypatch.setattr(DecimalArray, "__array__", DecimalArray__array__, raising=False)
  333. data = make_data()
  334. s = pd.Series(DecimalArray(data))
  335. df = pd.DataFrame({"a": s, "b": range(len(s))})
  336. # ensure the following operations do not raise an error
  337. s[s > 0.5]
  338. df[s > 0.5]
  339. s.at[0]
  340. df.at[0, "a"]
  341. def test_to_numpy_keyword():
  342. # test the extra keyword
  343. values = [decimal.Decimal("1.1111"), decimal.Decimal("2.2222")]
  344. expected = np.array(
  345. [decimal.Decimal("1.11"), decimal.Decimal("2.22")], dtype="object"
  346. )
  347. a = pd.array(values, dtype="decimal")
  348. result = a.to_numpy(decimals=2)
  349. tm.assert_numpy_array_equal(result, expected)
  350. result = pd.Series(a).to_numpy(decimals=2)
  351. tm.assert_numpy_array_equal(result, expected)
  352. def test_array_copy_on_write(using_copy_on_write):
  353. df = pd.DataFrame({"a": [decimal.Decimal(2), decimal.Decimal(3)]}, dtype="object")
  354. df2 = df.astype(DecimalDtype())
  355. df.iloc[0, 0] = 0
  356. if using_copy_on_write:
  357. expected = pd.DataFrame(
  358. {"a": [decimal.Decimal(2), decimal.Decimal(3)]}, dtype=DecimalDtype()
  359. )
  360. tm.assert_equal(df2.values, expected.values)