kde.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. """
  2. This module was copied from the scipy project.
  3. In the process of copying, some methods were removed because they depended on
  4. other parts of scipy (especially on compiled components), allowing seaborn to
  5. have a simple and pure Python implementation. These include:
  6. - integrate_gaussian
  7. - integrate_box
  8. - integrate_box_1d
  9. - integrate_kde
  10. - logpdf
  11. - resample
  12. Additionally, the numpy.linalg module was substituted for scipy.linalg,
  13. and the examples section (with doctests) was removed from the docstring
  14. The original scipy license is copied below:
  15. Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
  16. All rights reserved.
  17. Redistribution and use in source and binary forms, with or without
  18. modification, are permitted provided that the following conditions
  19. are met:
  20. 1. Redistributions of source code must retain the above copyright
  21. notice, this list of conditions and the following disclaimer.
  22. 2. Redistributions in binary form must reproduce the above
  23. copyright notice, this list of conditions and the following
  24. disclaimer in the documentation and/or other materials provided
  25. with the distribution.
  26. 3. Neither the name of the copyright holder nor the names of its
  27. contributors may be used to endorse or promote products derived
  28. from this software without specific prior written permission.
  29. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  30. "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  31. LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  32. A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  33. OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  34. SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  35. LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  36. DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  37. THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  38. (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  39. OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  40. """
  41. # -------------------------------------------------------------------------------
  42. #
  43. # Define classes for (uni/multi)-variate kernel density estimation.
  44. #
  45. # Currently, only Gaussian kernels are implemented.
  46. #
  47. # Written by: Robert Kern
  48. #
  49. # Date: 2004-08-09
  50. #
  51. # Modified: 2005-02-10 by Robert Kern.
  52. # Contributed to SciPy
  53. # 2005-10-07 by Robert Kern.
  54. # Some fixes to match the new scipy_core
  55. #
  56. # Copyright 2004-2005 by Enthought, Inc.
  57. #
  58. # -------------------------------------------------------------------------------
  59. import numpy as np
  60. from numpy import (asarray, atleast_2d, reshape, zeros, newaxis, dot, exp, pi,
  61. sqrt, power, atleast_1d, sum, ones, cov)
  62. from numpy import linalg
  63. __all__ = ['gaussian_kde']
  64. class gaussian_kde:
  65. """Representation of a kernel-density estimate using Gaussian kernels.
  66. Kernel density estimation is a way to estimate the probability density
  67. function (PDF) of a random variable in a non-parametric way.
  68. `gaussian_kde` works for both uni-variate and multi-variate data. It
  69. includes automatic bandwidth determination. The estimation works best for
  70. a unimodal distribution; bimodal or multi-modal distributions tend to be
  71. oversmoothed.
  72. Parameters
  73. ----------
  74. dataset : array_like
  75. Datapoints to estimate from. In case of univariate data this is a 1-D
  76. array, otherwise a 2-D array with shape (# of dims, # of data).
  77. bw_method : str, scalar or callable, optional
  78. The method used to calculate the estimator bandwidth. This can be
  79. 'scott', 'silverman', a scalar constant or a callable. If a scalar,
  80. this will be used directly as `kde.factor`. If a callable, it should
  81. take a `gaussian_kde` instance as only parameter and return a scalar.
  82. If None (default), 'scott' is used. See Notes for more details.
  83. weights : array_like, optional
  84. weights of datapoints. This must be the same shape as dataset.
  85. If None (default), the samples are assumed to be equally weighted
  86. Attributes
  87. ----------
  88. dataset : ndarray
  89. The dataset with which `gaussian_kde` was initialized.
  90. d : int
  91. Number of dimensions.
  92. n : int
  93. Number of datapoints.
  94. neff : int
  95. Effective number of datapoints.
  96. .. versionadded:: 1.2.0
  97. factor : float
  98. The bandwidth factor, obtained from `kde.covariance_factor`, with which
  99. the covariance matrix is multiplied.
  100. covariance : ndarray
  101. The covariance matrix of `dataset`, scaled by the calculated bandwidth
  102. (`kde.factor`).
  103. inv_cov : ndarray
  104. The inverse of `covariance`.
  105. Methods
  106. -------
  107. evaluate
  108. __call__
  109. integrate_gaussian
  110. integrate_box_1d
  111. integrate_box
  112. integrate_kde
  113. pdf
  114. logpdf
  115. resample
  116. set_bandwidth
  117. covariance_factor
  118. Notes
  119. -----
  120. Bandwidth selection strongly influences the estimate obtained from the KDE
  121. (much more so than the actual shape of the kernel). Bandwidth selection
  122. can be done by a "rule of thumb", by cross-validation, by "plug-in
  123. methods" or by other means; see [3]_, [4]_ for reviews. `gaussian_kde`
  124. uses a rule of thumb, the default is Scott's Rule.
  125. Scott's Rule [1]_, implemented as `scotts_factor`, is::
  126. n**(-1./(d+4)),
  127. with ``n`` the number of data points and ``d`` the number of dimensions.
  128. In the case of unequally weighted points, `scotts_factor` becomes::
  129. neff**(-1./(d+4)),
  130. with ``neff`` the effective number of datapoints.
  131. Silverman's Rule [2]_, implemented as `silverman_factor`, is::
  132. (n * (d + 2) / 4.)**(-1. / (d + 4)).
  133. or in the case of unequally weighted points::
  134. (neff * (d + 2) / 4.)**(-1. / (d + 4)).
  135. Good general descriptions of kernel density estimation can be found in [1]_
  136. and [2]_, the mathematics for this multi-dimensional implementation can be
  137. found in [1]_.
  138. With a set of weighted samples, the effective number of datapoints ``neff``
  139. is defined by::
  140. neff = sum(weights)^2 / sum(weights^2)
  141. as detailed in [5]_.
  142. References
  143. ----------
  144. .. [1] D.W. Scott, "Multivariate Density Estimation: Theory, Practice, and
  145. Visualization", John Wiley & Sons, New York, Chicester, 1992.
  146. .. [2] B.W. Silverman, "Density Estimation for Statistics and Data
  147. Analysis", Vol. 26, Monographs on Statistics and Applied Probability,
  148. Chapman and Hall, London, 1986.
  149. .. [3] B.A. Turlach, "Bandwidth Selection in Kernel Density Estimation: A
  150. Review", CORE and Institut de Statistique, Vol. 19, pp. 1-33, 1993.
  151. .. [4] D.M. Bashtannyk and R.J. Hyndman, "Bandwidth selection for kernel
  152. conditional density estimation", Computational Statistics & Data
  153. Analysis, Vol. 36, pp. 279-298, 2001.
  154. .. [5] Gray P. G., 1969, Journal of the Royal Statistical Society.
  155. Series A (General), 132, 272
  156. """
  157. def __init__(self, dataset, bw_method=None, weights=None):
  158. self.dataset = atleast_2d(asarray(dataset))
  159. if not self.dataset.size > 1:
  160. raise ValueError("`dataset` input should have multiple elements.")
  161. self.d, self.n = self.dataset.shape
  162. if weights is not None:
  163. self._weights = atleast_1d(weights).astype(float)
  164. self._weights /= sum(self._weights)
  165. if self.weights.ndim != 1:
  166. raise ValueError("`weights` input should be one-dimensional.")
  167. if len(self._weights) != self.n:
  168. raise ValueError("`weights` input should be of length n")
  169. self._neff = 1/sum(self._weights**2)
  170. self.set_bandwidth(bw_method=bw_method)
  171. def evaluate(self, points):
  172. """Evaluate the estimated pdf on a set of points.
  173. Parameters
  174. ----------
  175. points : (# of dimensions, # of points)-array
  176. Alternatively, a (# of dimensions,) vector can be passed in and
  177. treated as a single point.
  178. Returns
  179. -------
  180. values : (# of points,)-array
  181. The values at each point.
  182. Raises
  183. ------
  184. ValueError : if the dimensionality of the input points is different than
  185. the dimensionality of the KDE.
  186. """
  187. points = atleast_2d(asarray(points))
  188. d, m = points.shape
  189. if d != self.d:
  190. if d == 1 and m == self.d:
  191. # points was passed in as a row vector
  192. points = reshape(points, (self.d, 1))
  193. m = 1
  194. else:
  195. msg = f"points have dimension {d}, dataset has dimension {self.d}"
  196. raise ValueError(msg)
  197. output_dtype = np.common_type(self.covariance, points)
  198. result = zeros((m,), dtype=output_dtype)
  199. whitening = linalg.cholesky(self.inv_cov)
  200. scaled_dataset = dot(whitening, self.dataset)
  201. scaled_points = dot(whitening, points)
  202. if m >= self.n:
  203. # there are more points than data, so loop over data
  204. for i in range(self.n):
  205. diff = scaled_dataset[:, i, newaxis] - scaled_points
  206. energy = sum(diff * diff, axis=0) / 2.0
  207. result += self.weights[i]*exp(-energy)
  208. else:
  209. # loop over points
  210. for i in range(m):
  211. diff = scaled_dataset - scaled_points[:, i, newaxis]
  212. energy = sum(diff * diff, axis=0) / 2.0
  213. result[i] = sum(exp(-energy)*self.weights, axis=0)
  214. result = result / self._norm_factor
  215. return result
  216. __call__ = evaluate
  217. def scotts_factor(self):
  218. """Compute Scott's factor.
  219. Returns
  220. -------
  221. s : float
  222. Scott's factor.
  223. """
  224. return power(self.neff, -1./(self.d+4))
  225. def silverman_factor(self):
  226. """Compute the Silverman factor.
  227. Returns
  228. -------
  229. s : float
  230. The silverman factor.
  231. """
  232. return power(self.neff*(self.d+2.0)/4.0, -1./(self.d+4))
  233. # Default method to calculate bandwidth, can be overwritten by subclass
  234. covariance_factor = scotts_factor
  235. covariance_factor.__doc__ = """Computes the coefficient (`kde.factor`) that
  236. multiplies the data covariance matrix to obtain the kernel covariance
  237. matrix. The default is `scotts_factor`. A subclass can overwrite this
  238. method to provide a different method, or set it through a call to
  239. `kde.set_bandwidth`."""
  240. def set_bandwidth(self, bw_method=None):
  241. """Compute the estimator bandwidth with given method.
  242. The new bandwidth calculated after a call to `set_bandwidth` is used
  243. for subsequent evaluations of the estimated density.
  244. Parameters
  245. ----------
  246. bw_method : str, scalar or callable, optional
  247. The method used to calculate the estimator bandwidth. This can be
  248. 'scott', 'silverman', a scalar constant or a callable. If a
  249. scalar, this will be used directly as `kde.factor`. If a callable,
  250. it should take a `gaussian_kde` instance as only parameter and
  251. return a scalar. If None (default), nothing happens; the current
  252. `kde.covariance_factor` method is kept.
  253. Notes
  254. -----
  255. .. versionadded:: 0.11
  256. """
  257. if bw_method is None:
  258. pass
  259. elif bw_method == 'scott':
  260. self.covariance_factor = self.scotts_factor
  261. elif bw_method == 'silverman':
  262. self.covariance_factor = self.silverman_factor
  263. elif np.isscalar(bw_method) and not isinstance(bw_method, str):
  264. self._bw_method = 'use constant'
  265. self.covariance_factor = lambda: bw_method
  266. elif callable(bw_method):
  267. self._bw_method = bw_method
  268. self.covariance_factor = lambda: self._bw_method(self)
  269. else:
  270. msg = "`bw_method` should be 'scott', 'silverman', a scalar " \
  271. "or a callable."
  272. raise ValueError(msg)
  273. self._compute_covariance()
  274. def _compute_covariance(self):
  275. """Computes the covariance matrix for each Gaussian kernel using
  276. covariance_factor().
  277. """
  278. self.factor = self.covariance_factor()
  279. # Cache covariance and inverse covariance of the data
  280. if not hasattr(self, '_data_inv_cov'):
  281. self._data_covariance = atleast_2d(cov(self.dataset, rowvar=1,
  282. bias=False,
  283. aweights=self.weights))
  284. self._data_inv_cov = linalg.inv(self._data_covariance)
  285. self.covariance = self._data_covariance * self.factor**2
  286. self.inv_cov = self._data_inv_cov / self.factor**2
  287. self._norm_factor = sqrt(linalg.det(2*pi*self.covariance))
  288. def pdf(self, x):
  289. """
  290. Evaluate the estimated pdf on a provided set of points.
  291. Notes
  292. -----
  293. This is an alias for `gaussian_kde.evaluate`. See the ``evaluate``
  294. docstring for more details.
  295. """
  296. return self.evaluate(x)
  297. @property
  298. def weights(self):
  299. try:
  300. return self._weights
  301. except AttributeError:
  302. self._weights = ones(self.n)/self.n
  303. return self._weights
  304. @property
  305. def neff(self):
  306. try:
  307. return self._neff
  308. except AttributeError:
  309. self._neff = 1/sum(self.weights**2)
  310. return self._neff