counting.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import ClassVar
  4. import numpy as np
  5. import pandas as pd
  6. from pandas import DataFrame
  7. from seaborn._core.groupby import GroupBy
  8. from seaborn._core.scales import Scale
  9. from seaborn._stats.base import Stat
  10. from typing import TYPE_CHECKING
  11. if TYPE_CHECKING:
  12. from numpy.typing import ArrayLike
  13. @dataclass
  14. class Count(Stat):
  15. """
  16. Count distinct observations within groups.
  17. See Also
  18. --------
  19. Hist : A more fully-featured transform including binning and/or normalization.
  20. Examples
  21. --------
  22. .. include:: ../docstrings/objects.Count.rst
  23. """
  24. group_by_orient: ClassVar[bool] = True
  25. def __call__(
  26. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  27. ) -> DataFrame:
  28. var = {"x": "y", "y": "x"}[orient]
  29. res = (
  30. groupby
  31. .agg(data.assign(**{var: data[orient]}), {var: len})
  32. .dropna(subset=["x", "y"])
  33. .reset_index(drop=True)
  34. )
  35. return res
  36. @dataclass
  37. class Hist(Stat):
  38. """
  39. Bin observations, count them, and optionally normalize or cumulate.
  40. Parameters
  41. ----------
  42. stat : str
  43. Aggregate statistic to compute in each bin:
  44. - `count`: the number of observations
  45. - `density`: normalize so that the total area of the histogram equals 1
  46. - `percent`: normalize so that bar heights sum to 100
  47. - `probability` or `proportion`: normalize so that bar heights sum to 1
  48. - `frequency`: divide the number of observations by the bin width
  49. bins : str, int, or ArrayLike
  50. Generic parameter that can be the name of a reference rule, the number
  51. of bins, or the bin breaks. Passed to :func:`numpy.histogram_bin_edges`.
  52. binwidth : float
  53. Width of each bin; overrides `bins` but can be used with `binrange`.
  54. Note that if `binwidth` does not evenly divide the bin range, the actual
  55. bin width used will be only approximately equal to the parameter value.
  56. binrange : (min, max)
  57. Lowest and highest value for bin edges; can be used with either
  58. `bins` (when a number) or `binwidth`. Defaults to data extremes.
  59. common_norm : bool or list of variables
  60. When not `False`, the normalization is applied across groups. Use
  61. `True` to normalize across all groups, or pass variable name(s) that
  62. define normalization groups.
  63. common_bins : bool or list of variables
  64. When not `False`, the same bins are used for all groups. Use `True` to
  65. share bins across all groups, or pass variable name(s) to share within.
  66. cumulative : bool
  67. If True, cumulate the bin values.
  68. discrete : bool
  69. If True, set `binwidth` and `binrange` so that bins have unit width and
  70. are centered on integer values
  71. Notes
  72. -----
  73. The choice of bins for computing and plotting a histogram can exert
  74. substantial influence on the insights that one is able to draw from the
  75. visualization. If the bins are too large, they may erase important features.
  76. On the other hand, bins that are too small may be dominated by random
  77. variability, obscuring the shape of the true underlying distribution. The
  78. default bin size is determined using a reference rule that depends on the
  79. sample size and variance. This works well in many cases, (i.e., with
  80. "well-behaved" data) but it fails in others. It is always a good to try
  81. different bin sizes to be sure that you are not missing something important.
  82. This function allows you to specify bins in several different ways, such as
  83. by setting the total number of bins to use, the width of each bin, or the
  84. specific locations where the bins should break.
  85. Examples
  86. --------
  87. .. include:: ../docstrings/objects.Hist.rst
  88. """
  89. stat: str = "count"
  90. bins: str | int | ArrayLike = "auto"
  91. binwidth: float | None = None
  92. binrange: tuple[float, float] | None = None
  93. common_norm: bool | list[str] = True
  94. common_bins: bool | list[str] = True
  95. cumulative: bool = False
  96. discrete: bool = False
  97. def __post_init__(self):
  98. stat_options = [
  99. "count", "density", "percent", "probability", "proportion", "frequency"
  100. ]
  101. self._check_param_one_of("stat", stat_options)
  102. def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete):
  103. """Inner function that takes bin parameters as arguments."""
  104. vals = vals.replace(-np.inf, np.nan).replace(np.inf, np.nan).dropna()
  105. if binrange is None:
  106. start, stop = vals.min(), vals.max()
  107. else:
  108. start, stop = binrange
  109. if discrete:
  110. bin_edges = np.arange(start - .5, stop + 1.5)
  111. else:
  112. if binwidth is not None:
  113. bins = int(round((stop - start) / binwidth))
  114. bin_edges = np.histogram_bin_edges(vals, bins, binrange, weight)
  115. # TODO warning or cap on too many bins?
  116. return bin_edges
  117. def _define_bin_params(self, data, orient, scale_type):
  118. """Given data, return numpy.histogram parameters to define bins."""
  119. vals = data[orient]
  120. weights = data.get("weight", None)
  121. # TODO We'll want this for ordinal / discrete scales too
  122. # (Do we need discrete as a parameter or just infer from scale?)
  123. discrete = self.discrete or scale_type == "nominal"
  124. bin_edges = self._define_bin_edges(
  125. vals, weights, self.bins, self.binwidth, self.binrange, discrete,
  126. )
  127. if isinstance(self.bins, (str, int)):
  128. n_bins = len(bin_edges) - 1
  129. bin_range = bin_edges.min(), bin_edges.max()
  130. bin_kws = dict(bins=n_bins, range=bin_range)
  131. else:
  132. bin_kws = dict(bins=bin_edges)
  133. return bin_kws
  134. def _get_bins_and_eval(self, data, orient, groupby, scale_type):
  135. bin_kws = self._define_bin_params(data, orient, scale_type)
  136. return groupby.apply(data, self._eval, orient, bin_kws)
  137. def _eval(self, data, orient, bin_kws):
  138. vals = data[orient]
  139. weights = data.get("weight", None)
  140. density = self.stat == "density"
  141. hist, edges = np.histogram(vals, **bin_kws, weights=weights, density=density)
  142. width = np.diff(edges)
  143. center = edges[:-1] + width / 2
  144. return pd.DataFrame({orient: center, "count": hist, "space": width})
  145. def _normalize(self, data):
  146. hist = data["count"]
  147. if self.stat == "probability" or self.stat == "proportion":
  148. hist = hist.astype(float) / hist.sum()
  149. elif self.stat == "percent":
  150. hist = hist.astype(float) / hist.sum() * 100
  151. elif self.stat == "frequency":
  152. hist = hist.astype(float) / data["space"]
  153. if self.cumulative:
  154. if self.stat in ["density", "frequency"]:
  155. hist = (hist * data["space"]).cumsum()
  156. else:
  157. hist = hist.cumsum()
  158. return data.assign(**{self.stat: hist})
  159. def __call__(
  160. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  161. ) -> DataFrame:
  162. scale_type = scales[orient].__class__.__name__.lower()
  163. grouping_vars = [str(v) for v in data if v in groupby.order]
  164. if not grouping_vars or self.common_bins is True:
  165. bin_kws = self._define_bin_params(data, orient, scale_type)
  166. data = groupby.apply(data, self._eval, orient, bin_kws)
  167. else:
  168. if self.common_bins is False:
  169. bin_groupby = GroupBy(grouping_vars)
  170. else:
  171. bin_groupby = GroupBy(self.common_bins)
  172. self._check_grouping_vars("common_bins", grouping_vars)
  173. data = bin_groupby.apply(
  174. data, self._get_bins_and_eval, orient, groupby, scale_type,
  175. )
  176. if not grouping_vars or self.common_norm is True:
  177. data = self._normalize(data)
  178. else:
  179. if self.common_norm is False:
  180. norm_groupby = GroupBy(grouping_vars)
  181. else:
  182. norm_groupby = GroupBy(self.common_norm)
  183. self._check_grouping_vars("common_norm", grouping_vars)
  184. data = norm_groupby.apply(data, self._normalize)
  185. other = {"x": "y", "y": "x"}[orient]
  186. return data.assign(**{other: data[self.stat]})