pivot.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. from __future__ import annotations
  2. from typing import (
  3. TYPE_CHECKING,
  4. Callable,
  5. Hashable,
  6. Sequence,
  7. cast,
  8. )
  9. import numpy as np
  10. from pandas._libs import lib
  11. from pandas._typing import (
  12. AggFuncType,
  13. AggFuncTypeBase,
  14. AggFuncTypeDict,
  15. IndexLabel,
  16. )
  17. from pandas.util._decorators import (
  18. Appender,
  19. Substitution,
  20. )
  21. from pandas.core.dtypes.cast import maybe_downcast_to_dtype
  22. from pandas.core.dtypes.common import (
  23. is_extension_array_dtype,
  24. is_integer_dtype,
  25. is_list_like,
  26. is_nested_list_like,
  27. is_scalar,
  28. )
  29. from pandas.core.dtypes.generic import (
  30. ABCDataFrame,
  31. ABCSeries,
  32. )
  33. import pandas.core.common as com
  34. from pandas.core.frame import _shared_docs
  35. from pandas.core.groupby import Grouper
  36. from pandas.core.indexes.api import (
  37. Index,
  38. MultiIndex,
  39. get_objs_combined_axis,
  40. )
  41. from pandas.core.reshape.concat import concat
  42. from pandas.core.reshape.util import cartesian_product
  43. from pandas.core.series import Series
  44. if TYPE_CHECKING:
  45. from pandas import DataFrame
  46. # Note: We need to make sure `frame` is imported before `pivot`, otherwise
  47. # _shared_docs['pivot_table'] will not yet exist. TODO: Fix this dependency
  48. @Substitution("\ndata : DataFrame")
  49. @Appender(_shared_docs["pivot_table"], indents=1)
  50. def pivot_table(
  51. data: DataFrame,
  52. values=None,
  53. index=None,
  54. columns=None,
  55. aggfunc: AggFuncType = "mean",
  56. fill_value=None,
  57. margins: bool = False,
  58. dropna: bool = True,
  59. margins_name: Hashable = "All",
  60. observed: bool = False,
  61. sort: bool = True,
  62. ) -> DataFrame:
  63. index = _convert_by(index)
  64. columns = _convert_by(columns)
  65. if isinstance(aggfunc, list):
  66. pieces: list[DataFrame] = []
  67. keys = []
  68. for func in aggfunc:
  69. _table = __internal_pivot_table(
  70. data,
  71. values=values,
  72. index=index,
  73. columns=columns,
  74. fill_value=fill_value,
  75. aggfunc=func,
  76. margins=margins,
  77. dropna=dropna,
  78. margins_name=margins_name,
  79. observed=observed,
  80. sort=sort,
  81. )
  82. pieces.append(_table)
  83. keys.append(getattr(func, "__name__", func))
  84. table = concat(pieces, keys=keys, axis=1)
  85. return table.__finalize__(data, method="pivot_table")
  86. table = __internal_pivot_table(
  87. data,
  88. values,
  89. index,
  90. columns,
  91. aggfunc,
  92. fill_value,
  93. margins,
  94. dropna,
  95. margins_name,
  96. observed,
  97. sort,
  98. )
  99. return table.__finalize__(data, method="pivot_table")
  100. def __internal_pivot_table(
  101. data: DataFrame,
  102. values,
  103. index,
  104. columns,
  105. aggfunc: AggFuncTypeBase | AggFuncTypeDict,
  106. fill_value,
  107. margins: bool,
  108. dropna: bool,
  109. margins_name: Hashable,
  110. observed: bool,
  111. sort: bool,
  112. ) -> DataFrame:
  113. """
  114. Helper of :func:`pandas.pivot_table` for any non-list ``aggfunc``.
  115. """
  116. keys = index + columns
  117. values_passed = values is not None
  118. if values_passed:
  119. if is_list_like(values):
  120. values_multi = True
  121. values = list(values)
  122. else:
  123. values_multi = False
  124. values = [values]
  125. # GH14938 Make sure value labels are in data
  126. for i in values:
  127. if i not in data:
  128. raise KeyError(i)
  129. to_filter = []
  130. for x in keys + values:
  131. if isinstance(x, Grouper):
  132. x = x.key
  133. try:
  134. if x in data:
  135. to_filter.append(x)
  136. except TypeError:
  137. pass
  138. if len(to_filter) < len(data.columns):
  139. data = data[to_filter]
  140. else:
  141. values = data.columns
  142. for key in keys:
  143. try:
  144. values = values.drop(key)
  145. except (TypeError, ValueError, KeyError):
  146. pass
  147. values = list(values)
  148. grouped = data.groupby(keys, observed=observed, sort=sort)
  149. agged = grouped.agg(aggfunc)
  150. if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
  151. agged = agged.dropna(how="all")
  152. # gh-21133
  153. # we want to down cast if
  154. # the original values are ints
  155. # as we grouped with a NaN value
  156. # and then dropped, coercing to floats
  157. for v in values:
  158. if (
  159. v in data
  160. and is_integer_dtype(data[v])
  161. and v in agged
  162. and not is_integer_dtype(agged[v])
  163. ):
  164. if not isinstance(agged[v], ABCDataFrame) and isinstance(
  165. data[v].dtype, np.dtype
  166. ):
  167. # exclude DataFrame case bc maybe_downcast_to_dtype expects
  168. # ArrayLike
  169. # e.g. test_pivot_table_multiindex_columns_doctest_case
  170. # agged.columns is a MultiIndex and 'v' is indexing only
  171. # on its first level.
  172. agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)
  173. table = agged
  174. # GH17038, this check should only happen if index is defined (not None)
  175. if table.index.nlevels > 1 and index:
  176. # Related GH #17123
  177. # If index_names are integers, determine whether the integers refer
  178. # to the level position or name.
  179. index_names = agged.index.names[: len(index)]
  180. to_unstack = []
  181. for i in range(len(index), len(keys)):
  182. name = agged.index.names[i]
  183. if name is None or name in index_names:
  184. to_unstack.append(i)
  185. else:
  186. to_unstack.append(name)
  187. table = agged.unstack(to_unstack)
  188. if not dropna:
  189. if isinstance(table.index, MultiIndex):
  190. m = MultiIndex.from_arrays(
  191. cartesian_product(table.index.levels), names=table.index.names
  192. )
  193. table = table.reindex(m, axis=0)
  194. if isinstance(table.columns, MultiIndex):
  195. m = MultiIndex.from_arrays(
  196. cartesian_product(table.columns.levels), names=table.columns.names
  197. )
  198. table = table.reindex(m, axis=1)
  199. if sort is True and isinstance(table, ABCDataFrame):
  200. table = table.sort_index(axis=1)
  201. if fill_value is not None:
  202. table = table.fillna(fill_value, downcast="infer")
  203. if margins:
  204. if dropna:
  205. data = data[data.notna().all(axis=1)]
  206. table = _add_margins(
  207. table,
  208. data,
  209. values,
  210. rows=index,
  211. cols=columns,
  212. aggfunc=aggfunc,
  213. observed=dropna,
  214. margins_name=margins_name,
  215. fill_value=fill_value,
  216. )
  217. # discard the top level
  218. if values_passed and not values_multi and table.columns.nlevels > 1:
  219. table = table.droplevel(0, axis=1)
  220. if len(index) == 0 and len(columns) > 0:
  221. table = table.T
  222. # GH 15193 Make sure empty columns are removed if dropna=True
  223. if isinstance(table, ABCDataFrame) and dropna:
  224. table = table.dropna(how="all", axis=1)
  225. return table
  226. def _add_margins(
  227. table: DataFrame | Series,
  228. data: DataFrame,
  229. values,
  230. rows,
  231. cols,
  232. aggfunc,
  233. observed=None,
  234. margins_name: Hashable = "All",
  235. fill_value=None,
  236. ):
  237. if not isinstance(margins_name, str):
  238. raise ValueError("margins_name argument must be a string")
  239. msg = f'Conflicting name "{margins_name}" in margins'
  240. for level in table.index.names:
  241. if margins_name in table.index.get_level_values(level):
  242. raise ValueError(msg)
  243. grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name)
  244. if table.ndim == 2:
  245. # i.e. DataFrame
  246. for level in table.columns.names[1:]:
  247. if margins_name in table.columns.get_level_values(level):
  248. raise ValueError(msg)
  249. key: str | tuple[str, ...]
  250. if len(rows) > 1:
  251. key = (margins_name,) + ("",) * (len(rows) - 1)
  252. else:
  253. key = margins_name
  254. if not values and isinstance(table, ABCSeries):
  255. # If there are no values and the table is a series, then there is only
  256. # one column in the data. Compute grand margin and return it.
  257. return table._append(Series({key: grand_margin[margins_name]}))
  258. elif values:
  259. marginal_result_set = _generate_marginal_results(
  260. table, data, values, rows, cols, aggfunc, observed, margins_name
  261. )
  262. if not isinstance(marginal_result_set, tuple):
  263. return marginal_result_set
  264. result, margin_keys, row_margin = marginal_result_set
  265. else:
  266. # no values, and table is a DataFrame
  267. assert isinstance(table, ABCDataFrame)
  268. marginal_result_set = _generate_marginal_results_without_values(
  269. table, data, rows, cols, aggfunc, observed, margins_name
  270. )
  271. if not isinstance(marginal_result_set, tuple):
  272. return marginal_result_set
  273. result, margin_keys, row_margin = marginal_result_set
  274. row_margin = row_margin.reindex(result.columns, fill_value=fill_value)
  275. # populate grand margin
  276. for k in margin_keys:
  277. if isinstance(k, str):
  278. row_margin[k] = grand_margin[k]
  279. else:
  280. row_margin[k] = grand_margin[k[0]]
  281. from pandas import DataFrame
  282. margin_dummy = DataFrame(row_margin, columns=Index([key])).T
  283. row_names = result.index.names
  284. # check the result column and leave floats
  285. for dtype in set(result.dtypes):
  286. if is_extension_array_dtype(dtype):
  287. # Can hold NA already
  288. continue
  289. cols = result.select_dtypes([dtype]).columns
  290. margin_dummy[cols] = margin_dummy[cols].apply(
  291. maybe_downcast_to_dtype, args=(dtype,)
  292. )
  293. result = result._append(margin_dummy)
  294. result.index.names = row_names
  295. return result
  296. def _compute_grand_margin(
  297. data: DataFrame, values, aggfunc, margins_name: Hashable = "All"
  298. ):
  299. if values:
  300. grand_margin = {}
  301. for k, v in data[values].items():
  302. try:
  303. if isinstance(aggfunc, str):
  304. grand_margin[k] = getattr(v, aggfunc)()
  305. elif isinstance(aggfunc, dict):
  306. if isinstance(aggfunc[k], str):
  307. grand_margin[k] = getattr(v, aggfunc[k])()
  308. else:
  309. grand_margin[k] = aggfunc[k](v)
  310. else:
  311. grand_margin[k] = aggfunc(v)
  312. except TypeError:
  313. pass
  314. return grand_margin
  315. else:
  316. return {margins_name: aggfunc(data.index)}
  317. def _generate_marginal_results(
  318. table, data, values, rows, cols, aggfunc, observed, margins_name: Hashable = "All"
  319. ):
  320. if len(cols) > 0:
  321. # need to "interleave" the margins
  322. table_pieces = []
  323. margin_keys = []
  324. def _all_key(key):
  325. return (key, margins_name) + ("",) * (len(cols) - 1)
  326. if len(rows) > 0:
  327. margin = data[rows + values].groupby(rows, observed=observed).agg(aggfunc)
  328. cat_axis = 1
  329. for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed):
  330. all_key = _all_key(key)
  331. # we are going to mutate this, so need to copy!
  332. piece = piece.copy()
  333. piece[all_key] = margin[key]
  334. table_pieces.append(piece)
  335. margin_keys.append(all_key)
  336. else:
  337. from pandas import DataFrame
  338. cat_axis = 0
  339. for key, piece in table.groupby(level=0, axis=cat_axis, observed=observed):
  340. if len(cols) > 1:
  341. all_key = _all_key(key)
  342. else:
  343. all_key = margins_name
  344. table_pieces.append(piece)
  345. # GH31016 this is to calculate margin for each group, and assign
  346. # corresponded key as index
  347. transformed_piece = DataFrame(piece.apply(aggfunc)).T
  348. if isinstance(piece.index, MultiIndex):
  349. # We are adding an empty level
  350. transformed_piece.index = MultiIndex.from_tuples(
  351. [all_key], names=piece.index.names + [None]
  352. )
  353. else:
  354. transformed_piece.index = Index([all_key], name=piece.index.name)
  355. # append piece for margin into table_piece
  356. table_pieces.append(transformed_piece)
  357. margin_keys.append(all_key)
  358. if not table_pieces:
  359. # GH 49240
  360. return table
  361. else:
  362. result = concat(table_pieces, axis=cat_axis)
  363. if len(rows) == 0:
  364. return result
  365. else:
  366. result = table
  367. margin_keys = table.columns
  368. if len(cols) > 0:
  369. row_margin = data[cols + values].groupby(cols, observed=observed).agg(aggfunc)
  370. row_margin = row_margin.stack()
  371. # slight hack
  372. new_order = [len(cols)] + list(range(len(cols)))
  373. row_margin.index = row_margin.index.reorder_levels(new_order)
  374. else:
  375. row_margin = Series(np.nan, index=result.columns)
  376. return result, margin_keys, row_margin
  377. def _generate_marginal_results_without_values(
  378. table: DataFrame,
  379. data,
  380. rows,
  381. cols,
  382. aggfunc,
  383. observed,
  384. margins_name: Hashable = "All",
  385. ):
  386. if len(cols) > 0:
  387. # need to "interleave" the margins
  388. margin_keys: list | Index = []
  389. def _all_key():
  390. if len(cols) == 1:
  391. return margins_name
  392. return (margins_name,) + ("",) * (len(cols) - 1)
  393. if len(rows) > 0:
  394. margin = data[rows].groupby(rows, observed=observed).apply(aggfunc)
  395. all_key = _all_key()
  396. table[all_key] = margin
  397. result = table
  398. margin_keys.append(all_key)
  399. else:
  400. margin = data.groupby(level=0, axis=0, observed=observed).apply(aggfunc)
  401. all_key = _all_key()
  402. table[all_key] = margin
  403. result = table
  404. margin_keys.append(all_key)
  405. return result
  406. else:
  407. result = table
  408. margin_keys = table.columns
  409. if len(cols):
  410. row_margin = data[cols].groupby(cols, observed=observed).apply(aggfunc)
  411. else:
  412. row_margin = Series(np.nan, index=result.columns)
  413. return result, margin_keys, row_margin
  414. def _convert_by(by):
  415. if by is None:
  416. by = []
  417. elif (
  418. is_scalar(by)
  419. or isinstance(by, (np.ndarray, Index, ABCSeries, Grouper))
  420. or callable(by)
  421. ):
  422. by = [by]
  423. else:
  424. by = list(by)
  425. return by
  426. @Substitution("\ndata : DataFrame")
  427. @Appender(_shared_docs["pivot"], indents=1)
  428. def pivot(
  429. data: DataFrame,
  430. *,
  431. columns: IndexLabel,
  432. index: IndexLabel | lib.NoDefault = lib.NoDefault,
  433. values: IndexLabel | lib.NoDefault = lib.NoDefault,
  434. ) -> DataFrame:
  435. columns_listlike = com.convert_to_list_like(columns)
  436. # If columns is None we will create a MultiIndex level with None as name
  437. # which might cause duplicated names because None is the default for
  438. # level names
  439. data = data.copy(deep=False)
  440. data.index = data.index.copy()
  441. data.index.names = [
  442. name if name is not None else lib.NoDefault for name in data.index.names
  443. ]
  444. indexed: DataFrame | Series
  445. if values is lib.NoDefault:
  446. if index is not lib.NoDefault:
  447. cols = com.convert_to_list_like(index)
  448. else:
  449. cols = []
  450. append = index is lib.NoDefault
  451. # error: Unsupported operand types for + ("List[Any]" and "ExtensionArray")
  452. # error: Unsupported left operand type for + ("ExtensionArray")
  453. indexed = data.set_index(
  454. cols + columns_listlike, append=append # type: ignore[operator]
  455. )
  456. else:
  457. if index is lib.NoDefault:
  458. if isinstance(data.index, MultiIndex):
  459. # GH 23955
  460. index_list = [
  461. data.index.get_level_values(i) for i in range(data.index.nlevels)
  462. ]
  463. else:
  464. index_list = [Series(data.index, name=data.index.name)]
  465. else:
  466. index_list = [data[idx] for idx in com.convert_to_list_like(index)]
  467. data_columns = [data[col] for col in columns_listlike]
  468. index_list.extend(data_columns)
  469. multiindex = MultiIndex.from_arrays(index_list)
  470. if is_list_like(values) and not isinstance(values, tuple):
  471. # Exclude tuple because it is seen as a single column name
  472. values = cast(Sequence[Hashable], values)
  473. indexed = data._constructor(
  474. data[values]._values, index=multiindex, columns=values
  475. )
  476. else:
  477. indexed = data._constructor_sliced(data[values]._values, index=multiindex)
  478. # error: Argument 1 to "unstack" of "DataFrame" has incompatible type "Union
  479. # [List[Any], ExtensionArray, ndarray[Any, Any], Index, Series]"; expected
  480. # "Hashable"
  481. result = indexed.unstack(columns_listlike) # type: ignore[arg-type]
  482. result.index.names = [
  483. name if name is not lib.NoDefault else None for name in result.index.names
  484. ]
  485. return result
  486. def crosstab(
  487. index,
  488. columns,
  489. values=None,
  490. rownames=None,
  491. colnames=None,
  492. aggfunc=None,
  493. margins: bool = False,
  494. margins_name: Hashable = "All",
  495. dropna: bool = True,
  496. normalize: bool = False,
  497. ) -> DataFrame:
  498. """
  499. Compute a simple cross tabulation of two (or more) factors.
  500. By default, computes a frequency table of the factors unless an
  501. array of values and an aggregation function are passed.
  502. Parameters
  503. ----------
  504. index : array-like, Series, or list of arrays/Series
  505. Values to group by in the rows.
  506. columns : array-like, Series, or list of arrays/Series
  507. Values to group by in the columns.
  508. values : array-like, optional
  509. Array of values to aggregate according to the factors.
  510. Requires `aggfunc` be specified.
  511. rownames : sequence, default None
  512. If passed, must match number of row arrays passed.
  513. colnames : sequence, default None
  514. If passed, must match number of column arrays passed.
  515. aggfunc : function, optional
  516. If specified, requires `values` be specified as well.
  517. margins : bool, default False
  518. Add row/column margins (subtotals).
  519. margins_name : str, default 'All'
  520. Name of the row/column that will contain the totals
  521. when margins is True.
  522. dropna : bool, default True
  523. Do not include columns whose entries are all NaN.
  524. normalize : bool, {'all', 'index', 'columns'}, or {0,1}, default False
  525. Normalize by dividing all values by the sum of values.
  526. - If passed 'all' or `True`, will normalize over all values.
  527. - If passed 'index' will normalize over each row.
  528. - If passed 'columns' will normalize over each column.
  529. - If margins is `True`, will also normalize margin values.
  530. Returns
  531. -------
  532. DataFrame
  533. Cross tabulation of the data.
  534. See Also
  535. --------
  536. DataFrame.pivot : Reshape data based on column values.
  537. pivot_table : Create a pivot table as a DataFrame.
  538. Notes
  539. -----
  540. Any Series passed will have their name attributes used unless row or column
  541. names for the cross-tabulation are specified.
  542. Any input passed containing Categorical data will have **all** of its
  543. categories included in the cross-tabulation, even if the actual data does
  544. not contain any instances of a particular category.
  545. In the event that there aren't overlapping indexes an empty DataFrame will
  546. be returned.
  547. Reference :ref:`the user guide <reshaping.crosstabulations>` for more examples.
  548. Examples
  549. --------
  550. >>> a = np.array(["foo", "foo", "foo", "foo", "bar", "bar",
  551. ... "bar", "bar", "foo", "foo", "foo"], dtype=object)
  552. >>> b = np.array(["one", "one", "one", "two", "one", "one",
  553. ... "one", "two", "two", "two", "one"], dtype=object)
  554. >>> c = np.array(["dull", "dull", "shiny", "dull", "dull", "shiny",
  555. ... "shiny", "dull", "shiny", "shiny", "shiny"],
  556. ... dtype=object)
  557. >>> pd.crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c'])
  558. b one two
  559. c dull shiny dull shiny
  560. a
  561. bar 1 2 1 0
  562. foo 2 2 1 2
  563. Here 'c' and 'f' are not represented in the data and will not be
  564. shown in the output because dropna is True by default. Set
  565. dropna=False to preserve categories with no data.
  566. >>> foo = pd.Categorical(['a', 'b'], categories=['a', 'b', 'c'])
  567. >>> bar = pd.Categorical(['d', 'e'], categories=['d', 'e', 'f'])
  568. >>> pd.crosstab(foo, bar)
  569. col_0 d e
  570. row_0
  571. a 1 0
  572. b 0 1
  573. >>> pd.crosstab(foo, bar, dropna=False)
  574. col_0 d e f
  575. row_0
  576. a 1 0 0
  577. b 0 1 0
  578. c 0 0 0
  579. """
  580. if values is None and aggfunc is not None:
  581. raise ValueError("aggfunc cannot be used without values.")
  582. if values is not None and aggfunc is None:
  583. raise ValueError("values cannot be used without an aggfunc.")
  584. if not is_nested_list_like(index):
  585. index = [index]
  586. if not is_nested_list_like(columns):
  587. columns = [columns]
  588. common_idx = None
  589. pass_objs = [x for x in index + columns if isinstance(x, (ABCSeries, ABCDataFrame))]
  590. if pass_objs:
  591. common_idx = get_objs_combined_axis(pass_objs, intersect=True, sort=False)
  592. rownames = _get_names(index, rownames, prefix="row")
  593. colnames = _get_names(columns, colnames, prefix="col")
  594. # duplicate names mapped to unique names for pivot op
  595. (
  596. rownames_mapper,
  597. unique_rownames,
  598. colnames_mapper,
  599. unique_colnames,
  600. ) = _build_names_mapper(rownames, colnames)
  601. from pandas import DataFrame
  602. data = {
  603. **dict(zip(unique_rownames, index)),
  604. **dict(zip(unique_colnames, columns)),
  605. }
  606. df = DataFrame(data, index=common_idx)
  607. if values is None:
  608. df["__dummy__"] = 0
  609. kwargs = {"aggfunc": len, "fill_value": 0}
  610. else:
  611. df["__dummy__"] = values
  612. kwargs = {"aggfunc": aggfunc}
  613. # error: Argument 7 to "pivot_table" of "DataFrame" has incompatible type
  614. # "**Dict[str, object]"; expected "Union[...]"
  615. table = df.pivot_table(
  616. "__dummy__",
  617. index=unique_rownames,
  618. columns=unique_colnames,
  619. margins=margins,
  620. margins_name=margins_name,
  621. dropna=dropna,
  622. **kwargs, # type: ignore[arg-type]
  623. )
  624. # Post-process
  625. if normalize is not False:
  626. table = _normalize(
  627. table, normalize=normalize, margins=margins, margins_name=margins_name
  628. )
  629. table = table.rename_axis(index=rownames_mapper, axis=0)
  630. table = table.rename_axis(columns=colnames_mapper, axis=1)
  631. return table
  632. def _normalize(
  633. table: DataFrame, normalize, margins: bool, margins_name: Hashable = "All"
  634. ) -> DataFrame:
  635. if not isinstance(normalize, (bool, str)):
  636. axis_subs = {0: "index", 1: "columns"}
  637. try:
  638. normalize = axis_subs[normalize]
  639. except KeyError as err:
  640. raise ValueError("Not a valid normalize argument") from err
  641. if margins is False:
  642. # Actual Normalizations
  643. normalizers: dict[bool | str, Callable] = {
  644. "all": lambda x: x / x.sum(axis=1).sum(axis=0),
  645. "columns": lambda x: x / x.sum(),
  646. "index": lambda x: x.div(x.sum(axis=1), axis=0),
  647. }
  648. normalizers[True] = normalizers["all"]
  649. try:
  650. f = normalizers[normalize]
  651. except KeyError as err:
  652. raise ValueError("Not a valid normalize argument") from err
  653. table = f(table)
  654. table = table.fillna(0)
  655. elif margins is True:
  656. # keep index and column of pivoted table
  657. table_index = table.index
  658. table_columns = table.columns
  659. last_ind_or_col = table.iloc[-1, :].name
  660. # check if margin name is not in (for MI cases) and not equal to last
  661. # index/column and save the column and index margin
  662. if (margins_name not in last_ind_or_col) & (margins_name != last_ind_or_col):
  663. raise ValueError(f"{margins_name} not in pivoted DataFrame")
  664. column_margin = table.iloc[:-1, -1]
  665. index_margin = table.iloc[-1, :-1]
  666. # keep the core table
  667. table = table.iloc[:-1, :-1]
  668. # Normalize core
  669. table = _normalize(table, normalize=normalize, margins=False)
  670. # Fix Margins
  671. if normalize == "columns":
  672. column_margin = column_margin / column_margin.sum()
  673. table = concat([table, column_margin], axis=1)
  674. table = table.fillna(0)
  675. table.columns = table_columns
  676. elif normalize == "index":
  677. index_margin = index_margin / index_margin.sum()
  678. table = table._append(index_margin)
  679. table = table.fillna(0)
  680. table.index = table_index
  681. elif normalize == "all" or normalize is True:
  682. column_margin = column_margin / column_margin.sum()
  683. index_margin = index_margin / index_margin.sum()
  684. index_margin.loc[margins_name] = 1
  685. table = concat([table, column_margin], axis=1)
  686. table = table._append(index_margin)
  687. table = table.fillna(0)
  688. table.index = table_index
  689. table.columns = table_columns
  690. else:
  691. raise ValueError("Not a valid normalize argument")
  692. else:
  693. raise ValueError("Not a valid margins argument")
  694. return table
  695. def _get_names(arrs, names, prefix: str = "row"):
  696. if names is None:
  697. names = []
  698. for i, arr in enumerate(arrs):
  699. if isinstance(arr, ABCSeries) and arr.name is not None:
  700. names.append(arr.name)
  701. else:
  702. names.append(f"{prefix}_{i}")
  703. else:
  704. if len(names) != len(arrs):
  705. raise AssertionError("arrays and names must have the same length")
  706. if not isinstance(names, list):
  707. names = list(names)
  708. return names
  709. def _build_names_mapper(
  710. rownames: list[str], colnames: list[str]
  711. ) -> tuple[dict[str, str], list[str], dict[str, str], list[str]]:
  712. """
  713. Given the names of a DataFrame's rows and columns, returns a set of unique row
  714. and column names and mappers that convert to original names.
  715. A row or column name is replaced if it is duplicate among the rows of the inputs,
  716. among the columns of the inputs or between the rows and the columns.
  717. Parameters
  718. ----------
  719. rownames: list[str]
  720. colnames: list[str]
  721. Returns
  722. -------
  723. Tuple(Dict[str, str], List[str], Dict[str, str], List[str])
  724. rownames_mapper: dict[str, str]
  725. a dictionary with new row names as keys and original rownames as values
  726. unique_rownames: list[str]
  727. a list of rownames with duplicate names replaced by dummy names
  728. colnames_mapper: dict[str, str]
  729. a dictionary with new column names as keys and original column names as values
  730. unique_colnames: list[str]
  731. a list of column names with duplicate names replaced by dummy names
  732. """
  733. def get_duplicates(names):
  734. seen: set = set()
  735. return {name for name in names if name not in seen}
  736. shared_names = set(rownames).intersection(set(colnames))
  737. dup_names = get_duplicates(rownames) | get_duplicates(colnames) | shared_names
  738. rownames_mapper = {
  739. f"row_{i}": name for i, name in enumerate(rownames) if name in dup_names
  740. }
  741. unique_rownames = [
  742. f"row_{i}" if name in dup_names else name for i, name in enumerate(rownames)
  743. ]
  744. colnames_mapper = {
  745. f"col_{i}": name for i, name in enumerate(colnames) if name in dup_names
  746. }
  747. unique_colnames = [
  748. f"col_{i}" if name in dup_names else name for i, name in enumerate(colnames)
  749. ]
  750. return rownames_mapper, unique_rownames, colnames_mapper, unique_colnames