test_nlargest.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. """
  2. Note: for naming purposes, most tests are title with as e.g. "test_nlargest_foo"
  3. but are implicitly also testing nsmallest_foo.
  4. """
  5. from string import ascii_lowercase
  6. import numpy as np
  7. import pytest
  8. import pandas as pd
  9. import pandas._testing as tm
  10. from pandas.util.version import Version
  11. @pytest.fixture
  12. def df_duplicates():
  13. return pd.DataFrame(
  14. {"a": [1, 2, 3, 4, 4], "b": [1, 1, 1, 1, 1], "c": [0, 1, 2, 5, 4]},
  15. index=[0, 0, 1, 1, 1],
  16. )
  17. @pytest.fixture
  18. def df_strings():
  19. return pd.DataFrame(
  20. {
  21. "a": np.random.permutation(10),
  22. "b": list(ascii_lowercase[:10]),
  23. "c": np.random.permutation(10).astype("float64"),
  24. }
  25. )
  26. @pytest.fixture
  27. def df_main_dtypes():
  28. return pd.DataFrame(
  29. {
  30. "group": [1, 1, 2],
  31. "int": [1, 2, 3],
  32. "float": [4.0, 5.0, 6.0],
  33. "string": list("abc"),
  34. "category_string": pd.Series(list("abc")).astype("category"),
  35. "category_int": [7, 8, 9],
  36. "datetime": pd.date_range("20130101", periods=3),
  37. "datetimetz": pd.date_range("20130101", periods=3, tz="US/Eastern"),
  38. "timedelta": pd.timedelta_range("1 s", periods=3, freq="s"),
  39. },
  40. columns=[
  41. "group",
  42. "int",
  43. "float",
  44. "string",
  45. "category_string",
  46. "category_int",
  47. "datetime",
  48. "datetimetz",
  49. "timedelta",
  50. ],
  51. )
  52. class TestNLargestNSmallest:
  53. # ----------------------------------------------------------------------
  54. # Top / bottom
  55. @pytest.mark.parametrize(
  56. "order",
  57. [
  58. ["a"],
  59. ["c"],
  60. ["a", "b"],
  61. ["a", "c"],
  62. ["b", "a"],
  63. ["b", "c"],
  64. ["a", "b", "c"],
  65. ["c", "a", "b"],
  66. ["c", "b", "a"],
  67. ["b", "c", "a"],
  68. ["b", "a", "c"],
  69. # dups!
  70. ["b", "c", "c"],
  71. ],
  72. )
  73. @pytest.mark.parametrize("n", range(1, 11))
  74. def test_nlargest_n(self, df_strings, nselect_method, n, order):
  75. # GH#10393
  76. df = df_strings
  77. if "b" in order:
  78. error_msg = (
  79. f"Column 'b' has dtype object, "
  80. f"cannot use method '{nselect_method}' with this dtype"
  81. )
  82. with pytest.raises(TypeError, match=error_msg):
  83. getattr(df, nselect_method)(n, order)
  84. else:
  85. ascending = nselect_method == "nsmallest"
  86. result = getattr(df, nselect_method)(n, order)
  87. expected = df.sort_values(order, ascending=ascending).head(n)
  88. tm.assert_frame_equal(result, expected)
  89. @pytest.mark.parametrize(
  90. "columns", [["group", "category_string"], ["group", "string"]]
  91. )
  92. def test_nlargest_error(self, df_main_dtypes, nselect_method, columns):
  93. df = df_main_dtypes
  94. col = columns[1]
  95. error_msg = (
  96. f"Column '{col}' has dtype {df[col].dtype}, "
  97. f"cannot use method '{nselect_method}' with this dtype"
  98. )
  99. # escape some characters that may be in the repr
  100. error_msg = (
  101. error_msg.replace("(", "\\(")
  102. .replace(")", "\\)")
  103. .replace("[", "\\[")
  104. .replace("]", "\\]")
  105. )
  106. with pytest.raises(TypeError, match=error_msg):
  107. getattr(df, nselect_method)(2, columns)
  108. def test_nlargest_all_dtypes(self, df_main_dtypes):
  109. df = df_main_dtypes
  110. df.nsmallest(2, list(set(df) - {"category_string", "string"}))
  111. df.nlargest(2, list(set(df) - {"category_string", "string"}))
  112. def test_nlargest_duplicates_on_starter_columns(self):
  113. # regression test for GH#22752
  114. df = pd.DataFrame({"a": [2, 2, 2, 1, 1, 1], "b": [1, 2, 3, 3, 2, 1]})
  115. result = df.nlargest(4, columns=["a", "b"])
  116. expected = pd.DataFrame(
  117. {"a": [2, 2, 2, 1], "b": [3, 2, 1, 3]}, index=[2, 1, 0, 3]
  118. )
  119. tm.assert_frame_equal(result, expected)
  120. result = df.nsmallest(4, columns=["a", "b"])
  121. expected = pd.DataFrame(
  122. {"a": [1, 1, 1, 2], "b": [1, 2, 3, 1]}, index=[5, 4, 3, 0]
  123. )
  124. tm.assert_frame_equal(result, expected)
  125. def test_nlargest_n_identical_values(self):
  126. # GH#15297
  127. df = pd.DataFrame({"a": [1] * 5, "b": [1, 2, 3, 4, 5]})
  128. result = df.nlargest(3, "a")
  129. expected = pd.DataFrame({"a": [1] * 3, "b": [1, 2, 3]}, index=[0, 1, 2])
  130. tm.assert_frame_equal(result, expected)
  131. result = df.nsmallest(3, "a")
  132. expected = pd.DataFrame({"a": [1] * 3, "b": [1, 2, 3]})
  133. tm.assert_frame_equal(result, expected)
  134. @pytest.mark.parametrize(
  135. "order",
  136. [["a", "b", "c"], ["c", "b", "a"], ["a"], ["b"], ["a", "b"], ["c", "b"]],
  137. )
  138. @pytest.mark.parametrize("n", range(1, 6))
  139. def test_nlargest_n_duplicate_index(self, df_duplicates, n, order, request):
  140. # GH#13412
  141. df = df_duplicates
  142. result = df.nsmallest(n, order)
  143. expected = df.sort_values(order).head(n)
  144. tm.assert_frame_equal(result, expected)
  145. result = df.nlargest(n, order)
  146. expected = df.sort_values(order, ascending=False).head(n)
  147. if Version(np.__version__) >= Version("1.25") and (
  148. (order == ["a"] and n in (1, 2, 3, 4)) or (order == ["a", "b"]) and n == 5
  149. ):
  150. request.node.add_marker(
  151. pytest.mark.xfail(
  152. reason=(
  153. "pandas default unstable sorting of duplicates"
  154. "issue with numpy>=1.25 with AVX instructions"
  155. ),
  156. strict=False,
  157. )
  158. )
  159. tm.assert_frame_equal(result, expected)
  160. def test_nlargest_duplicate_keep_all_ties(self):
  161. # GH#16818
  162. df = pd.DataFrame(
  163. {"a": [5, 4, 4, 2, 3, 3, 3, 3], "b": [10, 9, 8, 7, 5, 50, 10, 20]}
  164. )
  165. result = df.nlargest(4, "a", keep="all")
  166. expected = pd.DataFrame(
  167. {
  168. "a": {0: 5, 1: 4, 2: 4, 4: 3, 5: 3, 6: 3, 7: 3},
  169. "b": {0: 10, 1: 9, 2: 8, 4: 5, 5: 50, 6: 10, 7: 20},
  170. }
  171. )
  172. tm.assert_frame_equal(result, expected)
  173. result = df.nsmallest(2, "a", keep="all")
  174. expected = pd.DataFrame(
  175. {
  176. "a": {3: 2, 4: 3, 5: 3, 6: 3, 7: 3},
  177. "b": {3: 7, 4: 5, 5: 50, 6: 10, 7: 20},
  178. }
  179. )
  180. tm.assert_frame_equal(result, expected)
  181. def test_nlargest_multiindex_column_lookup(self):
  182. # Check whether tuples are correctly treated as multi-level lookups.
  183. # GH#23033
  184. df = pd.DataFrame(
  185. columns=pd.MultiIndex.from_product([["x"], ["a", "b"]]),
  186. data=[[0.33, 0.13], [0.86, 0.25], [0.25, 0.70], [0.85, 0.91]],
  187. )
  188. # nsmallest
  189. result = df.nsmallest(3, ("x", "a"))
  190. expected = df.iloc[[2, 0, 3]]
  191. tm.assert_frame_equal(result, expected)
  192. # nlargest
  193. result = df.nlargest(3, ("x", "b"))
  194. expected = df.iloc[[3, 2, 1]]
  195. tm.assert_frame_equal(result, expected)
  196. def test_nlargest_nan(self):
  197. # GH#43060
  198. df = pd.DataFrame([np.nan, np.nan, 0, 1, 2, 3])
  199. result = df.nlargest(5, 0)
  200. expected = df.sort_values(0, ascending=False).head(5)
  201. tm.assert_frame_equal(result, expected)
  202. def test_nsmallest_nan_after_n_element(self):
  203. # GH#46589
  204. df = pd.DataFrame(
  205. {
  206. "a": [1, 2, 3, 4, 5, None, 7],
  207. "b": [7, 6, 5, 4, 3, 2, 1],
  208. "c": [1, 1, 2, 2, 3, 3, 3],
  209. },
  210. index=range(7),
  211. )
  212. result = df.nsmallest(5, columns=["a", "b"])
  213. expected = pd.DataFrame(
  214. {
  215. "a": [1, 2, 3, 4, 5],
  216. "b": [7, 6, 5, 4, 3],
  217. "c": [1, 1, 2, 2, 3],
  218. },
  219. index=range(5),
  220. ).astype({"a": "float"})
  221. tm.assert_frame_equal(result, expected)