123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262 |
- """
- Implementation of nlargest and nsmallest.
- """
- from __future__ import annotations
- from typing import (
- TYPE_CHECKING,
- Hashable,
- Sequence,
- cast,
- final,
- )
- import numpy as np
- from pandas._libs import algos as libalgos
- from pandas._typing import (
- DtypeObj,
- IndexLabel,
- )
- from pandas.core.dtypes.common import (
- is_bool_dtype,
- is_complex_dtype,
- is_integer_dtype,
- is_list_like,
- is_numeric_dtype,
- needs_i8_conversion,
- )
- from pandas.core.dtypes.dtypes import BaseMaskedDtype
- if TYPE_CHECKING:
- from pandas import (
- DataFrame,
- Series,
- )
- class SelectN:
- def __init__(self, obj, n: int, keep: str) -> None:
- self.obj = obj
- self.n = n
- self.keep = keep
- if self.keep not in ("first", "last", "all"):
- raise ValueError('keep must be either "first", "last" or "all"')
- def compute(self, method: str) -> DataFrame | Series:
- raise NotImplementedError
- @final
- def nlargest(self):
- return self.compute("nlargest")
- @final
- def nsmallest(self):
- return self.compute("nsmallest")
- @final
- @staticmethod
- def is_valid_dtype_n_method(dtype: DtypeObj) -> bool:
- """
- Helper function to determine if dtype is valid for
- nsmallest/nlargest methods
- """
- if is_numeric_dtype(dtype):
- return not is_complex_dtype(dtype)
- return needs_i8_conversion(dtype)
- class SelectNSeries(SelectN):
- """
- Implement n largest/smallest for Series
- Parameters
- ----------
- obj : Series
- n : int
- keep : {'first', 'last'}, default 'first'
- Returns
- -------
- nordered : Series
- """
- def compute(self, method: str) -> Series:
- from pandas.core.reshape.concat import concat
- n = self.n
- dtype = self.obj.dtype
- if not self.is_valid_dtype_n_method(dtype):
- raise TypeError(f"Cannot use method '{method}' with dtype {dtype}")
- if n <= 0:
- return self.obj[[]]
- dropped = self.obj.dropna()
- nan_index = self.obj.drop(dropped.index)
- # slow method
- if n >= len(self.obj):
- ascending = method == "nsmallest"
- return self.obj.sort_values(ascending=ascending).head(n)
- # fast method
- new_dtype = dropped.dtype
- # Similar to algorithms._ensure_data
- arr = dropped._values
- if needs_i8_conversion(arr.dtype):
- arr = arr.view("i8")
- elif isinstance(arr.dtype, BaseMaskedDtype):
- arr = arr._data
- else:
- arr = np.asarray(arr)
- if arr.dtype.kind == "b":
- arr = arr.view(np.uint8)
- if method == "nlargest":
- arr = -arr
- if is_integer_dtype(new_dtype):
- # GH 21426: ensure reverse ordering at boundaries
- arr -= 1
- elif is_bool_dtype(new_dtype):
- # GH 26154: ensure False is smaller than True
- arr = 1 - (-arr)
- if self.keep == "last":
- arr = arr[::-1]
- nbase = n
- narr = len(arr)
- n = min(n, narr)
- # arr passed into kth_smallest must be contiguous. We copy
- # here because kth_smallest will modify its input
- kth_val = libalgos.kth_smallest(arr.copy(order="C"), n - 1)
- (ns,) = np.nonzero(arr <= kth_val)
- inds = ns[arr[ns].argsort(kind="mergesort")]
- if self.keep != "all":
- inds = inds[:n]
- findex = nbase
- else:
- if len(inds) < nbase <= len(nan_index) + len(inds):
- findex = len(nan_index) + len(inds)
- else:
- findex = len(inds)
- if self.keep == "last":
- # reverse indices
- inds = narr - 1 - inds
- return concat([dropped.iloc[inds], nan_index]).iloc[:findex]
- class SelectNFrame(SelectN):
- """
- Implement n largest/smallest for DataFrame
- Parameters
- ----------
- obj : DataFrame
- n : int
- keep : {'first', 'last'}, default 'first'
- columns : list or str
- Returns
- -------
- nordered : DataFrame
- """
- def __init__(self, obj: DataFrame, n: int, keep: str, columns: IndexLabel) -> None:
- super().__init__(obj, n, keep)
- if not is_list_like(columns) or isinstance(columns, tuple):
- columns = [columns]
- columns = cast(Sequence[Hashable], columns)
- columns = list(columns)
- self.columns = columns
- def compute(self, method: str) -> DataFrame:
- from pandas.core.api import Index
- n = self.n
- frame = self.obj
- columns = self.columns
- for column in columns:
- dtype = frame[column].dtype
- if not self.is_valid_dtype_n_method(dtype):
- raise TypeError(
- f"Column {repr(column)} has dtype {dtype}, "
- f"cannot use method {repr(method)} with this dtype"
- )
- def get_indexer(current_indexer, other_indexer):
- """
- Helper function to concat `current_indexer` and `other_indexer`
- depending on `method`
- """
- if method == "nsmallest":
- return current_indexer.append(other_indexer)
- else:
- return other_indexer.append(current_indexer)
- # Below we save and reset the index in case index contains duplicates
- original_index = frame.index
- cur_frame = frame = frame.reset_index(drop=True)
- cur_n = n
- indexer = Index([], dtype=np.int64)
- for i, column in enumerate(columns):
- # For each column we apply method to cur_frame[column].
- # If it's the last column or if we have the number of
- # results desired we are done.
- # Otherwise there are duplicates of the largest/smallest
- # value and we need to look at the rest of the columns
- # to determine which of the rows with the largest/smallest
- # value in the column to keep.
- series = cur_frame[column]
- is_last_column = len(columns) - 1 == i
- values = getattr(series, method)(
- cur_n, keep=self.keep if is_last_column else "all"
- )
- if is_last_column or len(values) <= cur_n:
- indexer = get_indexer(indexer, values.index)
- break
- # Now find all values which are equal to
- # the (nsmallest: largest)/(nlargest: smallest)
- # from our series.
- border_value = values == values[values.index[-1]]
- # Some of these values are among the top-n
- # some aren't.
- unsafe_values = values[border_value]
- # These values are definitely among the top-n
- safe_values = values[~border_value]
- indexer = get_indexer(indexer, safe_values.index)
- # Go on and separate the unsafe_values on the remaining
- # columns.
- cur_frame = cur_frame.loc[unsafe_values.index]
- cur_n = n - len(indexer)
- frame = frame.take(indexer)
- # Restore the index on frame
- frame.index = original_index.take(indexer)
- # If there is only one column, the frame is already sorted.
- if len(columns) == 1:
- return frame
- ascending = method == "nsmallest"
- return frame.sort_values(columns, ascending=ascending, kind="mergesort")
|