test_boolean.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. """
  2. This file contains a minimal set of tests for compliance with the extension
  3. array interface test suite, and should contain no other tests.
  4. The test suite for the full functionality of the array is located in
  5. `pandas/tests/arrays/`.
  6. The tests in this file are inherited from the BaseExtensionTests, and only
  7. minimal tweaks should be applied to get the tests passing (by overwriting a
  8. parent method).
  9. Additional tests should either be added to one of the BaseExtensionTests
  10. classes (if they are relevant for the extension interface for all dtypes), or
  11. be added to the array-specific tests in `pandas/tests/arrays/`.
  12. """
  13. import numpy as np
  14. import pytest
  15. from pandas.core.dtypes.common import is_bool_dtype
  16. import pandas as pd
  17. import pandas._testing as tm
  18. from pandas.core.arrays.boolean import BooleanDtype
  19. from pandas.tests.extension import base
  20. def make_data():
  21. return [True, False] * 4 + [np.nan] + [True, False] * 44 + [np.nan] + [True, False]
  22. @pytest.fixture
  23. def dtype():
  24. return BooleanDtype()
  25. @pytest.fixture
  26. def data(dtype):
  27. return pd.array(make_data(), dtype=dtype)
  28. @pytest.fixture
  29. def data_for_twos(dtype):
  30. return pd.array(np.ones(100), dtype=dtype)
  31. @pytest.fixture
  32. def data_missing(dtype):
  33. return pd.array([np.nan, True], dtype=dtype)
  34. @pytest.fixture
  35. def data_for_sorting(dtype):
  36. return pd.array([True, True, False], dtype=dtype)
  37. @pytest.fixture
  38. def data_missing_for_sorting(dtype):
  39. return pd.array([True, np.nan, False], dtype=dtype)
  40. @pytest.fixture
  41. def na_cmp():
  42. # we are pd.NA
  43. return lambda x, y: x is pd.NA and y is pd.NA
  44. @pytest.fixture
  45. def na_value():
  46. return pd.NA
  47. @pytest.fixture
  48. def data_for_grouping(dtype):
  49. b = True
  50. a = False
  51. na = np.nan
  52. return pd.array([b, b, na, na, a, a, b], dtype=dtype)
  53. class TestDtype(base.BaseDtypeTests):
  54. pass
  55. class TestInterface(base.BaseInterfaceTests):
  56. pass
  57. class TestConstructors(base.BaseConstructorsTests):
  58. pass
  59. class TestGetitem(base.BaseGetitemTests):
  60. pass
  61. class TestSetitem(base.BaseSetitemTests):
  62. pass
  63. class TestIndex(base.BaseIndexTests):
  64. pass
  65. class TestMissing(base.BaseMissingTests):
  66. pass
  67. class TestArithmeticOps(base.BaseArithmeticOpsTests):
  68. implements = {"__sub__", "__rsub__"}
  69. def check_opname(self, s, op_name, other, exc=None):
  70. # overwriting to indicate ops don't raise an error
  71. exc = None
  72. if op_name.strip("_").lstrip("r") in ["pow", "truediv", "floordiv"]:
  73. # match behavior with non-masked bool dtype
  74. exc = NotImplementedError
  75. super().check_opname(s, op_name, other, exc=exc)
  76. def _check_op(self, obj, op, other, op_name, exc=NotImplementedError):
  77. if exc is None:
  78. if op_name in self.implements:
  79. msg = r"numpy boolean subtract"
  80. with pytest.raises(TypeError, match=msg):
  81. op(obj, other)
  82. return
  83. result = op(obj, other)
  84. expected = self._combine(obj, other, op)
  85. if op_name in (
  86. "__floordiv__",
  87. "__rfloordiv__",
  88. "__pow__",
  89. "__rpow__",
  90. "__mod__",
  91. "__rmod__",
  92. ):
  93. # combine keeps boolean type
  94. expected = expected.astype("Int8")
  95. elif op_name in ("__truediv__", "__rtruediv__"):
  96. # combine with bools does not generate the correct result
  97. # (numpy behaviour for div is to regard the bools as numeric)
  98. expected = self._combine(obj.astype(float), other, op)
  99. expected = expected.astype("Float64")
  100. if op_name == "__rpow__":
  101. # for rpow, combine does not propagate NaN
  102. expected[result.isna()] = np.nan
  103. self.assert_equal(result, expected)
  104. else:
  105. with pytest.raises(exc):
  106. op(obj, other)
  107. @pytest.mark.xfail(
  108. reason="Inconsistency between floordiv and divmod; we raise for floordiv "
  109. "but not for divmod. This matches what we do for non-masked bool dtype."
  110. )
  111. def test_divmod_series_array(self, data, data_for_twos):
  112. super().test_divmod_series_array(data, data_for_twos)
  113. @pytest.mark.xfail(
  114. reason="Inconsistency between floordiv and divmod; we raise for floordiv "
  115. "but not for divmod. This matches what we do for non-masked bool dtype."
  116. )
  117. def test_divmod(self, data):
  118. super().test_divmod(data)
  119. class TestComparisonOps(base.BaseComparisonOpsTests):
  120. def check_opname(self, s, op_name, other, exc=None):
  121. # overwriting to indicate ops don't raise an error
  122. super().check_opname(s, op_name, other, exc=None)
  123. class TestReshaping(base.BaseReshapingTests):
  124. pass
  125. class TestMethods(base.BaseMethodsTests):
  126. _combine_le_expected_dtype = "boolean"
  127. def test_factorize(self, data_for_grouping):
  128. # override because we only have 2 unique values
  129. labels, uniques = pd.factorize(data_for_grouping, use_na_sentinel=True)
  130. expected_labels = np.array([0, 0, -1, -1, 1, 1, 0], dtype=np.intp)
  131. expected_uniques = data_for_grouping.take([0, 4])
  132. tm.assert_numpy_array_equal(labels, expected_labels)
  133. self.assert_extension_array_equal(uniques, expected_uniques)
  134. def test_searchsorted(self, data_for_sorting, as_series):
  135. # override because we only have 2 unique values
  136. data_for_sorting = pd.array([True, False], dtype="boolean")
  137. b, a = data_for_sorting
  138. arr = type(data_for_sorting)._from_sequence([a, b])
  139. if as_series:
  140. arr = pd.Series(arr)
  141. assert arr.searchsorted(a) == 0
  142. assert arr.searchsorted(a, side="right") == 1
  143. assert arr.searchsorted(b) == 1
  144. assert arr.searchsorted(b, side="right") == 2
  145. result = arr.searchsorted(arr.take([0, 1]))
  146. expected = np.array([0, 1], dtype=np.intp)
  147. tm.assert_numpy_array_equal(result, expected)
  148. # sorter
  149. sorter = np.array([1, 0])
  150. assert data_for_sorting.searchsorted(a, sorter=sorter) == 0
  151. def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting):
  152. # override because there are only 2 unique values
  153. # data_for_sorting -> [B, C, A] with A < B < C -> here True, True, False
  154. assert data_for_sorting.argmax() == 0
  155. assert data_for_sorting.argmin() == 2
  156. # with repeated values -> first occurrence
  157. data = data_for_sorting.take([2, 0, 0, 1, 1, 2])
  158. assert data.argmax() == 1
  159. assert data.argmin() == 0
  160. # with missing values
  161. # data_missing_for_sorting -> [B, NA, A] with A < B and NA missing.
  162. assert data_missing_for_sorting.argmax() == 0
  163. assert data_missing_for_sorting.argmin() == 2
  164. class TestCasting(base.BaseCastingTests):
  165. pass
  166. class TestGroupby(base.BaseGroupbyTests):
  167. """
  168. Groupby-specific tests are overridden because boolean only has 2
  169. unique values, base tests uses 3 groups.
  170. """
  171. def test_grouping_grouper(self, data_for_grouping):
  172. df = pd.DataFrame(
  173. {"A": ["B", "B", None, None, "A", "A", "B"], "B": data_for_grouping}
  174. )
  175. gr1 = df.groupby("A").grouper.groupings[0]
  176. gr2 = df.groupby("B").grouper.groupings[0]
  177. tm.assert_numpy_array_equal(gr1.grouping_vector, df.A.values)
  178. tm.assert_extension_array_equal(gr2.grouping_vector, data_for_grouping)
  179. @pytest.mark.parametrize("as_index", [True, False])
  180. def test_groupby_extension_agg(self, as_index, data_for_grouping):
  181. df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
  182. result = df.groupby("B", as_index=as_index).A.mean()
  183. _, uniques = pd.factorize(data_for_grouping, sort=True)
  184. if as_index:
  185. index = pd.Index(uniques, name="B")
  186. expected = pd.Series([3.0, 1.0], index=index, name="A")
  187. self.assert_series_equal(result, expected)
  188. else:
  189. expected = pd.DataFrame({"B": uniques, "A": [3.0, 1.0]})
  190. self.assert_frame_equal(result, expected)
  191. def test_groupby_agg_extension(self, data_for_grouping):
  192. # GH#38980 groupby agg on extension type fails for non-numeric types
  193. df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
  194. expected = df.iloc[[0, 2, 4]]
  195. expected = expected.set_index("A")
  196. result = df.groupby("A").agg({"B": "first"})
  197. self.assert_frame_equal(result, expected)
  198. result = df.groupby("A").agg("first")
  199. self.assert_frame_equal(result, expected)
  200. result = df.groupby("A").first()
  201. self.assert_frame_equal(result, expected)
  202. def test_groupby_extension_no_sort(self, data_for_grouping):
  203. df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
  204. result = df.groupby("B", sort=False).A.mean()
  205. _, index = pd.factorize(data_for_grouping, sort=False)
  206. index = pd.Index(index, name="B")
  207. expected = pd.Series([1.0, 3.0], index=index, name="A")
  208. self.assert_series_equal(result, expected)
  209. def test_groupby_extension_transform(self, data_for_grouping):
  210. valid = data_for_grouping[~data_for_grouping.isna()]
  211. df = pd.DataFrame({"A": [1, 1, 3, 3, 1], "B": valid})
  212. result = df.groupby("B").A.transform(len)
  213. expected = pd.Series([3, 3, 2, 2, 3], name="A")
  214. self.assert_series_equal(result, expected)
  215. def test_groupby_extension_apply(self, data_for_grouping, groupby_apply_op):
  216. df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
  217. df.groupby("B", group_keys=False).apply(groupby_apply_op)
  218. df.groupby("B", group_keys=False).A.apply(groupby_apply_op)
  219. df.groupby("A", group_keys=False).apply(groupby_apply_op)
  220. df.groupby("A", group_keys=False).B.apply(groupby_apply_op)
  221. def test_groupby_apply_identity(self, data_for_grouping):
  222. df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
  223. result = df.groupby("A").B.apply(lambda x: x.array)
  224. expected = pd.Series(
  225. [
  226. df.B.iloc[[0, 1, 6]].array,
  227. df.B.iloc[[2, 3]].array,
  228. df.B.iloc[[4, 5]].array,
  229. ],
  230. index=pd.Index([1, 2, 3], name="A"),
  231. name="B",
  232. )
  233. self.assert_series_equal(result, expected)
  234. def test_in_numeric_groupby(self, data_for_grouping):
  235. df = pd.DataFrame(
  236. {
  237. "A": [1, 1, 2, 2, 3, 3, 1],
  238. "B": data_for_grouping,
  239. "C": [1, 1, 1, 1, 1, 1, 1],
  240. }
  241. )
  242. result = df.groupby("A").sum().columns
  243. if data_for_grouping.dtype._is_numeric:
  244. expected = pd.Index(["B", "C"])
  245. else:
  246. expected = pd.Index(["C"])
  247. tm.assert_index_equal(result, expected)
  248. @pytest.mark.parametrize("min_count", [0, 10])
  249. def test_groupby_sum_mincount(self, data_for_grouping, min_count):
  250. df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1], "B": data_for_grouping})
  251. result = df.groupby("A").sum(min_count=min_count)
  252. if min_count == 0:
  253. expected = pd.DataFrame(
  254. {"B": pd.array([3, 0, 0], dtype="Int64")},
  255. index=pd.Index([1, 2, 3], name="A"),
  256. )
  257. tm.assert_frame_equal(result, expected)
  258. else:
  259. expected = pd.DataFrame(
  260. {"B": pd.array([pd.NA] * 3, dtype="Int64")},
  261. index=pd.Index([1, 2, 3], name="A"),
  262. )
  263. tm.assert_frame_equal(result, expected)
  264. class TestNumericReduce(base.BaseNumericReduceTests):
  265. def check_reduce(self, s, op_name, skipna):
  266. if op_name == "count":
  267. result = getattr(s, op_name)()
  268. expected = getattr(s.astype("float64"), op_name)()
  269. else:
  270. result = getattr(s, op_name)(skipna=skipna)
  271. expected = getattr(s.astype("float64"), op_name)(skipna=skipna)
  272. # override parent function to cast to bool for min/max
  273. if np.isnan(expected):
  274. expected = pd.NA
  275. elif op_name in ("min", "max"):
  276. expected = bool(expected)
  277. tm.assert_almost_equal(result, expected)
  278. class TestBooleanReduce(base.BaseBooleanReduceTests):
  279. pass
  280. class TestPrinting(base.BasePrintingTests):
  281. pass
  282. class TestUnaryOps(base.BaseUnaryOpsTests):
  283. pass
  284. class TestAccumulation(base.BaseAccumulateTests):
  285. def check_accumulate(self, s, op_name, skipna):
  286. result = getattr(s, op_name)(skipna=skipna)
  287. expected = getattr(pd.Series(s.astype("float64")), op_name)(skipna=skipna)
  288. tm.assert_series_equal(result, expected, check_dtype=False)
  289. if op_name in ("cummin", "cummax"):
  290. assert is_bool_dtype(result)
  291. @pytest.mark.parametrize("skipna", [True, False])
  292. def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
  293. pass
  294. class TestParsing(base.BaseParsingTests):
  295. pass
  296. class Test2DCompat(base.Dim2CompatTests):
  297. pass