groupby.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. """Simplified split-apply-combine paradigm on dataframes for internal use."""
  2. from __future__ import annotations
  3. from typing import cast, Iterable
  4. import pandas as pd
  5. from seaborn._core.rules import categorical_order
  6. from typing import TYPE_CHECKING
  7. if TYPE_CHECKING:
  8. from typing import Callable
  9. from pandas import DataFrame, MultiIndex, Index
  10. class GroupBy:
  11. """
  12. Interface for Pandas GroupBy operations allowing specified group order.
  13. Writing our own class to do this has a few advantages:
  14. - It constrains the interface between Plot and Stat/Move objects
  15. - It allows control over the row order of the GroupBy result, which is
  16. important when using in the context of some Move operations (dodge, stack, ...)
  17. - It simplifies some complexities regarding the return type and Index contents
  18. one encounters with Pandas, especially for DataFrame -> DataFrame applies
  19. - It increases future flexibility regarding alternate DataFrame libraries
  20. """
  21. def __init__(self, order: list[str] | dict[str, list | None]):
  22. """
  23. Initialize the GroupBy from grouping variables and optional level orders.
  24. Parameters
  25. ----------
  26. order
  27. List of variable names or dict mapping names to desired level orders.
  28. Level order values can be None to use default ordering rules. The
  29. variables can include names that are not expected to appear in the
  30. data; these will be dropped before the groups are defined.
  31. """
  32. if not order:
  33. raise ValueError("GroupBy requires at least one grouping variable")
  34. if isinstance(order, list):
  35. order = {k: None for k in order}
  36. self.order = order
  37. def _get_groups(
  38. self, data: DataFrame
  39. ) -> tuple[str | list[str], Index | MultiIndex]:
  40. """Return index with Cartesian product of ordered grouping variable levels."""
  41. levels = {}
  42. for var, order in self.order.items():
  43. if var in data:
  44. if order is None:
  45. order = categorical_order(data[var])
  46. levels[var] = order
  47. grouper: str | list[str]
  48. groups: Index | MultiIndex
  49. if not levels:
  50. grouper = []
  51. groups = pd.Index([])
  52. elif len(levels) > 1:
  53. grouper = list(levels)
  54. groups = pd.MultiIndex.from_product(levels.values(), names=grouper)
  55. else:
  56. grouper, = list(levels)
  57. groups = pd.Index(levels[grouper], name=grouper)
  58. return grouper, groups
  59. def _reorder_columns(self, res, data):
  60. """Reorder result columns to match original order with new columns appended."""
  61. cols = [c for c in data if c in res]
  62. cols += [c for c in res if c not in data]
  63. return res.reindex(columns=pd.Index(cols))
  64. def agg(self, data: DataFrame, *args, **kwargs) -> DataFrame:
  65. """
  66. Reduce each group to a single row in the output.
  67. The output will have a row for each unique combination of the grouping
  68. variable levels with null values for the aggregated variable(s) where
  69. those combinations do not appear in the dataset.
  70. """
  71. grouper, groups = self._get_groups(data)
  72. if not grouper:
  73. # We will need to see whether there are valid usecases that end up here
  74. raise ValueError("No grouping variables are present in dataframe")
  75. res = (
  76. data
  77. .groupby(grouper, sort=False, observed=False)
  78. .agg(*args, **kwargs)
  79. .reindex(groups)
  80. .reset_index()
  81. .pipe(self._reorder_columns, data)
  82. )
  83. return res
  84. def apply(
  85. self, data: DataFrame, func: Callable[..., DataFrame],
  86. *args, **kwargs,
  87. ) -> DataFrame:
  88. """Apply a DataFrame -> DataFrame mapping to each group."""
  89. grouper, groups = self._get_groups(data)
  90. if not grouper:
  91. return self._reorder_columns(func(data, *args, **kwargs), data)
  92. parts = {}
  93. for key, part_df in data.groupby(grouper, sort=False, observed=False):
  94. parts[key] = func(part_df, *args, **kwargs)
  95. stack = []
  96. for key in groups:
  97. if key in parts:
  98. if isinstance(grouper, list):
  99. # Implies that we had a MultiIndex so key is iterable
  100. group_ids = dict(zip(grouper, cast(Iterable, key)))
  101. else:
  102. group_ids = {grouper: key}
  103. stack.append(parts[key].assign(**group_ids))
  104. res = pd.concat(stack, ignore_index=True)
  105. return self._reorder_columns(res, data)