_misc.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. from __future__ import annotations
  2. from contextlib import contextmanager
  3. from typing import (
  4. TYPE_CHECKING,
  5. Generator,
  6. )
  7. from pandas.plotting._core import _get_plot_backend
  8. if TYPE_CHECKING:
  9. from matplotlib.axes import Axes
  10. from matplotlib.figure import Figure
  11. import numpy as np
  12. from pandas import (
  13. DataFrame,
  14. Series,
  15. )
  16. def table(ax, data, **kwargs):
  17. """
  18. Helper function to convert DataFrame and Series to matplotlib.table.
  19. Parameters
  20. ----------
  21. ax : Matplotlib axes object
  22. data : DataFrame or Series
  23. Data for table contents.
  24. **kwargs
  25. Keyword arguments to be passed to matplotlib.table.table.
  26. If `rowLabels` or `colLabels` is not specified, data index or column
  27. name will be used.
  28. Returns
  29. -------
  30. matplotlib table object
  31. """
  32. plot_backend = _get_plot_backend("matplotlib")
  33. return plot_backend.table(
  34. ax=ax, data=data, rowLabels=None, colLabels=None, **kwargs
  35. )
  36. def register() -> None:
  37. """
  38. Register pandas formatters and converters with matplotlib.
  39. This function modifies the global ``matplotlib.units.registry``
  40. dictionary. pandas adds custom converters for
  41. * pd.Timestamp
  42. * pd.Period
  43. * np.datetime64
  44. * datetime.datetime
  45. * datetime.date
  46. * datetime.time
  47. See Also
  48. --------
  49. deregister_matplotlib_converters : Remove pandas formatters and converters.
  50. """
  51. plot_backend = _get_plot_backend("matplotlib")
  52. plot_backend.register()
  53. def deregister() -> None:
  54. """
  55. Remove pandas formatters and converters.
  56. Removes the custom converters added by :func:`register`. This
  57. attempts to set the state of the registry back to the state before
  58. pandas registered its own units. Converters for pandas' own types like
  59. Timestamp and Period are removed completely. Converters for types
  60. pandas overwrites, like ``datetime.datetime``, are restored to their
  61. original value.
  62. See Also
  63. --------
  64. register_matplotlib_converters : Register pandas formatters and converters
  65. with matplotlib.
  66. """
  67. plot_backend = _get_plot_backend("matplotlib")
  68. plot_backend.deregister()
  69. def scatter_matrix(
  70. frame: DataFrame,
  71. alpha: float = 0.5,
  72. figsize: tuple[float, float] | None = None,
  73. ax: Axes | None = None,
  74. grid: bool = False,
  75. diagonal: str = "hist",
  76. marker: str = ".",
  77. density_kwds=None,
  78. hist_kwds=None,
  79. range_padding: float = 0.05,
  80. **kwargs,
  81. ) -> np.ndarray:
  82. """
  83. Draw a matrix of scatter plots.
  84. Parameters
  85. ----------
  86. frame : DataFrame
  87. alpha : float, optional
  88. Amount of transparency applied.
  89. figsize : (float,float), optional
  90. A tuple (width, height) in inches.
  91. ax : Matplotlib axis object, optional
  92. grid : bool, optional
  93. Setting this to True will show the grid.
  94. diagonal : {'hist', 'kde'}
  95. Pick between 'kde' and 'hist' for either Kernel Density Estimation or
  96. Histogram plot in the diagonal.
  97. marker : str, optional
  98. Matplotlib marker type, default '.'.
  99. density_kwds : keywords
  100. Keyword arguments to be passed to kernel density estimate plot.
  101. hist_kwds : keywords
  102. Keyword arguments to be passed to hist function.
  103. range_padding : float, default 0.05
  104. Relative extension of axis range in x and y with respect to
  105. (x_max - x_min) or (y_max - y_min).
  106. **kwargs
  107. Keyword arguments to be passed to scatter function.
  108. Returns
  109. -------
  110. numpy.ndarray
  111. A matrix of scatter plots.
  112. Examples
  113. --------
  114. .. plot::
  115. :context: close-figs
  116. >>> df = pd.DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
  117. >>> pd.plotting.scatter_matrix(df, alpha=0.2)
  118. array([[<AxesSubplot: xlabel='A', ylabel='A'>,
  119. <AxesSubplot: xlabel='B', ylabel='A'>,
  120. <AxesSubplot: xlabel='C', ylabel='A'>,
  121. <AxesSubplot: xlabel='D', ylabel='A'>],
  122. [<AxesSubplot: xlabel='A', ylabel='B'>,
  123. <AxesSubplot: xlabel='B', ylabel='B'>,
  124. <AxesSubplot: xlabel='C', ylabel='B'>,
  125. <AxesSubplot: xlabel='D', ylabel='B'>],
  126. [<AxesSubplot: xlabel='A', ylabel='C'>,
  127. <AxesSubplot: xlabel='B', ylabel='C'>,
  128. <AxesSubplot: xlabel='C', ylabel='C'>,
  129. <AxesSubplot: xlabel='D', ylabel='C'>],
  130. [<AxesSubplot: xlabel='A', ylabel='D'>,
  131. <AxesSubplot: xlabel='B', ylabel='D'>,
  132. <AxesSubplot: xlabel='C', ylabel='D'>,
  133. <AxesSubplot: xlabel='D', ylabel='D'>]], dtype=object)
  134. """
  135. plot_backend = _get_plot_backend("matplotlib")
  136. return plot_backend.scatter_matrix(
  137. frame=frame,
  138. alpha=alpha,
  139. figsize=figsize,
  140. ax=ax,
  141. grid=grid,
  142. diagonal=diagonal,
  143. marker=marker,
  144. density_kwds=density_kwds,
  145. hist_kwds=hist_kwds,
  146. range_padding=range_padding,
  147. **kwargs,
  148. )
  149. def radviz(
  150. frame: DataFrame,
  151. class_column: str,
  152. ax: Axes | None = None,
  153. color: list[str] | tuple[str, ...] | None = None,
  154. colormap=None,
  155. **kwds,
  156. ) -> Axes:
  157. """
  158. Plot a multidimensional dataset in 2D.
  159. Each Series in the DataFrame is represented as a evenly distributed
  160. slice on a circle. Each data point is rendered in the circle according to
  161. the value on each Series. Highly correlated `Series` in the `DataFrame`
  162. are placed closer on the unit circle.
  163. RadViz allow to project a N-dimensional data set into a 2D space where the
  164. influence of each dimension can be interpreted as a balance between the
  165. influence of all dimensions.
  166. More info available at the `original article
  167. <https://doi.org/10.1145/331770.331775>`_
  168. describing RadViz.
  169. Parameters
  170. ----------
  171. frame : `DataFrame`
  172. Object holding the data.
  173. class_column : str
  174. Column name containing the name of the data point category.
  175. ax : :class:`matplotlib.axes.Axes`, optional
  176. A plot instance to which to add the information.
  177. color : list[str] or tuple[str], optional
  178. Assign a color to each category. Example: ['blue', 'green'].
  179. colormap : str or :class:`matplotlib.colors.Colormap`, default None
  180. Colormap to select colors from. If string, load colormap with that
  181. name from matplotlib.
  182. **kwds
  183. Options to pass to matplotlib scatter plotting method.
  184. Returns
  185. -------
  186. :class:`matplotlib.axes.Axes`
  187. See Also
  188. --------
  189. pandas.plotting.andrews_curves : Plot clustering visualization.
  190. Examples
  191. --------
  192. .. plot::
  193. :context: close-figs
  194. >>> df = pd.DataFrame(
  195. ... {
  196. ... 'SepalLength': [6.5, 7.7, 5.1, 5.8, 7.6, 5.0, 5.4, 4.6, 6.7, 4.6],
  197. ... 'SepalWidth': [3.0, 3.8, 3.8, 2.7, 3.0, 2.3, 3.0, 3.2, 3.3, 3.6],
  198. ... 'PetalLength': [5.5, 6.7, 1.9, 5.1, 6.6, 3.3, 4.5, 1.4, 5.7, 1.0],
  199. ... 'PetalWidth': [1.8, 2.2, 0.4, 1.9, 2.1, 1.0, 1.5, 0.2, 2.1, 0.2],
  200. ... 'Category': [
  201. ... 'virginica',
  202. ... 'virginica',
  203. ... 'setosa',
  204. ... 'virginica',
  205. ... 'virginica',
  206. ... 'versicolor',
  207. ... 'versicolor',
  208. ... 'setosa',
  209. ... 'virginica',
  210. ... 'setosa'
  211. ... ]
  212. ... }
  213. ... )
  214. >>> pd.plotting.radviz(df, 'Category')
  215. <AxesSubplot: xlabel='y(t)', ylabel='y(t + 1)'>
  216. """
  217. plot_backend = _get_plot_backend("matplotlib")
  218. return plot_backend.radviz(
  219. frame=frame,
  220. class_column=class_column,
  221. ax=ax,
  222. color=color,
  223. colormap=colormap,
  224. **kwds,
  225. )
  226. def andrews_curves(
  227. frame: DataFrame,
  228. class_column: str,
  229. ax: Axes | None = None,
  230. samples: int = 200,
  231. color: list[str] | tuple[str, ...] | None = None,
  232. colormap=None,
  233. **kwargs,
  234. ) -> Axes:
  235. """
  236. Generate a matplotlib plot for visualising clusters of multivariate data.
  237. Andrews curves have the functional form:
  238. .. math::
  239. f(t) = \\frac{x_1}{\\sqrt{2}} + x_2 \\sin(t) + x_3 \\cos(t) +
  240. x_4 \\sin(2t) + x_5 \\cos(2t) + \\cdots
  241. Where :math:`x` coefficients correspond to the values of each dimension
  242. and :math:`t` is linearly spaced between :math:`-\\pi` and :math:`+\\pi`.
  243. Each row of frame then corresponds to a single curve.
  244. Parameters
  245. ----------
  246. frame : DataFrame
  247. Data to be plotted, preferably normalized to (0.0, 1.0).
  248. class_column : label
  249. Name of the column containing class names.
  250. ax : axes object, default None
  251. Axes to use.
  252. samples : int
  253. Number of points to plot in each curve.
  254. color : str, list[str] or tuple[str], optional
  255. Colors to use for the different classes. Colors can be strings
  256. or 3-element floating point RGB values.
  257. colormap : str or matplotlib colormap object, default None
  258. Colormap to select colors from. If a string, load colormap with that
  259. name from matplotlib.
  260. **kwargs
  261. Options to pass to matplotlib plotting method.
  262. Returns
  263. -------
  264. :class:`matplotlib.axes.Axes`
  265. Examples
  266. --------
  267. .. plot::
  268. :context: close-figs
  269. >>> df = pd.read_csv(
  270. ... 'https://raw.githubusercontent.com/pandas-dev/'
  271. ... 'pandas/main/pandas/tests/io/data/csv/iris.csv'
  272. ... )
  273. >>> pd.plotting.andrews_curves(df, 'Name')
  274. <AxesSubplot: title={'center': 'width'}>
  275. """
  276. plot_backend = _get_plot_backend("matplotlib")
  277. return plot_backend.andrews_curves(
  278. frame=frame,
  279. class_column=class_column,
  280. ax=ax,
  281. samples=samples,
  282. color=color,
  283. colormap=colormap,
  284. **kwargs,
  285. )
  286. def bootstrap_plot(
  287. series: Series,
  288. fig: Figure | None = None,
  289. size: int = 50,
  290. samples: int = 500,
  291. **kwds,
  292. ) -> Figure:
  293. """
  294. Bootstrap plot on mean, median and mid-range statistics.
  295. The bootstrap plot is used to estimate the uncertainty of a statistic
  296. by relying on random sampling with replacement [1]_. This function will
  297. generate bootstrapping plots for mean, median and mid-range statistics
  298. for the given number of samples of the given size.
  299. .. [1] "Bootstrapping (statistics)" in \
  300. https://en.wikipedia.org/wiki/Bootstrapping_%28statistics%29
  301. Parameters
  302. ----------
  303. series : pandas.Series
  304. Series from where to get the samplings for the bootstrapping.
  305. fig : matplotlib.figure.Figure, default None
  306. If given, it will use the `fig` reference for plotting instead of
  307. creating a new one with default parameters.
  308. size : int, default 50
  309. Number of data points to consider during each sampling. It must be
  310. less than or equal to the length of the `series`.
  311. samples : int, default 500
  312. Number of times the bootstrap procedure is performed.
  313. **kwds
  314. Options to pass to matplotlib plotting method.
  315. Returns
  316. -------
  317. matplotlib.figure.Figure
  318. Matplotlib figure.
  319. See Also
  320. --------
  321. pandas.DataFrame.plot : Basic plotting for DataFrame objects.
  322. pandas.Series.plot : Basic plotting for Series objects.
  323. Examples
  324. --------
  325. This example draws a basic bootstrap plot for a Series.
  326. .. plot::
  327. :context: close-figs
  328. >>> s = pd.Series(np.random.uniform(size=100))
  329. >>> pd.plotting.bootstrap_plot(s)
  330. <Figure size 640x480 with 6 Axes>
  331. """
  332. plot_backend = _get_plot_backend("matplotlib")
  333. return plot_backend.bootstrap_plot(
  334. series=series, fig=fig, size=size, samples=samples, **kwds
  335. )
  336. def parallel_coordinates(
  337. frame: DataFrame,
  338. class_column: str,
  339. cols: list[str] | None = None,
  340. ax: Axes | None = None,
  341. color: list[str] | tuple[str, ...] | None = None,
  342. use_columns: bool = False,
  343. xticks: list | tuple | None = None,
  344. colormap=None,
  345. axvlines: bool = True,
  346. axvlines_kwds=None,
  347. sort_labels: bool = False,
  348. **kwargs,
  349. ) -> Axes:
  350. """
  351. Parallel coordinates plotting.
  352. Parameters
  353. ----------
  354. frame : DataFrame
  355. class_column : str
  356. Column name containing class names.
  357. cols : list, optional
  358. A list of column names to use.
  359. ax : matplotlib.axis, optional
  360. Matplotlib axis object.
  361. color : list or tuple, optional
  362. Colors to use for the different classes.
  363. use_columns : bool, optional
  364. If true, columns will be used as xticks.
  365. xticks : list or tuple, optional
  366. A list of values to use for xticks.
  367. colormap : str or matplotlib colormap, default None
  368. Colormap to use for line colors.
  369. axvlines : bool, optional
  370. If true, vertical lines will be added at each xtick.
  371. axvlines_kwds : keywords, optional
  372. Options to be passed to axvline method for vertical lines.
  373. sort_labels : bool, default False
  374. Sort class_column labels, useful when assigning colors.
  375. **kwargs
  376. Options to pass to matplotlib plotting method.
  377. Returns
  378. -------
  379. matplotlib.axes.Axes
  380. Examples
  381. --------
  382. .. plot::
  383. :context: close-figs
  384. >>> df = pd.read_csv(
  385. ... 'https://raw.githubusercontent.com/pandas-dev/'
  386. ... 'pandas/main/pandas/tests/io/data/csv/iris.csv'
  387. ... )
  388. >>> pd.plotting.parallel_coordinates(
  389. ... df, 'Name', color=('#556270', '#4ECDC4', '#C7F464')
  390. ... )
  391. <AxesSubplot: xlabel='y(t)', ylabel='y(t + 1)'>
  392. """
  393. plot_backend = _get_plot_backend("matplotlib")
  394. return plot_backend.parallel_coordinates(
  395. frame=frame,
  396. class_column=class_column,
  397. cols=cols,
  398. ax=ax,
  399. color=color,
  400. use_columns=use_columns,
  401. xticks=xticks,
  402. colormap=colormap,
  403. axvlines=axvlines,
  404. axvlines_kwds=axvlines_kwds,
  405. sort_labels=sort_labels,
  406. **kwargs,
  407. )
  408. def lag_plot(series: Series, lag: int = 1, ax: Axes | None = None, **kwds) -> Axes:
  409. """
  410. Lag plot for time series.
  411. Parameters
  412. ----------
  413. series : Series
  414. The time series to visualize.
  415. lag : int, default 1
  416. Lag length of the scatter plot.
  417. ax : Matplotlib axis object, optional
  418. The matplotlib axis object to use.
  419. **kwds
  420. Matplotlib scatter method keyword arguments.
  421. Returns
  422. -------
  423. matplotlib.axes.Axes
  424. Examples
  425. --------
  426. Lag plots are most commonly used to look for patterns in time series data.
  427. Given the following time series
  428. .. plot::
  429. :context: close-figs
  430. >>> np.random.seed(5)
  431. >>> x = np.cumsum(np.random.normal(loc=1, scale=5, size=50))
  432. >>> s = pd.Series(x)
  433. >>> s.plot()
  434. <AxesSubplot: xlabel='Midrange'>
  435. A lag plot with ``lag=1`` returns
  436. .. plot::
  437. :context: close-figs
  438. >>> pd.plotting.lag_plot(s, lag=1)
  439. <AxesSubplot: xlabel='y(t)', ylabel='y(t + 1)'>
  440. """
  441. plot_backend = _get_plot_backend("matplotlib")
  442. return plot_backend.lag_plot(series=series, lag=lag, ax=ax, **kwds)
  443. def autocorrelation_plot(series: Series, ax: Axes | None = None, **kwargs) -> Axes:
  444. """
  445. Autocorrelation plot for time series.
  446. Parameters
  447. ----------
  448. series : Series
  449. The time series to visualize.
  450. ax : Matplotlib axis object, optional
  451. The matplotlib axis object to use.
  452. **kwargs
  453. Options to pass to matplotlib plotting method.
  454. Returns
  455. -------
  456. matplotlib.axes.Axes
  457. Examples
  458. --------
  459. The horizontal lines in the plot correspond to 95% and 99% confidence bands.
  460. The dashed line is 99% confidence band.
  461. .. plot::
  462. :context: close-figs
  463. >>> spacing = np.linspace(-9 * np.pi, 9 * np.pi, num=1000)
  464. >>> s = pd.Series(0.7 * np.random.rand(1000) + 0.3 * np.sin(spacing))
  465. >>> pd.plotting.autocorrelation_plot(s)
  466. <AxesSubplot: title={'center': 'width'}, xlabel='Lag', ylabel='Autocorrelation'>
  467. """
  468. plot_backend = _get_plot_backend("matplotlib")
  469. return plot_backend.autocorrelation_plot(series=series, ax=ax, **kwargs)
  470. class _Options(dict):
  471. """
  472. Stores pandas plotting options.
  473. Allows for parameter aliasing so you can just use parameter names that are
  474. the same as the plot function parameters, but is stored in a canonical
  475. format that makes it easy to breakdown into groups later.
  476. """
  477. # alias so the names are same as plotting method parameter names
  478. _ALIASES = {"x_compat": "xaxis.compat"}
  479. _DEFAULT_KEYS = ["xaxis.compat"]
  480. def __init__(self, deprecated: bool = False) -> None:
  481. self._deprecated = deprecated
  482. super().__setitem__("xaxis.compat", False)
  483. def __getitem__(self, key):
  484. key = self._get_canonical_key(key)
  485. if key not in self:
  486. raise ValueError(f"{key} is not a valid pandas plotting option")
  487. return super().__getitem__(key)
  488. def __setitem__(self, key, value) -> None:
  489. key = self._get_canonical_key(key)
  490. super().__setitem__(key, value)
  491. def __delitem__(self, key) -> None:
  492. key = self._get_canonical_key(key)
  493. if key in self._DEFAULT_KEYS:
  494. raise ValueError(f"Cannot remove default parameter {key}")
  495. super().__delitem__(key)
  496. def __contains__(self, key) -> bool:
  497. key = self._get_canonical_key(key)
  498. return super().__contains__(key)
  499. def reset(self) -> None:
  500. """
  501. Reset the option store to its initial state
  502. Returns
  503. -------
  504. None
  505. """
  506. # error: Cannot access "__init__" directly
  507. self.__init__() # type: ignore[misc]
  508. def _get_canonical_key(self, key):
  509. return self._ALIASES.get(key, key)
  510. @contextmanager
  511. def use(self, key, value) -> Generator[_Options, None, None]:
  512. """
  513. Temporarily set a parameter value using the with statement.
  514. Aliasing allowed.
  515. """
  516. old_value = self[key]
  517. try:
  518. self[key] = value
  519. yield self
  520. finally:
  521. self[key] = old_value
  522. plot_params = _Options()