hist.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. from __future__ import annotations
  2. from typing import (
  3. TYPE_CHECKING,
  4. Literal,
  5. )
  6. import numpy as np
  7. from pandas._typing import PlottingOrientation
  8. from pandas.core.dtypes.common import (
  9. is_integer,
  10. is_list_like,
  11. )
  12. from pandas.core.dtypes.generic import (
  13. ABCDataFrame,
  14. ABCIndex,
  15. )
  16. from pandas.core.dtypes.missing import (
  17. isna,
  18. remove_na_arraylike,
  19. )
  20. from pandas.io.formats.printing import pprint_thing
  21. from pandas.plotting._matplotlib.core import (
  22. LinePlot,
  23. MPLPlot,
  24. )
  25. from pandas.plotting._matplotlib.groupby import (
  26. create_iter_data_given_by,
  27. reformat_hist_y_given_by,
  28. )
  29. from pandas.plotting._matplotlib.misc import unpack_single_str_list
  30. from pandas.plotting._matplotlib.tools import (
  31. create_subplots,
  32. flatten_axes,
  33. maybe_adjust_figure,
  34. set_ticks_props,
  35. )
  36. if TYPE_CHECKING:
  37. from matplotlib.axes import Axes
  38. from pandas import DataFrame
  39. class HistPlot(LinePlot):
  40. @property
  41. def _kind(self) -> Literal["hist", "kde"]:
  42. return "hist"
  43. def __init__(
  44. self,
  45. data,
  46. bins: int | np.ndarray | list[np.ndarray] = 10,
  47. bottom: int | np.ndarray = 0,
  48. **kwargs,
  49. ) -> None:
  50. self.bins = bins # use mpl default
  51. self.bottom = bottom
  52. self.xlabel = kwargs.get("xlabel")
  53. self.ylabel = kwargs.get("ylabel")
  54. # Do not call LinePlot.__init__ which may fill nan
  55. MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
  56. def _args_adjust(self) -> None:
  57. # calculate bin number separately in different subplots
  58. # where subplots are created based on by argument
  59. if is_integer(self.bins):
  60. if self.by is not None:
  61. by_modified = unpack_single_str_list(self.by)
  62. grouped = self.data.groupby(by_modified)[self.columns]
  63. self.bins = [self._calculate_bins(group) for key, group in grouped]
  64. else:
  65. self.bins = self._calculate_bins(self.data)
  66. if is_list_like(self.bottom):
  67. self.bottom = np.array(self.bottom)
  68. def _calculate_bins(self, data: DataFrame) -> np.ndarray:
  69. """Calculate bins given data"""
  70. nd_values = data.infer_objects(copy=False)._get_numeric_data()
  71. values = np.ravel(nd_values)
  72. values = values[~isna(values)]
  73. hist, bins = np.histogram(
  74. values, bins=self.bins, range=self.kwds.get("range", None)
  75. )
  76. return bins
  77. # error: Signature of "_plot" incompatible with supertype "LinePlot"
  78. @classmethod
  79. def _plot( # type: ignore[override]
  80. cls,
  81. ax,
  82. y,
  83. style=None,
  84. bottom: int | np.ndarray = 0,
  85. column_num: int = 0,
  86. stacking_id=None,
  87. *,
  88. bins,
  89. **kwds,
  90. ):
  91. if column_num == 0:
  92. cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
  93. base = np.zeros(len(bins) - 1)
  94. bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"])
  95. # ignore style
  96. n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds)
  97. cls._update_stacker(ax, stacking_id, n)
  98. return patches
  99. def _make_plot(self) -> None:
  100. colors = self._get_colors()
  101. stacking_id = self._get_stacking_id()
  102. # Re-create iterated data if `by` is assigned by users
  103. data = (
  104. create_iter_data_given_by(self.data, self._kind)
  105. if self.by is not None
  106. else self.data
  107. )
  108. for i, (label, y) in enumerate(self._iter_data(data=data)):
  109. ax = self._get_ax(i)
  110. kwds = self.kwds.copy()
  111. label = pprint_thing(label)
  112. label = self._mark_right_label(label, index=i)
  113. kwds["label"] = label
  114. style, kwds = self._apply_style_colors(colors, kwds, i, label)
  115. if style is not None:
  116. kwds["style"] = style
  117. kwds = self._make_plot_keywords(kwds, y)
  118. # the bins is multi-dimension array now and each plot need only 1-d and
  119. # when by is applied, label should be columns that are grouped
  120. if self.by is not None:
  121. kwds["bins"] = kwds["bins"][i]
  122. kwds["label"] = self.columns
  123. kwds.pop("color")
  124. # We allow weights to be a multi-dimensional array, e.g. a (10, 2) array,
  125. # and each sub-array (10,) will be called in each iteration. If users only
  126. # provide 1D array, we assume the same weights is used for all iterations
  127. weights = kwds.get("weights", None)
  128. if weights is not None:
  129. if np.ndim(weights) != 1 and np.shape(weights)[-1] != 1:
  130. try:
  131. weights = weights[:, i]
  132. except IndexError as err:
  133. raise ValueError(
  134. "weights must have the same shape as data, "
  135. "or be a single column"
  136. ) from err
  137. weights = weights[~isna(y)]
  138. kwds["weights"] = weights
  139. y = reformat_hist_y_given_by(y, self.by)
  140. artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
  141. # when by is applied, show title for subplots to know which group it is
  142. if self.by is not None:
  143. ax.set_title(pprint_thing(label))
  144. self._append_legend_handles_labels(artists[0], label)
  145. def _make_plot_keywords(self, kwds, y):
  146. """merge BoxPlot/KdePlot properties to passed kwds"""
  147. # y is required for KdePlot
  148. kwds["bottom"] = self.bottom
  149. kwds["bins"] = self.bins
  150. return kwds
  151. def _post_plot_logic(self, ax: Axes, data) -> None:
  152. if self.orientation == "horizontal":
  153. ax.set_xlabel("Frequency" if self.xlabel is None else self.xlabel)
  154. ax.set_ylabel(self.ylabel)
  155. else:
  156. ax.set_xlabel(self.xlabel)
  157. ax.set_ylabel("Frequency" if self.ylabel is None else self.ylabel)
  158. @property
  159. def orientation(self) -> PlottingOrientation:
  160. if self.kwds.get("orientation", None) == "horizontal":
  161. return "horizontal"
  162. else:
  163. return "vertical"
  164. class KdePlot(HistPlot):
  165. @property
  166. def _kind(self) -> Literal["kde"]:
  167. return "kde"
  168. @property
  169. def orientation(self) -> Literal["vertical"]:
  170. return "vertical"
  171. def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
  172. # Do not call LinePlot.__init__ which may fill nan
  173. MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
  174. self.bw_method = bw_method
  175. self.ind = ind
  176. def _args_adjust(self) -> None:
  177. pass
  178. def _get_ind(self, y):
  179. if self.ind is None:
  180. # np.nanmax() and np.nanmin() ignores the missing values
  181. sample_range = np.nanmax(y) - np.nanmin(y)
  182. ind = np.linspace(
  183. np.nanmin(y) - 0.5 * sample_range,
  184. np.nanmax(y) + 0.5 * sample_range,
  185. 1000,
  186. )
  187. elif is_integer(self.ind):
  188. sample_range = np.nanmax(y) - np.nanmin(y)
  189. ind = np.linspace(
  190. np.nanmin(y) - 0.5 * sample_range,
  191. np.nanmax(y) + 0.5 * sample_range,
  192. self.ind,
  193. )
  194. else:
  195. ind = self.ind
  196. return ind
  197. @classmethod
  198. def _plot(
  199. cls,
  200. ax,
  201. y,
  202. style=None,
  203. bw_method=None,
  204. ind=None,
  205. column_num=None,
  206. stacking_id=None,
  207. **kwds,
  208. ):
  209. from scipy.stats import gaussian_kde
  210. y = remove_na_arraylike(y)
  211. gkde = gaussian_kde(y, bw_method=bw_method)
  212. y = gkde.evaluate(ind)
  213. lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
  214. return lines
  215. def _make_plot_keywords(self, kwds, y):
  216. kwds["bw_method"] = self.bw_method
  217. kwds["ind"] = self._get_ind(y)
  218. return kwds
  219. def _post_plot_logic(self, ax, data) -> None:
  220. ax.set_ylabel("Density")
  221. def _grouped_plot(
  222. plotf,
  223. data,
  224. column=None,
  225. by=None,
  226. numeric_only: bool = True,
  227. figsize=None,
  228. sharex: bool = True,
  229. sharey: bool = True,
  230. layout=None,
  231. rot: float = 0,
  232. ax=None,
  233. **kwargs,
  234. ):
  235. if figsize == "default":
  236. # allowed to specify mpl default with 'default'
  237. raise ValueError(
  238. "figsize='default' is no longer supported. "
  239. "Specify figure size by tuple instead"
  240. )
  241. grouped = data.groupby(by)
  242. if column is not None:
  243. grouped = grouped[column]
  244. naxes = len(grouped)
  245. fig, axes = create_subplots(
  246. naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
  247. )
  248. _axes = flatten_axes(axes)
  249. for i, (key, group) in enumerate(grouped):
  250. ax = _axes[i]
  251. if numeric_only and isinstance(group, ABCDataFrame):
  252. group = group._get_numeric_data()
  253. plotf(group, ax, **kwargs)
  254. ax.set_title(pprint_thing(key))
  255. return fig, axes
  256. def _grouped_hist(
  257. data,
  258. column=None,
  259. by=None,
  260. ax=None,
  261. bins: int = 50,
  262. figsize=None,
  263. layout=None,
  264. sharex: bool = False,
  265. sharey: bool = False,
  266. rot: float = 90,
  267. grid: bool = True,
  268. xlabelsize=None,
  269. xrot=None,
  270. ylabelsize=None,
  271. yrot=None,
  272. legend: bool = False,
  273. **kwargs,
  274. ):
  275. """
  276. Grouped histogram
  277. Parameters
  278. ----------
  279. data : Series/DataFrame
  280. column : object, optional
  281. by : object, optional
  282. ax : axes, optional
  283. bins : int, default 50
  284. figsize : tuple, optional
  285. layout : optional
  286. sharex : bool, default False
  287. sharey : bool, default False
  288. rot : float, default 90
  289. grid : bool, default True
  290. legend: : bool, default False
  291. kwargs : dict, keyword arguments passed to matplotlib.Axes.hist
  292. Returns
  293. -------
  294. collection of Matplotlib Axes
  295. """
  296. if legend:
  297. assert "label" not in kwargs
  298. if data.ndim == 1:
  299. kwargs["label"] = data.name
  300. elif column is None:
  301. kwargs["label"] = data.columns
  302. else:
  303. kwargs["label"] = column
  304. def plot_group(group, ax) -> None:
  305. ax.hist(group.dropna().values, bins=bins, **kwargs)
  306. if legend:
  307. ax.legend()
  308. if xrot is None:
  309. xrot = rot
  310. fig, axes = _grouped_plot(
  311. plot_group,
  312. data,
  313. column=column,
  314. by=by,
  315. sharex=sharex,
  316. sharey=sharey,
  317. ax=ax,
  318. figsize=figsize,
  319. layout=layout,
  320. rot=rot,
  321. )
  322. set_ticks_props(
  323. axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
  324. )
  325. maybe_adjust_figure(
  326. fig, bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3
  327. )
  328. return axes
  329. def hist_series(
  330. self,
  331. by=None,
  332. ax=None,
  333. grid: bool = True,
  334. xlabelsize=None,
  335. xrot=None,
  336. ylabelsize=None,
  337. yrot=None,
  338. figsize=None,
  339. bins: int = 10,
  340. legend: bool = False,
  341. **kwds,
  342. ):
  343. import matplotlib.pyplot as plt
  344. if legend and "label" in kwds:
  345. raise ValueError("Cannot use both legend and label")
  346. if by is None:
  347. if kwds.get("layout", None) is not None:
  348. raise ValueError("The 'layout' keyword is not supported when 'by' is None")
  349. # hack until the plotting interface is a bit more unified
  350. fig = kwds.pop(
  351. "figure", plt.gcf() if plt.get_fignums() else plt.figure(figsize=figsize)
  352. )
  353. if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
  354. fig.set_size_inches(*figsize, forward=True)
  355. if ax is None:
  356. ax = fig.gca()
  357. elif ax.get_figure() != fig:
  358. raise AssertionError("passed axis not bound to passed figure")
  359. values = self.dropna().values
  360. if legend:
  361. kwds["label"] = self.name
  362. ax.hist(values, bins=bins, **kwds)
  363. if legend:
  364. ax.legend()
  365. ax.grid(grid)
  366. axes = np.array([ax])
  367. set_ticks_props(
  368. axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
  369. )
  370. else:
  371. if "figure" in kwds:
  372. raise ValueError(
  373. "Cannot pass 'figure' when using the "
  374. "'by' argument, since a new 'Figure' instance will be created"
  375. )
  376. axes = _grouped_hist(
  377. self,
  378. by=by,
  379. ax=ax,
  380. grid=grid,
  381. figsize=figsize,
  382. bins=bins,
  383. xlabelsize=xlabelsize,
  384. xrot=xrot,
  385. ylabelsize=ylabelsize,
  386. yrot=yrot,
  387. legend=legend,
  388. **kwds,
  389. )
  390. if hasattr(axes, "ndim"):
  391. if axes.ndim == 1 and len(axes) == 1:
  392. return axes[0]
  393. return axes
  394. def hist_frame(
  395. data,
  396. column=None,
  397. by=None,
  398. grid: bool = True,
  399. xlabelsize=None,
  400. xrot=None,
  401. ylabelsize=None,
  402. yrot=None,
  403. ax=None,
  404. sharex: bool = False,
  405. sharey: bool = False,
  406. figsize=None,
  407. layout=None,
  408. bins: int = 10,
  409. legend: bool = False,
  410. **kwds,
  411. ):
  412. if legend and "label" in kwds:
  413. raise ValueError("Cannot use both legend and label")
  414. if by is not None:
  415. axes = _grouped_hist(
  416. data,
  417. column=column,
  418. by=by,
  419. ax=ax,
  420. grid=grid,
  421. figsize=figsize,
  422. sharex=sharex,
  423. sharey=sharey,
  424. layout=layout,
  425. bins=bins,
  426. xlabelsize=xlabelsize,
  427. xrot=xrot,
  428. ylabelsize=ylabelsize,
  429. yrot=yrot,
  430. legend=legend,
  431. **kwds,
  432. )
  433. return axes
  434. if column is not None:
  435. if not isinstance(column, (list, np.ndarray, ABCIndex)):
  436. column = [column]
  437. data = data[column]
  438. # GH32590
  439. data = data.select_dtypes(
  440. include=(np.number, "datetime64", "datetimetz"), exclude="timedelta"
  441. )
  442. naxes = len(data.columns)
  443. if naxes == 0:
  444. raise ValueError(
  445. "hist method requires numerical or datetime columns, nothing to plot."
  446. )
  447. fig, axes = create_subplots(
  448. naxes=naxes,
  449. ax=ax,
  450. squeeze=False,
  451. sharex=sharex,
  452. sharey=sharey,
  453. figsize=figsize,
  454. layout=layout,
  455. )
  456. _axes = flatten_axes(axes)
  457. can_set_label = "label" not in kwds
  458. for i, col in enumerate(data.columns):
  459. ax = _axes[i]
  460. if legend and can_set_label:
  461. kwds["label"] = col
  462. ax.hist(data[col].dropna().values, bins=bins, **kwds)
  463. ax.set_title(col)
  464. ax.grid(grid)
  465. if legend:
  466. ax.legend()
  467. set_ticks_props(
  468. axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
  469. )
  470. maybe_adjust_figure(fig, wspace=0.3, hspace=0.3)
  471. return axes