aggregation.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import ClassVar, Callable
  4. import pandas as pd
  5. from pandas import DataFrame
  6. from seaborn._core.scales import Scale
  7. from seaborn._core.groupby import GroupBy
  8. from seaborn._stats.base import Stat
  9. from seaborn._statistics import EstimateAggregator
  10. from seaborn._core.typing import Vector
  11. @dataclass
  12. class Agg(Stat):
  13. """
  14. Aggregate data along the value axis using given method.
  15. Parameters
  16. ----------
  17. func : str or callable
  18. Name of a :class:`pandas.Series` method or a vector -> scalar function.
  19. See Also
  20. --------
  21. objects.Est : Aggregation with error bars.
  22. Examples
  23. --------
  24. .. include:: ../docstrings/objects.Agg.rst
  25. """
  26. func: str | Callable[[Vector], float] = "mean"
  27. group_by_orient: ClassVar[bool] = True
  28. def __call__(
  29. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  30. ) -> DataFrame:
  31. var = {"x": "y", "y": "x"}.get(orient)
  32. res = (
  33. groupby
  34. .agg(data, {var: self.func})
  35. .dropna(subset=[var])
  36. .reset_index(drop=True)
  37. )
  38. return res
  39. @dataclass
  40. class Est(Stat):
  41. """
  42. Calculate a point estimate and error bar interval.
  43. For additional information about the various `errorbar` choices, see
  44. the :doc:`errorbar tutorial </tutorial/error_bars>`.
  45. Parameters
  46. ----------
  47. func : str or callable
  48. Name of a :class:`numpy.ndarray` method or a vector -> scalar function.
  49. errorbar : str, (str, float) tuple, or callable
  50. Name of errorbar method (one of "ci", "pi", "se" or "sd"), or a tuple
  51. with a method name ane a level parameter, or a function that maps from a
  52. vector to a (min, max) interval.
  53. n_boot : int
  54. Number of bootstrap samples to draw for "ci" errorbars.
  55. seed : int
  56. Seed for the PRNG used to draw bootstrap samples.
  57. Examples
  58. --------
  59. .. include:: ../docstrings/objects.Est.rst
  60. """
  61. func: str | Callable[[Vector], float] = "mean"
  62. errorbar: str | tuple[str, float] = ("ci", 95)
  63. n_boot: int = 1000
  64. seed: int | None = None
  65. group_by_orient: ClassVar[bool] = True
  66. def _process(
  67. self, data: DataFrame, var: str, estimator: EstimateAggregator
  68. ) -> DataFrame:
  69. # Needed because GroupBy.apply assumes func is DataFrame -> DataFrame
  70. # which we could probably make more general to allow Series return
  71. res = estimator(data, var)
  72. return pd.DataFrame([res])
  73. def __call__(
  74. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  75. ) -> DataFrame:
  76. boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
  77. engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)
  78. var = {"x": "y", "y": "x"}[orient]
  79. res = (
  80. groupby
  81. .apply(data, self._process, var, engine)
  82. .dropna(subset=[var])
  83. .reset_index(drop=True)
  84. )
  85. res = res.fillna({f"{var}min": res[var], f"{var}max": res[var]})
  86. return res
  87. @dataclass
  88. class Rolling(Stat):
  89. ...
  90. def __call__(self, data, groupby, orient, scales):
  91. ...