_mixins.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. from __future__ import annotations
  2. from functools import wraps
  3. from typing import (
  4. TYPE_CHECKING,
  5. Any,
  6. Literal,
  7. Sequence,
  8. TypeVar,
  9. cast,
  10. overload,
  11. )
  12. import numpy as np
  13. from pandas._libs import lib
  14. from pandas._libs.arrays import NDArrayBacked
  15. from pandas._typing import (
  16. ArrayLike,
  17. AxisInt,
  18. Dtype,
  19. F,
  20. PositionalIndexer2D,
  21. PositionalIndexerTuple,
  22. ScalarIndexer,
  23. SequenceIndexer,
  24. Shape,
  25. TakeIndexer,
  26. npt,
  27. type_t,
  28. )
  29. from pandas.errors import AbstractMethodError
  30. from pandas.util._decorators import doc
  31. from pandas.util._validators import (
  32. validate_bool_kwarg,
  33. validate_fillna_kwargs,
  34. validate_insert_loc,
  35. )
  36. from pandas.core.dtypes.common import (
  37. is_dtype_equal,
  38. pandas_dtype,
  39. )
  40. from pandas.core.dtypes.dtypes import (
  41. DatetimeTZDtype,
  42. ExtensionDtype,
  43. PeriodDtype,
  44. )
  45. from pandas.core.dtypes.missing import array_equivalent
  46. from pandas.core import missing
  47. from pandas.core.algorithms import (
  48. take,
  49. unique,
  50. value_counts,
  51. )
  52. from pandas.core.array_algos.quantile import quantile_with_mask
  53. from pandas.core.array_algos.transforms import shift
  54. from pandas.core.arrays.base import ExtensionArray
  55. from pandas.core.construction import extract_array
  56. from pandas.core.indexers import check_array_indexer
  57. from pandas.core.sorting import nargminmax
  58. NDArrayBackedExtensionArrayT = TypeVar(
  59. "NDArrayBackedExtensionArrayT", bound="NDArrayBackedExtensionArray"
  60. )
  61. if TYPE_CHECKING:
  62. from pandas._typing import (
  63. NumpySorter,
  64. NumpyValueArrayLike,
  65. )
  66. from pandas import Series
  67. def ravel_compat(meth: F) -> F:
  68. """
  69. Decorator to ravel a 2D array before passing it to a cython operation,
  70. then reshape the result to our own shape.
  71. """
  72. @wraps(meth)
  73. def method(self, *args, **kwargs):
  74. if self.ndim == 1:
  75. return meth(self, *args, **kwargs)
  76. flags = self._ndarray.flags
  77. flat = self.ravel("K")
  78. result = meth(flat, *args, **kwargs)
  79. order = "F" if flags.f_contiguous else "C"
  80. return result.reshape(self.shape, order=order)
  81. return cast(F, method)
  82. class NDArrayBackedExtensionArray(NDArrayBacked, ExtensionArray):
  83. """
  84. ExtensionArray that is backed by a single NumPy ndarray.
  85. """
  86. _ndarray: np.ndarray
  87. # scalar used to denote NA value inside our self._ndarray, e.g. -1
  88. # for Categorical, iNaT for Period. Outside of object dtype,
  89. # self.isna() should be exactly locations in self._ndarray with
  90. # _internal_fill_value.
  91. _internal_fill_value: Any
  92. def _box_func(self, x):
  93. """
  94. Wrap numpy type in our dtype.type if necessary.
  95. """
  96. return x
  97. def _validate_scalar(self, value):
  98. # used by NDArrayBackedExtensionIndex.insert
  99. raise AbstractMethodError(self)
  100. # ------------------------------------------------------------------------
  101. def view(self, dtype: Dtype | None = None) -> ArrayLike:
  102. # We handle datetime64, datetime64tz, timedelta64, and period
  103. # dtypes here. Everything else we pass through to the underlying
  104. # ndarray.
  105. if dtype is None or dtype is self.dtype:
  106. return self._from_backing_data(self._ndarray)
  107. if isinstance(dtype, type):
  108. # we sometimes pass non-dtype objects, e.g np.ndarray;
  109. # pass those through to the underlying ndarray
  110. return self._ndarray.view(dtype)
  111. dtype = pandas_dtype(dtype)
  112. arr = self._ndarray
  113. if isinstance(dtype, (PeriodDtype, DatetimeTZDtype)):
  114. cls = dtype.construct_array_type()
  115. return cls(arr.view("i8"), dtype=dtype)
  116. elif dtype == "M8[ns]":
  117. from pandas.core.arrays import DatetimeArray
  118. return DatetimeArray(arr.view("i8"), dtype=dtype)
  119. elif dtype == "m8[ns]":
  120. from pandas.core.arrays import TimedeltaArray
  121. return TimedeltaArray(arr.view("i8"), dtype=dtype)
  122. # error: Argument "dtype" to "view" of "_ArrayOrScalarCommon" has incompatible
  123. # type "Union[ExtensionDtype, dtype[Any]]"; expected "Union[dtype[Any], None,
  124. # type, _SupportsDType, str, Union[Tuple[Any, int], Tuple[Any, Union[int,
  125. # Sequence[int]]], List[Any], _DTypeDict, Tuple[Any, Any]]]"
  126. return arr.view(dtype=dtype) # type: ignore[arg-type]
  127. def take(
  128. self: NDArrayBackedExtensionArrayT,
  129. indices: TakeIndexer,
  130. *,
  131. allow_fill: bool = False,
  132. fill_value: Any = None,
  133. axis: AxisInt = 0,
  134. ) -> NDArrayBackedExtensionArrayT:
  135. if allow_fill:
  136. fill_value = self._validate_scalar(fill_value)
  137. new_data = take(
  138. self._ndarray,
  139. indices,
  140. allow_fill=allow_fill,
  141. fill_value=fill_value,
  142. axis=axis,
  143. )
  144. return self._from_backing_data(new_data)
  145. # ------------------------------------------------------------------------
  146. def equals(self, other) -> bool:
  147. if type(self) is not type(other):
  148. return False
  149. if not is_dtype_equal(self.dtype, other.dtype):
  150. return False
  151. return bool(array_equivalent(self._ndarray, other._ndarray))
  152. @classmethod
  153. def _from_factorized(cls, values, original):
  154. assert values.dtype == original._ndarray.dtype
  155. return original._from_backing_data(values)
  156. def _values_for_argsort(self) -> np.ndarray:
  157. return self._ndarray
  158. def _values_for_factorize(self):
  159. return self._ndarray, self._internal_fill_value
  160. # Signature of "argmin" incompatible with supertype "ExtensionArray"
  161. def argmin(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
  162. # override base class by adding axis keyword
  163. validate_bool_kwarg(skipna, "skipna")
  164. if not skipna and self._hasna:
  165. raise NotImplementedError
  166. return nargminmax(self, "argmin", axis=axis)
  167. # Signature of "argmax" incompatible with supertype "ExtensionArray"
  168. def argmax(self, axis: AxisInt = 0, skipna: bool = True): # type: ignore[override]
  169. # override base class by adding axis keyword
  170. validate_bool_kwarg(skipna, "skipna")
  171. if not skipna and self._hasna:
  172. raise NotImplementedError
  173. return nargminmax(self, "argmax", axis=axis)
  174. def unique(self: NDArrayBackedExtensionArrayT) -> NDArrayBackedExtensionArrayT:
  175. new_data = unique(self._ndarray)
  176. return self._from_backing_data(new_data)
  177. @classmethod
  178. @doc(ExtensionArray._concat_same_type)
  179. def _concat_same_type(
  180. cls: type[NDArrayBackedExtensionArrayT],
  181. to_concat: Sequence[NDArrayBackedExtensionArrayT],
  182. axis: AxisInt = 0,
  183. ) -> NDArrayBackedExtensionArrayT:
  184. dtypes = {str(x.dtype) for x in to_concat}
  185. if len(dtypes) != 1:
  186. raise ValueError("to_concat must have the same dtype (tz)", dtypes)
  187. new_values = [x._ndarray for x in to_concat]
  188. new_arr = np.concatenate(new_values, axis=axis)
  189. return to_concat[0]._from_backing_data(new_arr)
  190. @doc(ExtensionArray.searchsorted)
  191. def searchsorted(
  192. self,
  193. value: NumpyValueArrayLike | ExtensionArray,
  194. side: Literal["left", "right"] = "left",
  195. sorter: NumpySorter = None,
  196. ) -> npt.NDArray[np.intp] | np.intp:
  197. npvalue = self._validate_setitem_value(value)
  198. return self._ndarray.searchsorted(npvalue, side=side, sorter=sorter)
  199. @doc(ExtensionArray.shift)
  200. def shift(self, periods: int = 1, fill_value=None, axis: AxisInt = 0):
  201. fill_value = self._validate_scalar(fill_value)
  202. new_values = shift(self._ndarray, periods, axis, fill_value)
  203. return self._from_backing_data(new_values)
  204. def __setitem__(self, key, value) -> None:
  205. key = check_array_indexer(self, key)
  206. value = self._validate_setitem_value(value)
  207. self._ndarray[key] = value
  208. def _validate_setitem_value(self, value):
  209. return value
  210. @overload
  211. def __getitem__(self, key: ScalarIndexer) -> Any:
  212. ...
  213. @overload
  214. def __getitem__(
  215. self: NDArrayBackedExtensionArrayT,
  216. key: SequenceIndexer | PositionalIndexerTuple,
  217. ) -> NDArrayBackedExtensionArrayT:
  218. ...
  219. def __getitem__(
  220. self: NDArrayBackedExtensionArrayT,
  221. key: PositionalIndexer2D,
  222. ) -> NDArrayBackedExtensionArrayT | Any:
  223. if lib.is_integer(key):
  224. # fast-path
  225. result = self._ndarray[key]
  226. if self.ndim == 1:
  227. return self._box_func(result)
  228. return self._from_backing_data(result)
  229. # error: Incompatible types in assignment (expression has type "ExtensionArray",
  230. # variable has type "Union[int, slice, ndarray]")
  231. key = extract_array(key, extract_numpy=True) # type: ignore[assignment]
  232. key = check_array_indexer(self, key)
  233. result = self._ndarray[key]
  234. if lib.is_scalar(result):
  235. return self._box_func(result)
  236. result = self._from_backing_data(result)
  237. return result
  238. def _fill_mask_inplace(
  239. self, method: str, limit, mask: npt.NDArray[np.bool_]
  240. ) -> None:
  241. # (for now) when self.ndim == 2, we assume axis=0
  242. func = missing.get_fill_func(method, ndim=self.ndim)
  243. func(self._ndarray.T, limit=limit, mask=mask.T)
  244. @doc(ExtensionArray.fillna)
  245. def fillna(
  246. self: NDArrayBackedExtensionArrayT, value=None, method=None, limit=None
  247. ) -> NDArrayBackedExtensionArrayT:
  248. value, method = validate_fillna_kwargs(
  249. value, method, validate_scalar_dict_value=False
  250. )
  251. mask = self.isna()
  252. # error: Argument 2 to "check_value_size" has incompatible type
  253. # "ExtensionArray"; expected "ndarray"
  254. value = missing.check_value_size(
  255. value, mask, len(self) # type: ignore[arg-type]
  256. )
  257. if mask.any():
  258. if method is not None:
  259. # TODO: check value is None
  260. # (for now) when self.ndim == 2, we assume axis=0
  261. func = missing.get_fill_func(method, ndim=self.ndim)
  262. npvalues = self._ndarray.T.copy()
  263. func(npvalues, limit=limit, mask=mask.T)
  264. npvalues = npvalues.T
  265. # TODO: PandasArray didn't used to copy, need tests for this
  266. new_values = self._from_backing_data(npvalues)
  267. else:
  268. # fill with value
  269. new_values = self.copy()
  270. new_values[mask] = value
  271. else:
  272. # We validate the fill_value even if there is nothing to fill
  273. if value is not None:
  274. self._validate_setitem_value(value)
  275. new_values = self.copy()
  276. return new_values
  277. # ------------------------------------------------------------------------
  278. # Reductions
  279. def _wrap_reduction_result(self, axis: AxisInt | None, result):
  280. if axis is None or self.ndim == 1:
  281. return self._box_func(result)
  282. return self._from_backing_data(result)
  283. # ------------------------------------------------------------------------
  284. # __array_function__ methods
  285. def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
  286. """
  287. Analogue to np.putmask(self, mask, value)
  288. Parameters
  289. ----------
  290. mask : np.ndarray[bool]
  291. value : scalar or listlike
  292. Raises
  293. ------
  294. TypeError
  295. If value cannot be cast to self.dtype.
  296. """
  297. value = self._validate_setitem_value(value)
  298. np.putmask(self._ndarray, mask, value)
  299. def _where(
  300. self: NDArrayBackedExtensionArrayT, mask: npt.NDArray[np.bool_], value
  301. ) -> NDArrayBackedExtensionArrayT:
  302. """
  303. Analogue to np.where(mask, self, value)
  304. Parameters
  305. ----------
  306. mask : np.ndarray[bool]
  307. value : scalar or listlike
  308. Raises
  309. ------
  310. TypeError
  311. If value cannot be cast to self.dtype.
  312. """
  313. value = self._validate_setitem_value(value)
  314. res_values = np.where(mask, self._ndarray, value)
  315. return self._from_backing_data(res_values)
  316. # ------------------------------------------------------------------------
  317. # Index compat methods
  318. def insert(
  319. self: NDArrayBackedExtensionArrayT, loc: int, item
  320. ) -> NDArrayBackedExtensionArrayT:
  321. """
  322. Make new ExtensionArray inserting new item at location. Follows
  323. Python list.append semantics for negative values.
  324. Parameters
  325. ----------
  326. loc : int
  327. item : object
  328. Returns
  329. -------
  330. type(self)
  331. """
  332. loc = validate_insert_loc(loc, len(self))
  333. code = self._validate_scalar(item)
  334. new_vals = np.concatenate(
  335. (
  336. self._ndarray[:loc],
  337. np.asarray([code], dtype=self._ndarray.dtype),
  338. self._ndarray[loc:],
  339. )
  340. )
  341. return self._from_backing_data(new_vals)
  342. # ------------------------------------------------------------------------
  343. # Additional array methods
  344. # These are not part of the EA API, but we implement them because
  345. # pandas assumes they're there.
  346. def value_counts(self, dropna: bool = True) -> Series:
  347. """
  348. Return a Series containing counts of unique values.
  349. Parameters
  350. ----------
  351. dropna : bool, default True
  352. Don't include counts of NA values.
  353. Returns
  354. -------
  355. Series
  356. """
  357. if self.ndim != 1:
  358. raise NotImplementedError
  359. from pandas import (
  360. Index,
  361. Series,
  362. )
  363. if dropna:
  364. # error: Unsupported operand type for ~ ("ExtensionArray")
  365. values = self[~self.isna()]._ndarray # type: ignore[operator]
  366. else:
  367. values = self._ndarray
  368. result = value_counts(values, sort=False, dropna=dropna)
  369. index_arr = self._from_backing_data(np.asarray(result.index._data))
  370. index = Index(index_arr, name=result.index.name)
  371. return Series(result._values, index=index, name=result.name, copy=False)
  372. def _quantile(
  373. self: NDArrayBackedExtensionArrayT,
  374. qs: npt.NDArray[np.float64],
  375. interpolation: str,
  376. ) -> NDArrayBackedExtensionArrayT:
  377. # TODO: disable for Categorical if not ordered?
  378. mask = np.asarray(self.isna())
  379. arr = self._ndarray
  380. fill_value = self._internal_fill_value
  381. res_values = quantile_with_mask(arr, mask, fill_value, qs, interpolation)
  382. res_values = self._cast_quantile_result(res_values)
  383. return self._from_backing_data(res_values)
  384. # TODO: see if we can share this with other dispatch-wrapping methods
  385. def _cast_quantile_result(self, res_values: np.ndarray) -> np.ndarray:
  386. """
  387. Cast the result of quantile_with_mask to an appropriate dtype
  388. to pass to _from_backing_data in _quantile.
  389. """
  390. return res_values
  391. # ------------------------------------------------------------------------
  392. # numpy-like methods
  393. @classmethod
  394. def _empty(
  395. cls: type_t[NDArrayBackedExtensionArrayT], shape: Shape, dtype: ExtensionDtype
  396. ) -> NDArrayBackedExtensionArrayT:
  397. """
  398. Analogous to np.empty(shape, dtype=dtype)
  399. Parameters
  400. ----------
  401. shape : tuple[int]
  402. dtype : ExtensionDtype
  403. """
  404. # The base implementation uses a naive approach to find the dtype
  405. # for the backing ndarray
  406. arr = cls._from_sequence([], dtype=dtype)
  407. backing = np.empty(shape, dtype=arr._ndarray.dtype)
  408. return arr._from_backing_data(backing)