test_nlargest.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. """
  2. Note: for naming purposes, most tests are title with as e.g. "test_nlargest_foo"
  3. but are implicitly also testing nsmallest_foo.
  4. """
  5. from itertools import product
  6. import numpy as np
  7. import pytest
  8. import pandas as pd
  9. from pandas import Series
  10. import pandas._testing as tm
  11. main_dtypes = [
  12. "datetime",
  13. "datetimetz",
  14. "timedelta",
  15. "int8",
  16. "int16",
  17. "int32",
  18. "int64",
  19. "float32",
  20. "float64",
  21. "uint8",
  22. "uint16",
  23. "uint32",
  24. "uint64",
  25. ]
  26. @pytest.fixture
  27. def s_main_dtypes():
  28. """
  29. A DataFrame with many dtypes
  30. * datetime
  31. * datetimetz
  32. * timedelta
  33. * [u]int{8,16,32,64}
  34. * float{32,64}
  35. The columns are the name of the dtype.
  36. """
  37. df = pd.DataFrame(
  38. {
  39. "datetime": pd.to_datetime(["2003", "2002", "2001", "2002", "2005"]),
  40. "datetimetz": pd.to_datetime(
  41. ["2003", "2002", "2001", "2002", "2005"]
  42. ).tz_localize("US/Eastern"),
  43. "timedelta": pd.to_timedelta(["3d", "2d", "1d", "2d", "5d"]),
  44. }
  45. )
  46. for dtype in [
  47. "int8",
  48. "int16",
  49. "int32",
  50. "int64",
  51. "float32",
  52. "float64",
  53. "uint8",
  54. "uint16",
  55. "uint32",
  56. "uint64",
  57. ]:
  58. df[dtype] = Series([3, 2, 1, 2, 5], dtype=dtype)
  59. return df
  60. @pytest.fixture(params=main_dtypes)
  61. def s_main_dtypes_split(request, s_main_dtypes):
  62. """Each series in s_main_dtypes."""
  63. return s_main_dtypes[request.param]
  64. def assert_check_nselect_boundary(vals, dtype, method):
  65. # helper function for 'test_boundary_{dtype}' tests
  66. ser = Series(vals, dtype=dtype)
  67. result = getattr(ser, method)(3)
  68. expected_idxr = [0, 1, 2] if method == "nsmallest" else [3, 2, 1]
  69. expected = ser.loc[expected_idxr]
  70. tm.assert_series_equal(result, expected)
  71. class TestSeriesNLargestNSmallest:
  72. @pytest.mark.parametrize(
  73. "r",
  74. [
  75. Series([3.0, 2, 1, 2, "5"], dtype="object"),
  76. Series([3.0, 2, 1, 2, 5], dtype="object"),
  77. # not supported on some archs
  78. # Series([3., 2, 1, 2, 5], dtype='complex256'),
  79. Series([3.0, 2, 1, 2, 5], dtype="complex128"),
  80. Series(list("abcde")),
  81. Series(list("abcde"), dtype="category"),
  82. ],
  83. )
  84. def test_nlargest_error(self, r):
  85. dt = r.dtype
  86. msg = f"Cannot use method 'n(largest|smallest)' with dtype {dt}"
  87. args = 2, len(r), 0, -1
  88. methods = r.nlargest, r.nsmallest
  89. for method, arg in product(methods, args):
  90. with pytest.raises(TypeError, match=msg):
  91. method(arg)
  92. def test_nsmallest_nlargest(self, s_main_dtypes_split):
  93. # float, int, datetime64 (use i8), timedelts64 (same),
  94. # object that are numbers, object that are strings
  95. ser = s_main_dtypes_split
  96. tm.assert_series_equal(ser.nsmallest(2), ser.iloc[[2, 1]])
  97. tm.assert_series_equal(ser.nsmallest(2, keep="last"), ser.iloc[[2, 3]])
  98. empty = ser.iloc[0:0]
  99. tm.assert_series_equal(ser.nsmallest(0), empty)
  100. tm.assert_series_equal(ser.nsmallest(-1), empty)
  101. tm.assert_series_equal(ser.nlargest(0), empty)
  102. tm.assert_series_equal(ser.nlargest(-1), empty)
  103. tm.assert_series_equal(ser.nsmallest(len(ser)), ser.sort_values())
  104. tm.assert_series_equal(ser.nsmallest(len(ser) + 1), ser.sort_values())
  105. tm.assert_series_equal(ser.nlargest(len(ser)), ser.iloc[[4, 0, 1, 3, 2]])
  106. tm.assert_series_equal(ser.nlargest(len(ser) + 1), ser.iloc[[4, 0, 1, 3, 2]])
  107. def test_nlargest_misc(self):
  108. ser = Series([3.0, np.nan, 1, 2, 5])
  109. result = ser.nlargest()
  110. expected = ser.iloc[[4, 0, 3, 2, 1]]
  111. tm.assert_series_equal(result, expected)
  112. result = ser.nsmallest()
  113. expected = ser.iloc[[2, 3, 0, 4, 1]]
  114. tm.assert_series_equal(result, expected)
  115. msg = 'keep must be either "first", "last"'
  116. with pytest.raises(ValueError, match=msg):
  117. ser.nsmallest(keep="invalid")
  118. with pytest.raises(ValueError, match=msg):
  119. ser.nlargest(keep="invalid")
  120. # GH#15297
  121. ser = Series([1] * 5, index=[1, 2, 3, 4, 5])
  122. expected_first = Series([1] * 3, index=[1, 2, 3])
  123. expected_last = Series([1] * 3, index=[5, 4, 3])
  124. result = ser.nsmallest(3)
  125. tm.assert_series_equal(result, expected_first)
  126. result = ser.nsmallest(3, keep="last")
  127. tm.assert_series_equal(result, expected_last)
  128. result = ser.nlargest(3)
  129. tm.assert_series_equal(result, expected_first)
  130. result = ser.nlargest(3, keep="last")
  131. tm.assert_series_equal(result, expected_last)
  132. @pytest.mark.parametrize("n", range(1, 5))
  133. def test_nlargest_n(self, n):
  134. # GH 13412
  135. ser = Series([1, 4, 3, 2], index=[0, 0, 1, 1])
  136. result = ser.nlargest(n)
  137. expected = ser.sort_values(ascending=False).head(n)
  138. tm.assert_series_equal(result, expected)
  139. result = ser.nsmallest(n)
  140. expected = ser.sort_values().head(n)
  141. tm.assert_series_equal(result, expected)
  142. def test_nlargest_boundary_integer(self, nselect_method, any_int_numpy_dtype):
  143. # GH#21426
  144. dtype_info = np.iinfo(any_int_numpy_dtype)
  145. min_val, max_val = dtype_info.min, dtype_info.max
  146. vals = [min_val, min_val + 1, max_val - 1, max_val]
  147. assert_check_nselect_boundary(vals, any_int_numpy_dtype, nselect_method)
  148. def test_nlargest_boundary_float(self, nselect_method, float_numpy_dtype):
  149. # GH#21426
  150. dtype_info = np.finfo(float_numpy_dtype)
  151. min_val, max_val = dtype_info.min, dtype_info.max
  152. min_2nd, max_2nd = np.nextafter([min_val, max_val], 0, dtype=float_numpy_dtype)
  153. vals = [min_val, min_2nd, max_2nd, max_val]
  154. assert_check_nselect_boundary(vals, float_numpy_dtype, nselect_method)
  155. @pytest.mark.parametrize("dtype", ["datetime64[ns]", "timedelta64[ns]"])
  156. def test_nlargest_boundary_datetimelike(self, nselect_method, dtype):
  157. # GH#21426
  158. # use int64 bounds and +1 to min_val since true minimum is NaT
  159. # (include min_val/NaT at end to maintain same expected_idxr)
  160. dtype_info = np.iinfo("int64")
  161. min_val, max_val = dtype_info.min, dtype_info.max
  162. vals = [min_val + 1, min_val + 2, max_val - 1, max_val, min_val]
  163. assert_check_nselect_boundary(vals, dtype, nselect_method)
  164. def test_nlargest_duplicate_keep_all_ties(self):
  165. # see GH#16818
  166. ser = Series([10, 9, 8, 7, 7, 7, 7, 6])
  167. result = ser.nlargest(4, keep="all")
  168. expected = Series([10, 9, 8, 7, 7, 7, 7])
  169. tm.assert_series_equal(result, expected)
  170. result = ser.nsmallest(2, keep="all")
  171. expected = Series([6, 7, 7, 7, 7], index=[7, 3, 4, 5, 6])
  172. tm.assert_series_equal(result, expected)
  173. @pytest.mark.parametrize(
  174. "data,expected", [([True, False], [True]), ([True, False, True, True], [True])]
  175. )
  176. def test_nlargest_boolean(self, data, expected):
  177. # GH#26154 : ensure True > False
  178. ser = Series(data)
  179. result = ser.nlargest(1)
  180. expected = Series(expected)
  181. tm.assert_series_equal(result, expected)
  182. def test_nlargest_nullable(self, any_numeric_ea_dtype):
  183. # GH#42816
  184. dtype = any_numeric_ea_dtype
  185. if dtype.startswith("UInt"):
  186. # Can't cast from negative float to uint on some platforms
  187. arr = np.random.randint(1, 10, 10)
  188. else:
  189. arr = np.random.randn(10)
  190. arr = arr.astype(dtype.lower(), copy=False)
  191. ser = Series(arr.copy(), dtype=dtype)
  192. ser[1] = pd.NA
  193. result = ser.nlargest(5)
  194. expected = (
  195. Series(np.delete(arr, 1), index=ser.index.delete(1))
  196. .nlargest(5)
  197. .astype(dtype)
  198. )
  199. tm.assert_series_equal(result, expected)
  200. def test_nsmallest_nan_when_keep_is_all(self):
  201. # GH#46589
  202. s = Series([1, 2, 3, 3, 3, None])
  203. result = s.nsmallest(3, keep="all")
  204. expected = Series([1.0, 2.0, 3.0, 3.0, 3.0])
  205. tm.assert_series_equal(result, expected)
  206. s = Series([1, 2, None, None, None])
  207. result = s.nsmallest(3, keep="all")
  208. expected = Series([1, 2, None, None, None])
  209. tm.assert_series_equal(result, expected)