base.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """Base module for statistical transformations."""
  2. from __future__ import annotations
  3. from collections.abc import Iterable
  4. from dataclasses import dataclass
  5. from typing import ClassVar, Any
  6. import warnings
  7. from typing import TYPE_CHECKING
  8. if TYPE_CHECKING:
  9. from pandas import DataFrame
  10. from seaborn._core.groupby import GroupBy
  11. from seaborn._core.scales import Scale
  12. @dataclass
  13. class Stat:
  14. """Base class for objects that apply statistical transformations."""
  15. # The class supports a partial-function application pattern. The object is
  16. # initialized with desired parameters and the result is a callable that
  17. # accepts and returns dataframes.
  18. # The statistical transformation logic should not add any state to the instance
  19. # beyond what is defined with the initialization parameters.
  20. # Subclasses can declare whether the orient dimension should be used in grouping
  21. # TODO consider whether this should be a parameter. Motivating example:
  22. # use the same KDE class violin plots and univariate density estimation.
  23. # In the former case, we would expect separate densities for each unique
  24. # value on the orient axis, but we would not in the latter case.
  25. group_by_orient: ClassVar[bool] = False
  26. def _check_param_one_of(self, param: str, options: Iterable[Any]) -> None:
  27. """Raise when parameter value is not one of a specified set."""
  28. value = getattr(self, param)
  29. if value not in options:
  30. *most, last = options
  31. option_str = ", ".join(f"{x!r}" for x in most[:-1]) + f" or {last!r}"
  32. err = " ".join([
  33. f"The `{param}` parameter for `{self.__class__.__name__}` must be",
  34. f"one of {option_str}; not {value!r}.",
  35. ])
  36. raise ValueError(err)
  37. def _check_grouping_vars(
  38. self, param: str, data_vars: list[str], stacklevel: int = 2,
  39. ) -> None:
  40. """Warn if vars are named in parameter without being present in the data."""
  41. param_vars = getattr(self, param)
  42. undefined = set(param_vars) - set(data_vars)
  43. if undefined:
  44. param = f"{self.__class__.__name__}.{param}"
  45. names = ", ".join(f"{x!r}" for x in undefined)
  46. msg = f"Undefined variable(s) passed for {param}: {names}."
  47. warnings.warn(msg, stacklevel=stacklevel)
  48. def __call__(
  49. self,
  50. data: DataFrame,
  51. groupby: GroupBy,
  52. orient: str,
  53. scales: dict[str, Scale],
  54. ) -> DataFrame:
  55. """Apply statistical transform to data subgroups and return combined result."""
  56. return data