selectn.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. """
  2. Implementation of nlargest and nsmallest.
  3. """
  4. from __future__ import annotations
  5. from typing import (
  6. TYPE_CHECKING,
  7. Hashable,
  8. Sequence,
  9. cast,
  10. final,
  11. )
  12. import numpy as np
  13. from pandas._libs import algos as libalgos
  14. from pandas._typing import (
  15. DtypeObj,
  16. IndexLabel,
  17. )
  18. from pandas.core.dtypes.common import (
  19. is_bool_dtype,
  20. is_complex_dtype,
  21. is_integer_dtype,
  22. is_list_like,
  23. is_numeric_dtype,
  24. needs_i8_conversion,
  25. )
  26. from pandas.core.dtypes.dtypes import BaseMaskedDtype
  27. if TYPE_CHECKING:
  28. from pandas import (
  29. DataFrame,
  30. Series,
  31. )
  32. class SelectN:
  33. def __init__(self, obj, n: int, keep: str) -> None:
  34. self.obj = obj
  35. self.n = n
  36. self.keep = keep
  37. if self.keep not in ("first", "last", "all"):
  38. raise ValueError('keep must be either "first", "last" or "all"')
  39. def compute(self, method: str) -> DataFrame | Series:
  40. raise NotImplementedError
  41. @final
  42. def nlargest(self):
  43. return self.compute("nlargest")
  44. @final
  45. def nsmallest(self):
  46. return self.compute("nsmallest")
  47. @final
  48. @staticmethod
  49. def is_valid_dtype_n_method(dtype: DtypeObj) -> bool:
  50. """
  51. Helper function to determine if dtype is valid for
  52. nsmallest/nlargest methods
  53. """
  54. if is_numeric_dtype(dtype):
  55. return not is_complex_dtype(dtype)
  56. return needs_i8_conversion(dtype)
  57. class SelectNSeries(SelectN):
  58. """
  59. Implement n largest/smallest for Series
  60. Parameters
  61. ----------
  62. obj : Series
  63. n : int
  64. keep : {'first', 'last'}, default 'first'
  65. Returns
  66. -------
  67. nordered : Series
  68. """
  69. def compute(self, method: str) -> Series:
  70. from pandas.core.reshape.concat import concat
  71. n = self.n
  72. dtype = self.obj.dtype
  73. if not self.is_valid_dtype_n_method(dtype):
  74. raise TypeError(f"Cannot use method '{method}' with dtype {dtype}")
  75. if n <= 0:
  76. return self.obj[[]]
  77. dropped = self.obj.dropna()
  78. nan_index = self.obj.drop(dropped.index)
  79. # slow method
  80. if n >= len(self.obj):
  81. ascending = method == "nsmallest"
  82. return self.obj.sort_values(ascending=ascending).head(n)
  83. # fast method
  84. new_dtype = dropped.dtype
  85. # Similar to algorithms._ensure_data
  86. arr = dropped._values
  87. if needs_i8_conversion(arr.dtype):
  88. arr = arr.view("i8")
  89. elif isinstance(arr.dtype, BaseMaskedDtype):
  90. arr = arr._data
  91. else:
  92. arr = np.asarray(arr)
  93. if arr.dtype.kind == "b":
  94. arr = arr.view(np.uint8)
  95. if method == "nlargest":
  96. arr = -arr
  97. if is_integer_dtype(new_dtype):
  98. # GH 21426: ensure reverse ordering at boundaries
  99. arr -= 1
  100. elif is_bool_dtype(new_dtype):
  101. # GH 26154: ensure False is smaller than True
  102. arr = 1 - (-arr)
  103. if self.keep == "last":
  104. arr = arr[::-1]
  105. nbase = n
  106. narr = len(arr)
  107. n = min(n, narr)
  108. # arr passed into kth_smallest must be contiguous. We copy
  109. # here because kth_smallest will modify its input
  110. kth_val = libalgos.kth_smallest(arr.copy(order="C"), n - 1)
  111. (ns,) = np.nonzero(arr <= kth_val)
  112. inds = ns[arr[ns].argsort(kind="mergesort")]
  113. if self.keep != "all":
  114. inds = inds[:n]
  115. findex = nbase
  116. else:
  117. if len(inds) < nbase <= len(nan_index) + len(inds):
  118. findex = len(nan_index) + len(inds)
  119. else:
  120. findex = len(inds)
  121. if self.keep == "last":
  122. # reverse indices
  123. inds = narr - 1 - inds
  124. return concat([dropped.iloc[inds], nan_index]).iloc[:findex]
  125. class SelectNFrame(SelectN):
  126. """
  127. Implement n largest/smallest for DataFrame
  128. Parameters
  129. ----------
  130. obj : DataFrame
  131. n : int
  132. keep : {'first', 'last'}, default 'first'
  133. columns : list or str
  134. Returns
  135. -------
  136. nordered : DataFrame
  137. """
  138. def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel) -> None:
  139. super().__init__(obj, n, keep)
  140. if not is_list_like(columns) or isinstance(columns, tuple):
  141. columns = [columns]
  142. columns = cast(Sequence[Hashable], columns)
  143. columns = list(columns)
  144. self.columns = columns
  145. def compute(self, method: str) -> DataFrame:
  146. from pandas.core.api import Index
  147. n = self.n
  148. frame = self.obj
  149. columns = self.columns
  150. for column in columns:
  151. dtype = frame[column].dtype
  152. if not self.is_valid_dtype_n_method(dtype):
  153. raise TypeError(
  154. f"Column {repr(column)} has dtype {dtype}, "
  155. f"cannot use method {repr(method)} with this dtype"
  156. )
  157. def get_indexer(current_indexer, other_indexer):
  158. """
  159. Helper function to concat `current_indexer` and `other_indexer`
  160. depending on `method`
  161. """
  162. if method == "nsmallest":
  163. return current_indexer.append(other_indexer)
  164. else:
  165. return other_indexer.append(current_indexer)
  166. # Below we save and reset the index in case index contains duplicates
  167. original_index = frame.index
  168. cur_frame = frame = frame.reset_index(drop=True)
  169. cur_n = n
  170. indexer = Index([], dtype=np.int64)
  171. for i, column in enumerate(columns):
  172. # For each column we apply method to cur_frame[column].
  173. # If it's the last column or if we have the number of
  174. # results desired we are done.
  175. # Otherwise there are duplicates of the largest/smallest
  176. # value and we need to look at the rest of the columns
  177. # to determine which of the rows with the largest/smallest
  178. # value in the column to keep.
  179. series = cur_frame[column]
  180. is_last_column = len(columns) - 1 == i
  181. values = getattr(series, method)(
  182. cur_n, keep=self.keep if is_last_column else "all"
  183. )
  184. if is_last_column or len(values) <= cur_n:
  185. indexer = get_indexer(indexer, values.index)
  186. break
  187. # Now find all values which are equal to
  188. # the (nsmallest: largest)/(nlargest: smallest)
  189. # from our series.
  190. border_value = values == values[values.index[-1]]
  191. # Some of these values are among the top-n
  192. # some aren't.
  193. unsafe_values = values[border_value]
  194. # These values are definitely among the top-n
  195. safe_values = values[~border_value]
  196. indexer = get_indexer(indexer, safe_values.index)
  197. # Go on and separate the unsafe_values on the remaining
  198. # columns.
  199. cur_frame = cur_frame.loc[unsafe_values.index]
  200. cur_n = n - len(indexer)
  201. frame = frame.take(indexer)
  202. # Restore the index on frame
  203. frame.index = original_index.take(indexer)
  204. # If there is only one column, the frame is already sorted.
  205. if len(columns) == 1:
  206. return frame
  207. ascending = method == "nsmallest"
  208. return frame.sort_values(columns, ascending=ascending, kind="mergesort")