ops.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278
  1. """
  2. Provide classes to perform the groupby aggregate operations.
  3. These are not exposed to the user and provide implementations of the grouping
  4. operations, primarily in cython. These classes (BaseGrouper and BinGrouper)
  5. are contained *in* the SeriesGroupBy and DataFrameGroupBy objects.
  6. """
  7. from __future__ import annotations
  8. import collections
  9. import functools
  10. from typing import (
  11. TYPE_CHECKING,
  12. Callable,
  13. Generic,
  14. Hashable,
  15. Iterator,
  16. Sequence,
  17. final,
  18. )
  19. import numpy as np
  20. from pandas._libs import (
  21. NaT,
  22. lib,
  23. )
  24. import pandas._libs.groupby as libgroupby
  25. import pandas._libs.reduction as libreduction
  26. from pandas._typing import (
  27. ArrayLike,
  28. AxisInt,
  29. DtypeObj,
  30. NDFrameT,
  31. Shape,
  32. npt,
  33. )
  34. from pandas.errors import AbstractMethodError
  35. from pandas.util._decorators import cache_readonly
  36. from pandas.core.dtypes.cast import (
  37. maybe_cast_pointwise_result,
  38. maybe_downcast_to_dtype,
  39. )
  40. from pandas.core.dtypes.common import (
  41. ensure_float64,
  42. ensure_int64,
  43. ensure_platform_int,
  44. ensure_uint64,
  45. is_1d_only_ea_dtype,
  46. is_bool_dtype,
  47. is_complex_dtype,
  48. is_datetime64_any_dtype,
  49. is_float_dtype,
  50. is_integer_dtype,
  51. is_numeric_dtype,
  52. is_period_dtype,
  53. is_sparse,
  54. is_timedelta64_dtype,
  55. needs_i8_conversion,
  56. )
  57. from pandas.core.dtypes.dtypes import CategoricalDtype
  58. from pandas.core.dtypes.missing import (
  59. isna,
  60. maybe_fill,
  61. )
  62. from pandas.core.arrays import (
  63. Categorical,
  64. DatetimeArray,
  65. ExtensionArray,
  66. PeriodArray,
  67. TimedeltaArray,
  68. )
  69. from pandas.core.arrays.masked import (
  70. BaseMaskedArray,
  71. BaseMaskedDtype,
  72. )
  73. from pandas.core.arrays.string_ import StringDtype
  74. from pandas.core.frame import DataFrame
  75. from pandas.core.groupby import grouper
  76. from pandas.core.indexes.api import (
  77. CategoricalIndex,
  78. Index,
  79. MultiIndex,
  80. ensure_index,
  81. )
  82. from pandas.core.series import Series
  83. from pandas.core.sorting import (
  84. compress_group_index,
  85. decons_obs_group_ids,
  86. get_flattened_list,
  87. get_group_index,
  88. get_group_index_sorter,
  89. get_indexer_dict,
  90. )
  91. if TYPE_CHECKING:
  92. from pandas.core.generic import NDFrame
  93. class WrappedCythonOp:
  94. """
  95. Dispatch logic for functions defined in _libs.groupby
  96. Parameters
  97. ----------
  98. kind: str
  99. Whether the operation is an aggregate or transform.
  100. how: str
  101. Operation name, e.g. "mean".
  102. has_dropped_na: bool
  103. True precisely when dropna=True and the grouper contains a null value.
  104. """
  105. # Functions for which we do _not_ attempt to cast the cython result
  106. # back to the original dtype.
  107. cast_blocklist = frozenset(["rank", "count", "size", "idxmin", "idxmax"])
  108. def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
  109. self.kind = kind
  110. self.how = how
  111. self.has_dropped_na = has_dropped_na
  112. _CYTHON_FUNCTIONS = {
  113. "aggregate": {
  114. "sum": "group_sum",
  115. "prod": "group_prod",
  116. "min": "group_min",
  117. "max": "group_max",
  118. "mean": "group_mean",
  119. "median": "group_median_float64",
  120. "var": "group_var",
  121. "first": "group_nth",
  122. "last": "group_last",
  123. "ohlc": "group_ohlc",
  124. },
  125. "transform": {
  126. "cumprod": "group_cumprod",
  127. "cumsum": "group_cumsum",
  128. "cummin": "group_cummin",
  129. "cummax": "group_cummax",
  130. "rank": "group_rank",
  131. },
  132. }
  133. _cython_arity = {"ohlc": 4} # OHLC
  134. # Note: we make this a classmethod and pass kind+how so that caching
  135. # works at the class level and not the instance level
  136. @classmethod
  137. @functools.lru_cache(maxsize=None)
  138. def _get_cython_function(
  139. cls, kind: str, how: str, dtype: np.dtype, is_numeric: bool
  140. ):
  141. dtype_str = dtype.name
  142. ftype = cls._CYTHON_FUNCTIONS[kind][how]
  143. # see if there is a fused-type version of function
  144. # only valid for numeric
  145. f = getattr(libgroupby, ftype)
  146. if is_numeric:
  147. return f
  148. elif dtype == np.dtype(object):
  149. if how in ["median", "cumprod"]:
  150. # no fused types -> no __signatures__
  151. raise NotImplementedError(
  152. f"function is not implemented for this dtype: "
  153. f"[how->{how},dtype->{dtype_str}]"
  154. )
  155. if "object" not in f.__signatures__:
  156. # raise NotImplementedError here rather than TypeError later
  157. raise NotImplementedError(
  158. f"function is not implemented for this dtype: "
  159. f"[how->{how},dtype->{dtype_str}]"
  160. )
  161. return f
  162. else:
  163. raise NotImplementedError(
  164. "This should not be reached. Please report a bug at "
  165. "github.com/pandas-dev/pandas/",
  166. dtype,
  167. )
  168. def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
  169. """
  170. Cast numeric dtypes to float64 for functions that only support that.
  171. Parameters
  172. ----------
  173. values : np.ndarray
  174. Returns
  175. -------
  176. values : np.ndarray
  177. """
  178. how = self.how
  179. if how == "median":
  180. # median only has a float64 implementation
  181. # We should only get here with is_numeric, as non-numeric cases
  182. # should raise in _get_cython_function
  183. values = ensure_float64(values)
  184. elif values.dtype.kind in ["i", "u"]:
  185. if how in ["var", "mean"] or (
  186. self.kind == "transform" and self.has_dropped_na
  187. ):
  188. # has_dropped_na check need for test_null_group_str_transformer
  189. # result may still include NaN, so we have to cast
  190. values = ensure_float64(values)
  191. elif how in ["sum", "ohlc", "prod", "cumsum", "cumprod"]:
  192. # Avoid overflow during group op
  193. if values.dtype.kind == "i":
  194. values = ensure_int64(values)
  195. else:
  196. values = ensure_uint64(values)
  197. return values
  198. # TODO: general case implementation overridable by EAs.
  199. def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
  200. """
  201. Check if we can do this operation with our cython functions.
  202. Raises
  203. ------
  204. TypeError
  205. This is not a valid operation for this dtype.
  206. NotImplementedError
  207. This may be a valid operation, but does not have a cython implementation.
  208. """
  209. how = self.how
  210. if is_numeric:
  211. # never an invalid op for those dtypes, so return early as fastpath
  212. return
  213. if isinstance(dtype, CategoricalDtype):
  214. if how in ["sum", "prod", "cumsum", "cumprod"]:
  215. raise TypeError(f"{dtype} type does not support {how} operations")
  216. if how in ["min", "max", "rank"] and not dtype.ordered:
  217. # raise TypeError instead of NotImplementedError to ensure we
  218. # don't go down a group-by-group path, since in the empty-groups
  219. # case that would fail to raise
  220. raise TypeError(f"Cannot perform {how} with non-ordered Categorical")
  221. if how not in ["rank"]:
  222. # only "rank" is implemented in cython
  223. raise NotImplementedError(f"{dtype} dtype not supported")
  224. elif is_sparse(dtype):
  225. raise NotImplementedError(f"{dtype} dtype not supported")
  226. elif is_datetime64_any_dtype(dtype):
  227. # Adding/multiplying datetimes is not valid
  228. if how in ["sum", "prod", "cumsum", "cumprod"]:
  229. raise TypeError(f"datetime64 type does not support {how} operations")
  230. elif is_period_dtype(dtype):
  231. # Adding/multiplying Periods is not valid
  232. if how in ["sum", "prod", "cumsum", "cumprod"]:
  233. raise TypeError(f"Period type does not support {how} operations")
  234. elif is_timedelta64_dtype(dtype):
  235. # timedeltas we can add but not multiply
  236. if how in ["prod", "cumprod"]:
  237. raise TypeError(f"timedelta64 type does not support {how} operations")
  238. def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
  239. how = self.how
  240. kind = self.kind
  241. arity = self._cython_arity.get(how, 1)
  242. out_shape: Shape
  243. if how == "ohlc":
  244. out_shape = (ngroups, arity)
  245. elif arity > 1:
  246. raise NotImplementedError(
  247. "arity of more than 1 is not supported for the 'how' argument"
  248. )
  249. elif kind == "transform":
  250. out_shape = values.shape
  251. else:
  252. out_shape = (ngroups,) + values.shape[1:]
  253. return out_shape
  254. def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
  255. how = self.how
  256. if how == "rank":
  257. out_dtype = "float64"
  258. else:
  259. if is_numeric_dtype(dtype):
  260. out_dtype = f"{dtype.kind}{dtype.itemsize}"
  261. else:
  262. out_dtype = "object"
  263. return np.dtype(out_dtype)
  264. def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
  265. """
  266. Get the desired dtype of a result based on the
  267. input dtype and how it was computed.
  268. Parameters
  269. ----------
  270. dtype : np.dtype
  271. Returns
  272. -------
  273. np.dtype
  274. The desired dtype of the result.
  275. """
  276. how = self.how
  277. if how in ["sum", "cumsum", "sum", "prod", "cumprod"]:
  278. if dtype == np.dtype(bool):
  279. return np.dtype(np.int64)
  280. elif how in ["mean", "median", "var"]:
  281. if is_float_dtype(dtype) or is_complex_dtype(dtype):
  282. return dtype
  283. elif is_numeric_dtype(dtype):
  284. return np.dtype(np.float64)
  285. return dtype
  286. @final
  287. def _ea_wrap_cython_operation(
  288. self,
  289. values: ExtensionArray,
  290. min_count: int,
  291. ngroups: int,
  292. comp_ids: np.ndarray,
  293. **kwargs,
  294. ) -> ArrayLike:
  295. """
  296. If we have an ExtensionArray, unwrap, call _cython_operation, and
  297. re-wrap if appropriate.
  298. """
  299. if isinstance(values, BaseMaskedArray):
  300. return self._masked_ea_wrap_cython_operation(
  301. values,
  302. min_count=min_count,
  303. ngroups=ngroups,
  304. comp_ids=comp_ids,
  305. **kwargs,
  306. )
  307. elif isinstance(values, Categorical):
  308. assert self.how == "rank" # the only one implemented ATM
  309. assert values.ordered # checked earlier
  310. mask = values.isna()
  311. npvalues = values._ndarray
  312. res_values = self._cython_op_ndim_compat(
  313. npvalues,
  314. min_count=min_count,
  315. ngroups=ngroups,
  316. comp_ids=comp_ids,
  317. mask=mask,
  318. **kwargs,
  319. )
  320. # If we ever have more than just "rank" here, we'll need to do
  321. # `if self.how in self.cast_blocklist` like we do for other dtypes.
  322. return res_values
  323. npvalues = self._ea_to_cython_values(values)
  324. res_values = self._cython_op_ndim_compat(
  325. npvalues,
  326. min_count=min_count,
  327. ngroups=ngroups,
  328. comp_ids=comp_ids,
  329. mask=None,
  330. **kwargs,
  331. )
  332. if self.how in self.cast_blocklist:
  333. # i.e. how in ["rank"], since other cast_blocklist methods don't go
  334. # through cython_operation
  335. return res_values
  336. return self._reconstruct_ea_result(values, res_values)
  337. # TODO: general case implementation overridable by EAs.
  338. def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray:
  339. # GH#43682
  340. if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)):
  341. # All of the functions implemented here are ordinal, so we can
  342. # operate on the tz-naive equivalents
  343. npvalues = values._ndarray.view("M8[ns]")
  344. elif isinstance(values.dtype, StringDtype):
  345. # StringArray
  346. npvalues = values.to_numpy(object, na_value=np.nan)
  347. else:
  348. raise NotImplementedError(
  349. f"function is not implemented for this dtype: {values.dtype}"
  350. )
  351. return npvalues
  352. # TODO: general case implementation overridable by EAs.
  353. def _reconstruct_ea_result(
  354. self, values: ExtensionArray, res_values: np.ndarray
  355. ) -> ExtensionArray:
  356. """
  357. Construct an ExtensionArray result from an ndarray result.
  358. """
  359. dtype: BaseMaskedDtype | StringDtype
  360. if isinstance(values.dtype, StringDtype):
  361. dtype = values.dtype
  362. string_array_cls = dtype.construct_array_type()
  363. return string_array_cls._from_sequence(res_values, dtype=dtype)
  364. elif isinstance(values, (DatetimeArray, TimedeltaArray, PeriodArray)):
  365. # In to_cython_values we took a view as M8[ns]
  366. assert res_values.dtype == "M8[ns]"
  367. res_values = res_values.view(values._ndarray.dtype)
  368. return values._from_backing_data(res_values)
  369. raise NotImplementedError
  370. @final
  371. def _masked_ea_wrap_cython_operation(
  372. self,
  373. values: BaseMaskedArray,
  374. min_count: int,
  375. ngroups: int,
  376. comp_ids: np.ndarray,
  377. **kwargs,
  378. ) -> BaseMaskedArray:
  379. """
  380. Equivalent of `_ea_wrap_cython_operation`, but optimized for masked EA's
  381. and cython algorithms which accept a mask.
  382. """
  383. orig_values = values
  384. # libgroupby functions are responsible for NOT altering mask
  385. mask = values._mask
  386. if self.kind != "aggregate":
  387. result_mask = mask.copy()
  388. else:
  389. result_mask = np.zeros(ngroups, dtype=bool)
  390. arr = values._data
  391. res_values = self._cython_op_ndim_compat(
  392. arr,
  393. min_count=min_count,
  394. ngroups=ngroups,
  395. comp_ids=comp_ids,
  396. mask=mask,
  397. result_mask=result_mask,
  398. **kwargs,
  399. )
  400. if self.how == "ohlc":
  401. arity = self._cython_arity.get(self.how, 1)
  402. result_mask = np.tile(result_mask, (arity, 1)).T
  403. # res_values should already have the correct dtype, we just need to
  404. # wrap in a MaskedArray
  405. return orig_values._maybe_mask_result(res_values, result_mask)
  406. @final
  407. def _cython_op_ndim_compat(
  408. self,
  409. values: np.ndarray,
  410. *,
  411. min_count: int,
  412. ngroups: int,
  413. comp_ids: np.ndarray,
  414. mask: npt.NDArray[np.bool_] | None = None,
  415. result_mask: npt.NDArray[np.bool_] | None = None,
  416. **kwargs,
  417. ) -> np.ndarray:
  418. if values.ndim == 1:
  419. # expand to 2d, dispatch, then squeeze if appropriate
  420. values2d = values[None, :]
  421. if mask is not None:
  422. mask = mask[None, :]
  423. if result_mask is not None:
  424. result_mask = result_mask[None, :]
  425. res = self._call_cython_op(
  426. values2d,
  427. min_count=min_count,
  428. ngroups=ngroups,
  429. comp_ids=comp_ids,
  430. mask=mask,
  431. result_mask=result_mask,
  432. **kwargs,
  433. )
  434. if res.shape[0] == 1:
  435. return res[0]
  436. # otherwise we have OHLC
  437. return res.T
  438. return self._call_cython_op(
  439. values,
  440. min_count=min_count,
  441. ngroups=ngroups,
  442. comp_ids=comp_ids,
  443. mask=mask,
  444. result_mask=result_mask,
  445. **kwargs,
  446. )
  447. @final
  448. def _call_cython_op(
  449. self,
  450. values: np.ndarray, # np.ndarray[ndim=2]
  451. *,
  452. min_count: int,
  453. ngroups: int,
  454. comp_ids: np.ndarray,
  455. mask: npt.NDArray[np.bool_] | None,
  456. result_mask: npt.NDArray[np.bool_] | None,
  457. **kwargs,
  458. ) -> np.ndarray: # np.ndarray[ndim=2]
  459. orig_values = values
  460. dtype = values.dtype
  461. is_numeric = is_numeric_dtype(dtype)
  462. is_datetimelike = needs_i8_conversion(dtype)
  463. if is_datetimelike:
  464. values = values.view("int64")
  465. is_numeric = True
  466. elif is_bool_dtype(dtype):
  467. values = values.view("uint8")
  468. if values.dtype == "float16":
  469. values = values.astype(np.float32)
  470. values = values.T
  471. if mask is not None:
  472. mask = mask.T
  473. if result_mask is not None:
  474. result_mask = result_mask.T
  475. out_shape = self._get_output_shape(ngroups, values)
  476. func = self._get_cython_function(self.kind, self.how, values.dtype, is_numeric)
  477. values = self._get_cython_vals(values)
  478. out_dtype = self._get_out_dtype(values.dtype)
  479. result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
  480. if self.kind == "aggregate":
  481. counts = np.zeros(ngroups, dtype=np.int64)
  482. if self.how in ["min", "max", "mean", "last", "first", "sum"]:
  483. func(
  484. out=result,
  485. counts=counts,
  486. values=values,
  487. labels=comp_ids,
  488. min_count=min_count,
  489. mask=mask,
  490. result_mask=result_mask,
  491. is_datetimelike=is_datetimelike,
  492. )
  493. elif self.how in ["var", "ohlc", "prod", "median"]:
  494. func(
  495. result,
  496. counts,
  497. values,
  498. comp_ids,
  499. min_count=min_count,
  500. mask=mask,
  501. result_mask=result_mask,
  502. **kwargs,
  503. )
  504. else:
  505. raise NotImplementedError(f"{self.how} is not implemented")
  506. else:
  507. # TODO: min_count
  508. if self.how != "rank":
  509. # TODO: should rank take result_mask?
  510. kwargs["result_mask"] = result_mask
  511. func(
  512. out=result,
  513. values=values,
  514. labels=comp_ids,
  515. ngroups=ngroups,
  516. is_datetimelike=is_datetimelike,
  517. mask=mask,
  518. **kwargs,
  519. )
  520. if self.kind == "aggregate":
  521. # i.e. counts is defined. Locations where count<min_count
  522. # need to have the result set to np.nan, which may require casting,
  523. # see GH#40767
  524. if is_integer_dtype(result.dtype) and not is_datetimelike:
  525. # if the op keeps the int dtypes, we have to use 0
  526. cutoff = max(0 if self.how in ["sum", "prod"] else 1, min_count)
  527. empty_groups = counts < cutoff
  528. if empty_groups.any():
  529. if result_mask is not None:
  530. assert result_mask[empty_groups].all()
  531. else:
  532. # Note: this conversion could be lossy, see GH#40767
  533. result = result.astype("float64")
  534. result[empty_groups] = np.nan
  535. result = result.T
  536. if self.how not in self.cast_blocklist:
  537. # e.g. if we are int64 and need to restore to datetime64/timedelta64
  538. # "rank" is the only member of cast_blocklist we get here
  539. # Casting only needed for float16, bool, datetimelike,
  540. # and self.how in ["sum", "prod", "ohlc", "cumprod"]
  541. res_dtype = self._get_result_dtype(orig_values.dtype)
  542. op_result = maybe_downcast_to_dtype(result, res_dtype)
  543. else:
  544. op_result = result
  545. return op_result
  546. @final
  547. def cython_operation(
  548. self,
  549. *,
  550. values: ArrayLike,
  551. axis: AxisInt,
  552. min_count: int = -1,
  553. comp_ids: np.ndarray,
  554. ngroups: int,
  555. **kwargs,
  556. ) -> ArrayLike:
  557. """
  558. Call our cython function, with appropriate pre- and post- processing.
  559. """
  560. if values.ndim > 2:
  561. raise NotImplementedError("number of dimensions is currently limited to 2")
  562. if values.ndim == 2:
  563. assert axis == 1, axis
  564. elif not is_1d_only_ea_dtype(values.dtype):
  565. # Note: it is *not* the case that axis is always 0 for 1-dim values,
  566. # as we can have 1D ExtensionArrays that we need to treat as 2D
  567. assert axis == 0
  568. dtype = values.dtype
  569. is_numeric = is_numeric_dtype(dtype)
  570. # can we do this operation with our cython functions
  571. # if not raise NotImplementedError
  572. self._disallow_invalid_ops(dtype, is_numeric)
  573. if not isinstance(values, np.ndarray):
  574. # i.e. ExtensionArray
  575. return self._ea_wrap_cython_operation(
  576. values,
  577. min_count=min_count,
  578. ngroups=ngroups,
  579. comp_ids=comp_ids,
  580. **kwargs,
  581. )
  582. return self._cython_op_ndim_compat(
  583. values,
  584. min_count=min_count,
  585. ngroups=ngroups,
  586. comp_ids=comp_ids,
  587. mask=None,
  588. **kwargs,
  589. )
  590. class BaseGrouper:
  591. """
  592. This is an internal Grouper class, which actually holds
  593. the generated groups
  594. Parameters
  595. ----------
  596. axis : Index
  597. groupings : Sequence[Grouping]
  598. all the grouping instances to handle in this grouper
  599. for example for grouper list to groupby, need to pass the list
  600. sort : bool, default True
  601. whether this grouper will give sorted result or not
  602. """
  603. axis: Index
  604. def __init__(
  605. self,
  606. axis: Index,
  607. groupings: Sequence[grouper.Grouping],
  608. sort: bool = True,
  609. dropna: bool = True,
  610. ) -> None:
  611. assert isinstance(axis, Index), axis
  612. self.axis = axis
  613. self._groupings: list[grouper.Grouping] = list(groupings)
  614. self._sort = sort
  615. self.dropna = dropna
  616. @property
  617. def groupings(self) -> list[grouper.Grouping]:
  618. return self._groupings
  619. @property
  620. def shape(self) -> Shape:
  621. return tuple(ping.ngroups for ping in self.groupings)
  622. def __iter__(self) -> Iterator[Hashable]:
  623. return iter(self.indices)
  624. @property
  625. def nkeys(self) -> int:
  626. return len(self.groupings)
  627. def get_iterator(
  628. self, data: NDFrameT, axis: AxisInt = 0
  629. ) -> Iterator[tuple[Hashable, NDFrameT]]:
  630. """
  631. Groupby iterator
  632. Returns
  633. -------
  634. Generator yielding sequence of (name, subsetted object)
  635. for each group
  636. """
  637. splitter = self._get_splitter(data, axis=axis)
  638. keys = self.group_keys_seq
  639. yield from zip(keys, splitter)
  640. @final
  641. def _get_splitter(self, data: NDFrame, axis: AxisInt = 0) -> DataSplitter:
  642. """
  643. Returns
  644. -------
  645. Generator yielding subsetted objects
  646. """
  647. ids, _, ngroups = self.group_info
  648. return _get_splitter(data, ids, ngroups, axis=axis)
  649. @final
  650. @cache_readonly
  651. def group_keys_seq(self):
  652. if len(self.groupings) == 1:
  653. return self.levels[0]
  654. else:
  655. ids, _, ngroups = self.group_info
  656. # provide "flattened" iterator for multi-group setting
  657. return get_flattened_list(ids, ngroups, self.levels, self.codes)
  658. @final
  659. def apply(
  660. self, f: Callable, data: DataFrame | Series, axis: AxisInt = 0
  661. ) -> tuple[list, bool]:
  662. mutated = False
  663. splitter = self._get_splitter(data, axis=axis)
  664. group_keys = self.group_keys_seq
  665. result_values = []
  666. # This calls DataSplitter.__iter__
  667. zipped = zip(group_keys, splitter)
  668. for key, group in zipped:
  669. object.__setattr__(group, "name", key)
  670. # group might be modified
  671. group_axes = group.axes
  672. res = f(group)
  673. if not mutated and not _is_indexed_like(res, group_axes, axis):
  674. mutated = True
  675. result_values.append(res)
  676. # getattr pattern for __name__ is needed for functools.partial objects
  677. if len(group_keys) == 0 and getattr(f, "__name__", None) in [
  678. "skew",
  679. "sum",
  680. "prod",
  681. ]:
  682. # If group_keys is empty, then no function calls have been made,
  683. # so we will not have raised even if this is an invalid dtype.
  684. # So do one dummy call here to raise appropriate TypeError.
  685. f(data.iloc[:0])
  686. return result_values, mutated
  687. @cache_readonly
  688. def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
  689. """dict {group name -> group indices}"""
  690. if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex):
  691. # This shows unused categories in indices GH#38642
  692. return self.groupings[0].indices
  693. codes_list = [ping.codes for ping in self.groupings]
  694. keys = [ping.group_index for ping in self.groupings]
  695. return get_indexer_dict(codes_list, keys)
  696. @final
  697. def result_ilocs(self) -> npt.NDArray[np.intp]:
  698. """
  699. Get the original integer locations of result_index in the input.
  700. """
  701. # Original indices are where group_index would go via sorting.
  702. # But when dropna is true, we need to remove null values while accounting for
  703. # any gaps that then occur because of them.
  704. group_index = get_group_index(
  705. self.codes, self.shape, sort=self._sort, xnull=True
  706. )
  707. group_index, _ = compress_group_index(group_index, sort=self._sort)
  708. if self.has_dropped_na:
  709. mask = np.where(group_index >= 0)
  710. # Count how many gaps are caused by previous null values for each position
  711. null_gaps = np.cumsum(group_index == -1)[mask]
  712. group_index = group_index[mask]
  713. result = get_group_index_sorter(group_index, self.ngroups)
  714. if self.has_dropped_na:
  715. # Shift by the number of prior null gaps
  716. result += np.take(null_gaps, result)
  717. return result
  718. @final
  719. @property
  720. def codes(self) -> list[npt.NDArray[np.signedinteger]]:
  721. return [ping.codes for ping in self.groupings]
  722. @property
  723. def levels(self) -> list[Index]:
  724. return [ping.group_index for ping in self.groupings]
  725. @property
  726. def names(self) -> list[Hashable]:
  727. return [ping.name for ping in self.groupings]
  728. @final
  729. def size(self) -> Series:
  730. """
  731. Compute group sizes.
  732. """
  733. ids, _, ngroups = self.group_info
  734. out: np.ndarray | list
  735. if ngroups:
  736. out = np.bincount(ids[ids != -1], minlength=ngroups)
  737. else:
  738. out = []
  739. return Series(out, index=self.result_index, dtype="int64")
  740. @cache_readonly
  741. def groups(self) -> dict[Hashable, np.ndarray]:
  742. """dict {group name -> group labels}"""
  743. if len(self.groupings) == 1:
  744. return self.groupings[0].groups
  745. else:
  746. to_groupby = zip(*(ping.grouping_vector for ping in self.groupings))
  747. index = Index(to_groupby)
  748. return self.axis.groupby(index)
  749. @final
  750. @cache_readonly
  751. def is_monotonic(self) -> bool:
  752. # return if my group orderings are monotonic
  753. return Index(self.group_info[0]).is_monotonic_increasing
  754. @final
  755. @cache_readonly
  756. def has_dropped_na(self) -> bool:
  757. """
  758. Whether grouper has null value(s) that are dropped.
  759. """
  760. return bool((self.group_info[0] < 0).any())
  761. @cache_readonly
  762. def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
  763. comp_ids, obs_group_ids = self._get_compressed_codes()
  764. ngroups = len(obs_group_ids)
  765. comp_ids = ensure_platform_int(comp_ids)
  766. return comp_ids, obs_group_ids, ngroups
  767. @cache_readonly
  768. def codes_info(self) -> npt.NDArray[np.intp]:
  769. # return the codes of items in original grouped axis
  770. ids, _, _ = self.group_info
  771. return ids
  772. @final
  773. def _get_compressed_codes(
  774. self,
  775. ) -> tuple[npt.NDArray[np.signedinteger], npt.NDArray[np.intp]]:
  776. # The first returned ndarray may have any signed integer dtype
  777. if len(self.groupings) > 1:
  778. group_index = get_group_index(self.codes, self.shape, sort=True, xnull=True)
  779. return compress_group_index(group_index, sort=self._sort)
  780. # FIXME: compress_group_index's second return value is int64, not intp
  781. ping = self.groupings[0]
  782. return ping.codes, np.arange(len(ping.group_index), dtype=np.intp)
  783. @final
  784. @cache_readonly
  785. def ngroups(self) -> int:
  786. return len(self.result_index)
  787. @property
  788. def reconstructed_codes(self) -> list[npt.NDArray[np.intp]]:
  789. codes = self.codes
  790. ids, obs_ids, _ = self.group_info
  791. return decons_obs_group_ids(ids, obs_ids, self.shape, codes, xnull=True)
  792. @cache_readonly
  793. def result_index(self) -> Index:
  794. if len(self.groupings) == 1:
  795. return self.groupings[0].result_index.rename(self.names[0])
  796. codes = self.reconstructed_codes
  797. levels = [ping.result_index for ping in self.groupings]
  798. return MultiIndex(
  799. levels=levels, codes=codes, verify_integrity=False, names=self.names
  800. )
  801. @final
  802. def get_group_levels(self) -> list[ArrayLike]:
  803. # Note: only called from _insert_inaxis_grouper, which
  804. # is only called for BaseGrouper, never for BinGrouper
  805. if len(self.groupings) == 1:
  806. return [self.groupings[0].group_arraylike]
  807. name_list = []
  808. for ping, codes in zip(self.groupings, self.reconstructed_codes):
  809. codes = ensure_platform_int(codes)
  810. levels = ping.group_arraylike.take(codes)
  811. name_list.append(levels)
  812. return name_list
  813. # ------------------------------------------------------------
  814. # Aggregation functions
  815. @final
  816. def _cython_operation(
  817. self,
  818. kind: str,
  819. values,
  820. how: str,
  821. axis: AxisInt,
  822. min_count: int = -1,
  823. **kwargs,
  824. ) -> ArrayLike:
  825. """
  826. Returns the values of a cython operation.
  827. """
  828. assert kind in ["transform", "aggregate"]
  829. cy_op = WrappedCythonOp(kind=kind, how=how, has_dropped_na=self.has_dropped_na)
  830. ids, _, _ = self.group_info
  831. ngroups = self.ngroups
  832. return cy_op.cython_operation(
  833. values=values,
  834. axis=axis,
  835. min_count=min_count,
  836. comp_ids=ids,
  837. ngroups=ngroups,
  838. **kwargs,
  839. )
  840. @final
  841. def agg_series(
  842. self, obj: Series, func: Callable, preserve_dtype: bool = False
  843. ) -> ArrayLike:
  844. """
  845. Parameters
  846. ----------
  847. obj : Series
  848. func : function taking a Series and returning a scalar-like
  849. preserve_dtype : bool
  850. Whether the aggregation is known to be dtype-preserving.
  851. Returns
  852. -------
  853. np.ndarray or ExtensionArray
  854. """
  855. # test_groupby_empty_with_category gets here with self.ngroups == 0
  856. # and len(obj) > 0
  857. if len(obj) > 0 and not isinstance(obj._values, np.ndarray):
  858. # we can preserve a little bit more aggressively with EA dtype
  859. # because maybe_cast_pointwise_result will do a try/except
  860. # with _from_sequence. NB we are assuming here that _from_sequence
  861. # is sufficiently strict that it casts appropriately.
  862. preserve_dtype = True
  863. result = self._aggregate_series_pure_python(obj, func)
  864. npvalues = lib.maybe_convert_objects(result, try_float=False)
  865. if preserve_dtype:
  866. out = maybe_cast_pointwise_result(npvalues, obj.dtype, numeric_only=True)
  867. else:
  868. out = npvalues
  869. return out
  870. @final
  871. def _aggregate_series_pure_python(
  872. self, obj: Series, func: Callable
  873. ) -> npt.NDArray[np.object_]:
  874. _, _, ngroups = self.group_info
  875. result = np.empty(ngroups, dtype="O")
  876. initialized = False
  877. splitter = self._get_splitter(obj, axis=0)
  878. for i, group in enumerate(splitter):
  879. res = func(group)
  880. res = libreduction.extract_result(res)
  881. if not initialized:
  882. # We only do this validation on the first iteration
  883. libreduction.check_result_array(res, group.dtype)
  884. initialized = True
  885. result[i] = res
  886. return result
  887. class BinGrouper(BaseGrouper):
  888. """
  889. This is an internal Grouper class
  890. Parameters
  891. ----------
  892. bins : the split index of binlabels to group the item of axis
  893. binlabels : the label list
  894. indexer : np.ndarray[np.intp], optional
  895. the indexer created by Grouper
  896. some groupers (TimeGrouper) will sort its axis and its
  897. group_info is also sorted, so need the indexer to reorder
  898. Examples
  899. --------
  900. bins: [2, 4, 6, 8, 10]
  901. binlabels: DatetimeIndex(['2005-01-01', '2005-01-03',
  902. '2005-01-05', '2005-01-07', '2005-01-09'],
  903. dtype='datetime64[ns]', freq='2D')
  904. the group_info, which contains the label of each item in grouped
  905. axis, the index of label in label list, group number, is
  906. (array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]), array([0, 1, 2, 3, 4]), 5)
  907. means that, the grouped axis has 10 items, can be grouped into 5
  908. labels, the first and second items belong to the first label, the
  909. third and forth items belong to the second label, and so on
  910. """
  911. bins: npt.NDArray[np.int64]
  912. binlabels: Index
  913. def __init__(
  914. self,
  915. bins,
  916. binlabels,
  917. indexer=None,
  918. ) -> None:
  919. self.bins = ensure_int64(bins)
  920. self.binlabels = ensure_index(binlabels)
  921. self.indexer = indexer
  922. # These lengths must match, otherwise we could call agg_series
  923. # with empty self.bins, which would raise in libreduction.
  924. assert len(self.binlabels) == len(self.bins)
  925. @cache_readonly
  926. def groups(self):
  927. """dict {group name -> group labels}"""
  928. # this is mainly for compat
  929. # GH 3881
  930. result = {
  931. key: value
  932. for key, value in zip(self.binlabels, self.bins)
  933. if key is not NaT
  934. }
  935. return result
  936. @property
  937. def nkeys(self) -> int:
  938. # still matches len(self.groupings), but we can hard-code
  939. return 1
  940. @cache_readonly
  941. def codes_info(self) -> npt.NDArray[np.intp]:
  942. # return the codes of items in original grouped axis
  943. ids, _, _ = self.group_info
  944. if self.indexer is not None:
  945. sorter = np.lexsort((ids, self.indexer))
  946. ids = ids[sorter]
  947. return ids
  948. def get_iterator(self, data: NDFrame, axis: AxisInt = 0):
  949. """
  950. Groupby iterator
  951. Returns
  952. -------
  953. Generator yielding sequence of (name, subsetted object)
  954. for each group
  955. """
  956. if axis == 0:
  957. slicer = lambda start, edge: data.iloc[start:edge]
  958. else:
  959. slicer = lambda start, edge: data.iloc[:, start:edge]
  960. length = len(data.axes[axis])
  961. start = 0
  962. for edge, label in zip(self.bins, self.binlabels):
  963. if label is not NaT:
  964. yield label, slicer(start, edge)
  965. start = edge
  966. if start < length:
  967. yield self.binlabels[-1], slicer(start, None)
  968. @cache_readonly
  969. def indices(self):
  970. indices = collections.defaultdict(list)
  971. i = 0
  972. for label, bin in zip(self.binlabels, self.bins):
  973. if i < bin:
  974. if label is not NaT:
  975. indices[label] = list(range(i, bin))
  976. i = bin
  977. return indices
  978. @cache_readonly
  979. def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
  980. ngroups = self.ngroups
  981. obs_group_ids = np.arange(ngroups, dtype=np.intp)
  982. rep = np.diff(np.r_[0, self.bins])
  983. rep = ensure_platform_int(rep)
  984. if ngroups == len(self.bins):
  985. comp_ids = np.repeat(np.arange(ngroups), rep)
  986. else:
  987. comp_ids = np.repeat(np.r_[-1, np.arange(ngroups)], rep)
  988. return (
  989. ensure_platform_int(comp_ids),
  990. obs_group_ids,
  991. ngroups,
  992. )
  993. @cache_readonly
  994. def reconstructed_codes(self) -> list[np.ndarray]:
  995. # get unique result indices, and prepend 0 as groupby starts from the first
  996. return [np.r_[0, np.flatnonzero(self.bins[1:] != self.bins[:-1]) + 1]]
  997. @cache_readonly
  998. def result_index(self) -> Index:
  999. if len(self.binlabels) != 0 and isna(self.binlabels[0]):
  1000. return self.binlabels[1:]
  1001. return self.binlabels
  1002. @property
  1003. def levels(self) -> list[Index]:
  1004. return [self.binlabels]
  1005. @property
  1006. def names(self) -> list[Hashable]:
  1007. return [self.binlabels.name]
  1008. @property
  1009. def groupings(self) -> list[grouper.Grouping]:
  1010. lev = self.binlabels
  1011. codes = self.group_info[0]
  1012. labels = lev.take(codes)
  1013. ping = grouper.Grouping(
  1014. labels, labels, in_axis=False, level=None, uniques=lev._values
  1015. )
  1016. return [ping]
  1017. def _is_indexed_like(obj, axes, axis: AxisInt) -> bool:
  1018. if isinstance(obj, Series):
  1019. if len(axes) > 1:
  1020. return False
  1021. return obj.axes[axis].equals(axes[axis])
  1022. elif isinstance(obj, DataFrame):
  1023. return obj.axes[axis].equals(axes[axis])
  1024. return False
  1025. # ----------------------------------------------------------------------
  1026. # Splitting / application
  1027. class DataSplitter(Generic[NDFrameT]):
  1028. def __init__(
  1029. self,
  1030. data: NDFrameT,
  1031. labels: npt.NDArray[np.intp],
  1032. ngroups: int,
  1033. axis: AxisInt = 0,
  1034. ) -> None:
  1035. self.data = data
  1036. self.labels = ensure_platform_int(labels) # _should_ already be np.intp
  1037. self.ngroups = ngroups
  1038. self.axis = axis
  1039. assert isinstance(axis, int), axis
  1040. @cache_readonly
  1041. def _slabels(self) -> npt.NDArray[np.intp]:
  1042. # Sorted labels
  1043. return self.labels.take(self._sort_idx)
  1044. @cache_readonly
  1045. def _sort_idx(self) -> npt.NDArray[np.intp]:
  1046. # Counting sort indexer
  1047. return get_group_index_sorter(self.labels, self.ngroups)
  1048. def __iter__(self) -> Iterator:
  1049. sdata = self._sorted_data
  1050. if self.ngroups == 0:
  1051. # we are inside a generator, rather than raise StopIteration
  1052. # we merely return signal the end
  1053. return
  1054. starts, ends = lib.generate_slices(self._slabels, self.ngroups)
  1055. for start, end in zip(starts, ends):
  1056. yield self._chop(sdata, slice(start, end))
  1057. @cache_readonly
  1058. def _sorted_data(self) -> NDFrameT:
  1059. return self.data.take(self._sort_idx, axis=self.axis)
  1060. def _chop(self, sdata, slice_obj: slice) -> NDFrame:
  1061. raise AbstractMethodError(self)
  1062. class SeriesSplitter(DataSplitter):
  1063. def _chop(self, sdata: Series, slice_obj: slice) -> Series:
  1064. # fastpath equivalent to `sdata.iloc[slice_obj]`
  1065. mgr = sdata._mgr.get_slice(slice_obj)
  1066. ser = sdata._constructor(mgr, name=sdata.name, fastpath=True)
  1067. return ser.__finalize__(sdata, method="groupby")
  1068. class FrameSplitter(DataSplitter):
  1069. def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
  1070. # Fastpath equivalent to:
  1071. # if self.axis == 0:
  1072. # return sdata.iloc[slice_obj]
  1073. # else:
  1074. # return sdata.iloc[:, slice_obj]
  1075. mgr = sdata._mgr.get_slice(slice_obj, axis=1 - self.axis)
  1076. df = sdata._constructor(mgr)
  1077. return df.__finalize__(sdata, method="groupby")
  1078. def _get_splitter(
  1079. data: NDFrame, labels: np.ndarray, ngroups: int, axis: AxisInt = 0
  1080. ) -> DataSplitter:
  1081. if isinstance(data, Series):
  1082. klass: type[DataSplitter] = SeriesSplitter
  1083. else:
  1084. # i.e. DataFrame
  1085. klass = FrameSplitter
  1086. return klass(data, labels, ngroups, axis)