density.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import Any, Callable
  4. import numpy as np
  5. from numpy import ndarray
  6. import pandas as pd
  7. from pandas import DataFrame
  8. try:
  9. from scipy.stats import gaussian_kde
  10. _no_scipy = False
  11. except ImportError:
  12. from seaborn.external.kde import gaussian_kde
  13. _no_scipy = True
  14. from seaborn._core.groupby import GroupBy
  15. from seaborn._core.scales import Scale
  16. from seaborn._stats.base import Stat
  17. @dataclass
  18. class KDE(Stat):
  19. """
  20. Compute a univariate kernel density estimate.
  21. Parameters
  22. ----------
  23. bw_adjust : float
  24. Factor that multiplicatively scales the value chosen using
  25. `bw_method`. Increasing will make the curve smoother. See Notes.
  26. bw_method : string, scalar, or callable
  27. Method for determining the smoothing bandwidth to use. Passed directly
  28. to :class:`scipy.stats.gaussian_kde`; see there for options.
  29. common_norm : bool or list of variables
  30. If `True`, normalize so that the areas of all curves sums to 1.
  31. If `False`, normalize each curve independently. If a list, defines
  32. variable(s) to group by and normalize within.
  33. common_grid : bool or list of variables
  34. If `True`, all curves will share the same evaluation grid.
  35. If `False`, each evaluation grid is independent. If a list, defines
  36. variable(s) to group by and share a grid within.
  37. gridsize : int or None
  38. Number of points in the evaluation grid. If None, the density is
  39. evaluated at the original datapoints.
  40. cut : float
  41. Factor, multiplied by the kernel bandwidth, that determines how far
  42. the evaluation grid extends past the extreme datapoints. When set to 0,
  43. the curve is truncated at the data limits.
  44. cumulative : bool
  45. If True, estimate a cumulative distribution function. Requires scipy.
  46. Notes
  47. -----
  48. The *bandwidth*, or standard deviation of the smoothing kernel, is an
  49. important parameter. Much like histogram bin width, using the wrong
  50. bandwidth can produce a distorted representation. Over-smoothing can erase
  51. true features, while under-smoothing can create false ones. The default
  52. uses a rule-of-thumb that works best for distributions that are roughly
  53. bell-shaped. It is a good idea to check the default by varying `bw_adjust`.
  54. Because the smoothing is performed with a Gaussian kernel, the estimated
  55. density curve can extend to values that may not make sense. For example, the
  56. curve may be drawn over negative values when data that are naturally
  57. positive. The `cut` parameter can be used to control the evaluation range,
  58. but datasets that have many observations close to a natural boundary may be
  59. better served by a different method.
  60. Similar distortions may arise when a dataset is naturally discrete or "spiky"
  61. (containing many repeated observations of the same value). KDEs will always
  62. produce a smooth curve, which could be misleading.
  63. The units on the density axis are a common source of confusion. While kernel
  64. density estimation produces a probability distribution, the height of the curve
  65. at each point gives a density, not a probability. A probability can be obtained
  66. only by integrating the density across a range. The curve is normalized so
  67. that the integral over all possible values is 1, meaning that the scale of
  68. the density axis depends on the data values.
  69. If scipy is installed, its cython-accelerated implementation will be used.
  70. Examples
  71. --------
  72. .. include:: ../docstrings/objects.KDE.rst
  73. """
  74. bw_adjust: float = 1
  75. bw_method: str | float | Callable[[gaussian_kde], float] = "scott"
  76. common_norm: bool | list[str] = True
  77. common_grid: bool | list[str] = True
  78. gridsize: int | None = 200
  79. cut: float = 3
  80. cumulative: bool = False
  81. def __post_init__(self):
  82. if self.cumulative and _no_scipy:
  83. raise RuntimeError("Cumulative KDE evaluation requires scipy")
  84. def _check_var_list_or_boolean(self, param: str, grouping_vars: Any) -> None:
  85. """Do input checks on grouping parameters."""
  86. value = getattr(self, param)
  87. if not (
  88. isinstance(value, bool)
  89. or (isinstance(value, list) and all(isinstance(v, str) for v in value))
  90. ):
  91. param_name = f"{self.__class__.__name__}.{param}"
  92. raise TypeError(f"{param_name} must be a boolean or list of strings.")
  93. self._check_grouping_vars(param, grouping_vars, stacklevel=3)
  94. def _fit(self, data: DataFrame, orient: str) -> gaussian_kde:
  95. """Fit and return a KDE object."""
  96. # TODO need to handle singular data
  97. fit_kws: dict[str, Any] = {"bw_method": self.bw_method}
  98. if "weight" in data:
  99. fit_kws["weights"] = data["weight"]
  100. kde = gaussian_kde(data[orient], **fit_kws)
  101. kde.set_bandwidth(kde.factor * self.bw_adjust)
  102. return kde
  103. def _get_support(self, data: DataFrame, orient: str) -> ndarray:
  104. """Define the grid that the KDE will be evaluated on."""
  105. if self.gridsize is None:
  106. return data[orient].to_numpy()
  107. kde = self._fit(data, orient)
  108. bw = np.sqrt(kde.covariance.squeeze())
  109. gridmin = data[orient].min() - bw * self.cut
  110. gridmax = data[orient].max() + bw * self.cut
  111. return np.linspace(gridmin, gridmax, self.gridsize)
  112. def _fit_and_evaluate(
  113. self, data: DataFrame, orient: str, support: ndarray
  114. ) -> DataFrame:
  115. """Transform single group by fitting a KDE and evaluating on a support grid."""
  116. empty = pd.DataFrame(columns=[orient, "weight", "density"], dtype=float)
  117. if len(data) < 2:
  118. return empty
  119. try:
  120. kde = self._fit(data, orient)
  121. except np.linalg.LinAlgError:
  122. return empty
  123. if self.cumulative:
  124. s_0 = support[0]
  125. density = np.array([kde.integrate_box_1d(s_0, s_i) for s_i in support])
  126. else:
  127. density = kde(support)
  128. weight = data["weight"].sum()
  129. return pd.DataFrame({orient: support, "weight": weight, "density": density})
  130. def _transform(
  131. self, data: DataFrame, orient: str, grouping_vars: list[str]
  132. ) -> DataFrame:
  133. """Transform multiple groups by fitting KDEs and evaluating."""
  134. empty = pd.DataFrame(columns=[*data.columns, "density"], dtype=float)
  135. if len(data) < 2:
  136. return empty
  137. try:
  138. support = self._get_support(data, orient)
  139. except np.linalg.LinAlgError:
  140. return empty
  141. grouping_vars = [x for x in grouping_vars if data[x].nunique() > 1]
  142. if not grouping_vars:
  143. return self._fit_and_evaluate(data, orient, support)
  144. groupby = GroupBy(grouping_vars)
  145. return groupby.apply(data, self._fit_and_evaluate, orient, support)
  146. def __call__(
  147. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  148. ) -> DataFrame:
  149. if "weight" not in data:
  150. data = data.assign(weight=1)
  151. data = data.dropna(subset=[orient, "weight"])
  152. # Transform each group separately
  153. grouping_vars = [str(v) for v in data if v in groupby.order]
  154. if not grouping_vars or self.common_grid is True:
  155. res = self._transform(data, orient, grouping_vars)
  156. else:
  157. if self.common_grid is False:
  158. grid_vars = grouping_vars
  159. else:
  160. self._check_var_list_or_boolean("common_grid", grouping_vars)
  161. grid_vars = [v for v in self.common_grid if v in grouping_vars]
  162. res = (
  163. GroupBy(grid_vars)
  164. .apply(data, self._transform, orient, grouping_vars)
  165. )
  166. # Normalize, potentially within groups
  167. if not grouping_vars or self.common_norm is True:
  168. res = res.assign(group_weight=data["weight"].sum())
  169. else:
  170. if self.common_norm is False:
  171. norm_vars = grouping_vars
  172. else:
  173. self._check_var_list_or_boolean("common_norm", grouping_vars)
  174. norm_vars = [v for v in self.common_norm if v in grouping_vars]
  175. res = res.join(
  176. data.groupby(norm_vars)["weight"].sum().rename("group_weight"),
  177. on=norm_vars,
  178. )
  179. res["density"] *= res.eval("weight / group_weight")
  180. value = {"x": "y", "y": "x"}[orient]
  181. res[value] = res["density"]
  182. return res.drop(["weight", "group_weight"], axis=1)