boxplot.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550
  1. from __future__ import annotations
  2. from typing import (
  3. TYPE_CHECKING,
  4. Collection,
  5. Literal,
  6. NamedTuple,
  7. )
  8. import warnings
  9. from matplotlib.artist import setp
  10. import numpy as np
  11. from pandas._typing import MatplotlibColor
  12. from pandas.util._exceptions import find_stack_level
  13. from pandas.core.dtypes.common import is_dict_like
  14. from pandas.core.dtypes.missing import remove_na_arraylike
  15. import pandas as pd
  16. import pandas.core.common as com
  17. from pandas.io.formats.printing import pprint_thing
  18. from pandas.plotting._matplotlib.core import (
  19. LinePlot,
  20. MPLPlot,
  21. )
  22. from pandas.plotting._matplotlib.groupby import create_iter_data_given_by
  23. from pandas.plotting._matplotlib.style import get_standard_colors
  24. from pandas.plotting._matplotlib.tools import (
  25. create_subplots,
  26. flatten_axes,
  27. maybe_adjust_figure,
  28. )
  29. if TYPE_CHECKING:
  30. from matplotlib.axes import Axes
  31. from matplotlib.lines import Line2D
  32. class BoxPlot(LinePlot):
  33. @property
  34. def _kind(self) -> Literal["box"]:
  35. return "box"
  36. _layout_type = "horizontal"
  37. _valid_return_types = (None, "axes", "dict", "both")
  38. class BP(NamedTuple):
  39. # namedtuple to hold results
  40. ax: Axes
  41. lines: dict[str, list[Line2D]]
  42. def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
  43. if return_type not in self._valid_return_types:
  44. raise ValueError("return_type must be {None, 'axes', 'dict', 'both'}")
  45. self.return_type = return_type
  46. # Do not call LinePlot.__init__ which may fill nan
  47. MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called
  48. def _args_adjust(self) -> None:
  49. if self.subplots:
  50. # Disable label ax sharing. Otherwise, all subplots shows last
  51. # column label
  52. if self.orientation == "vertical":
  53. self.sharex = False
  54. else:
  55. self.sharey = False
  56. # error: Signature of "_plot" incompatible with supertype "MPLPlot"
  57. @classmethod
  58. def _plot( # type: ignore[override]
  59. cls, ax, y, column_num=None, return_type: str = "axes", **kwds
  60. ):
  61. if y.ndim == 2:
  62. y = [remove_na_arraylike(v) for v in y]
  63. # Boxplot fails with empty arrays, so need to add a NaN
  64. # if any cols are empty
  65. # GH 8181
  66. y = [v if v.size > 0 else np.array([np.nan]) for v in y]
  67. else:
  68. y = remove_na_arraylike(y)
  69. bp = ax.boxplot(y, **kwds)
  70. if return_type == "dict":
  71. return bp, bp
  72. elif return_type == "both":
  73. return cls.BP(ax=ax, lines=bp), bp
  74. else:
  75. return ax, bp
  76. def _validate_color_args(self):
  77. if "color" in self.kwds:
  78. if self.colormap is not None:
  79. warnings.warn(
  80. "'color' and 'colormap' cannot be used "
  81. "simultaneously. Using 'color'",
  82. stacklevel=find_stack_level(),
  83. )
  84. self.color = self.kwds.pop("color")
  85. if isinstance(self.color, dict):
  86. valid_keys = ["boxes", "whiskers", "medians", "caps"]
  87. for key in self.color:
  88. if key not in valid_keys:
  89. raise ValueError(
  90. f"color dict contains invalid key '{key}'. "
  91. f"The key must be either {valid_keys}"
  92. )
  93. else:
  94. self.color = None
  95. # get standard colors for default
  96. colors = get_standard_colors(num_colors=3, colormap=self.colormap, color=None)
  97. # use 2 colors by default, for box/whisker and median
  98. # flier colors isn't needed here
  99. # because it can be specified by ``sym`` kw
  100. self._boxes_c = colors[0]
  101. self._whiskers_c = colors[0]
  102. self._medians_c = colors[2]
  103. self._caps_c = colors[0]
  104. def _get_colors(
  105. self,
  106. num_colors=None,
  107. color_kwds: dict[str, MatplotlibColor]
  108. | MatplotlibColor
  109. | Collection[MatplotlibColor]
  110. | None = "color",
  111. ) -> None:
  112. pass
  113. def maybe_color_bp(self, bp) -> None:
  114. if isinstance(self.color, dict):
  115. boxes = self.color.get("boxes", self._boxes_c)
  116. whiskers = self.color.get("whiskers", self._whiskers_c)
  117. medians = self.color.get("medians", self._medians_c)
  118. caps = self.color.get("caps", self._caps_c)
  119. else:
  120. # Other types are forwarded to matplotlib
  121. # If None, use default colors
  122. boxes = self.color or self._boxes_c
  123. whiskers = self.color or self._whiskers_c
  124. medians = self.color or self._medians_c
  125. caps = self.color or self._caps_c
  126. # GH 30346, when users specifying those arguments explicitly, our defaults
  127. # for these four kwargs should be overridden; if not, use Pandas settings
  128. if not self.kwds.get("boxprops"):
  129. setp(bp["boxes"], color=boxes, alpha=1)
  130. if not self.kwds.get("whiskerprops"):
  131. setp(bp["whiskers"], color=whiskers, alpha=1)
  132. if not self.kwds.get("medianprops"):
  133. setp(bp["medians"], color=medians, alpha=1)
  134. if not self.kwds.get("capprops"):
  135. setp(bp["caps"], color=caps, alpha=1)
  136. def _make_plot(self) -> None:
  137. if self.subplots:
  138. self._return_obj = pd.Series(dtype=object)
  139. # Re-create iterated data if `by` is assigned by users
  140. data = (
  141. create_iter_data_given_by(self.data, self._kind)
  142. if self.by is not None
  143. else self.data
  144. )
  145. for i, (label, y) in enumerate(self._iter_data(data=data)):
  146. ax = self._get_ax(i)
  147. kwds = self.kwds.copy()
  148. # When by is applied, show title for subplots to know which group it is
  149. # just like df.boxplot, and need to apply T on y to provide right input
  150. if self.by is not None:
  151. y = y.T
  152. ax.set_title(pprint_thing(label))
  153. # When `by` is assigned, the ticklabels will become unique grouped
  154. # values, instead of label which is used as subtitle in this case.
  155. ticklabels = [
  156. pprint_thing(col) for col in self.data.columns.levels[0]
  157. ]
  158. else:
  159. ticklabels = [pprint_thing(label)]
  160. ret, bp = self._plot(
  161. ax, y, column_num=i, return_type=self.return_type, **kwds
  162. )
  163. self.maybe_color_bp(bp)
  164. self._return_obj[label] = ret
  165. self._set_ticklabels(ax, ticklabels)
  166. else:
  167. y = self.data.values.T
  168. ax = self._get_ax(0)
  169. kwds = self.kwds.copy()
  170. ret, bp = self._plot(
  171. ax, y, column_num=0, return_type=self.return_type, **kwds
  172. )
  173. self.maybe_color_bp(bp)
  174. self._return_obj = ret
  175. labels = [left for left, _ in self._iter_data()]
  176. labels = [pprint_thing(left) for left in labels]
  177. if not self.use_index:
  178. labels = [pprint_thing(key) for key in range(len(labels))]
  179. self._set_ticklabels(ax, labels)
  180. def _set_ticklabels(self, ax: Axes, labels) -> None:
  181. if self.orientation == "vertical":
  182. ax.set_xticklabels(labels)
  183. else:
  184. ax.set_yticklabels(labels)
  185. def _make_legend(self) -> None:
  186. pass
  187. def _post_plot_logic(self, ax, data) -> None:
  188. # GH 45465: make sure that the boxplot doesn't ignore xlabel/ylabel
  189. if self.xlabel:
  190. ax.set_xlabel(pprint_thing(self.xlabel))
  191. if self.ylabel:
  192. ax.set_ylabel(pprint_thing(self.ylabel))
  193. @property
  194. def orientation(self) -> Literal["horizontal", "vertical"]:
  195. if self.kwds.get("vert", True):
  196. return "vertical"
  197. else:
  198. return "horizontal"
  199. @property
  200. def result(self):
  201. if self.return_type is None:
  202. return super().result
  203. else:
  204. return self._return_obj
  205. def _grouped_plot_by_column(
  206. plotf,
  207. data,
  208. columns=None,
  209. by=None,
  210. numeric_only: bool = True,
  211. grid: bool = False,
  212. figsize=None,
  213. ax=None,
  214. layout=None,
  215. return_type=None,
  216. **kwargs,
  217. ):
  218. grouped = data.groupby(by)
  219. if columns is None:
  220. if not isinstance(by, (list, tuple)):
  221. by = [by]
  222. columns = data._get_numeric_data().columns.difference(by)
  223. naxes = len(columns)
  224. fig, axes = create_subplots(
  225. naxes=naxes,
  226. sharex=kwargs.pop("sharex", True),
  227. sharey=kwargs.pop("sharey", True),
  228. figsize=figsize,
  229. ax=ax,
  230. layout=layout,
  231. )
  232. _axes = flatten_axes(axes)
  233. # GH 45465: move the "by" label based on "vert"
  234. xlabel, ylabel = kwargs.pop("xlabel", None), kwargs.pop("ylabel", None)
  235. if kwargs.get("vert", True):
  236. xlabel = xlabel or by
  237. else:
  238. ylabel = ylabel or by
  239. ax_values = []
  240. for i, col in enumerate(columns):
  241. ax = _axes[i]
  242. gp_col = grouped[col]
  243. keys, values = zip(*gp_col)
  244. re_plotf = plotf(keys, values, ax, xlabel=xlabel, ylabel=ylabel, **kwargs)
  245. ax.set_title(col)
  246. ax_values.append(re_plotf)
  247. ax.grid(grid)
  248. result = pd.Series(ax_values, index=columns, copy=False)
  249. # Return axes in multiplot case, maybe revisit later # 985
  250. if return_type is None:
  251. result = axes
  252. byline = by[0] if len(by) == 1 else by
  253. fig.suptitle(f"Boxplot grouped by {byline}")
  254. maybe_adjust_figure(fig, bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
  255. return result
  256. def boxplot(
  257. data,
  258. column=None,
  259. by=None,
  260. ax=None,
  261. fontsize=None,
  262. rot: int = 0,
  263. grid: bool = True,
  264. figsize=None,
  265. layout=None,
  266. return_type=None,
  267. **kwds,
  268. ):
  269. import matplotlib.pyplot as plt
  270. # validate return_type:
  271. if return_type not in BoxPlot._valid_return_types:
  272. raise ValueError("return_type must be {'axes', 'dict', 'both'}")
  273. if isinstance(data, pd.Series):
  274. data = data.to_frame("x")
  275. column = "x"
  276. def _get_colors():
  277. # num_colors=3 is required as method maybe_color_bp takes the colors
  278. # in positions 0 and 2.
  279. # if colors not provided, use same defaults as DataFrame.plot.box
  280. result = get_standard_colors(num_colors=3)
  281. result = np.take(result, [0, 0, 2])
  282. result = np.append(result, "k")
  283. colors = kwds.pop("color", None)
  284. if colors:
  285. if is_dict_like(colors):
  286. # replace colors in result array with user-specified colors
  287. # taken from the colors dict parameter
  288. # "boxes" value placed in position 0, "whiskers" in 1, etc.
  289. valid_keys = ["boxes", "whiskers", "medians", "caps"]
  290. key_to_index = dict(zip(valid_keys, range(4)))
  291. for key, value in colors.items():
  292. if key in valid_keys:
  293. result[key_to_index[key]] = value
  294. else:
  295. raise ValueError(
  296. f"color dict contains invalid key '{key}'. "
  297. f"The key must be either {valid_keys}"
  298. )
  299. else:
  300. result.fill(colors)
  301. return result
  302. def maybe_color_bp(bp, **kwds) -> None:
  303. # GH 30346, when users specifying those arguments explicitly, our defaults
  304. # for these four kwargs should be overridden; if not, use Pandas settings
  305. if not kwds.get("boxprops"):
  306. setp(bp["boxes"], color=colors[0], alpha=1)
  307. if not kwds.get("whiskerprops"):
  308. setp(bp["whiskers"], color=colors[1], alpha=1)
  309. if not kwds.get("medianprops"):
  310. setp(bp["medians"], color=colors[2], alpha=1)
  311. if not kwds.get("capprops"):
  312. setp(bp["caps"], color=colors[3], alpha=1)
  313. def plot_group(keys, values, ax: Axes, **kwds):
  314. # GH 45465: xlabel/ylabel need to be popped out before plotting happens
  315. xlabel, ylabel = kwds.pop("xlabel", None), kwds.pop("ylabel", None)
  316. if xlabel:
  317. ax.set_xlabel(pprint_thing(xlabel))
  318. if ylabel:
  319. ax.set_ylabel(pprint_thing(ylabel))
  320. keys = [pprint_thing(x) for x in keys]
  321. values = [np.asarray(remove_na_arraylike(v), dtype=object) for v in values]
  322. bp = ax.boxplot(values, **kwds)
  323. if fontsize is not None:
  324. ax.tick_params(axis="both", labelsize=fontsize)
  325. # GH 45465: x/y are flipped when "vert" changes
  326. is_vertical = kwds.get("vert", True)
  327. ticks = ax.get_xticks() if is_vertical else ax.get_yticks()
  328. if len(ticks) != len(keys):
  329. i, remainder = divmod(len(ticks), len(keys))
  330. assert remainder == 0, remainder
  331. keys *= i
  332. if is_vertical:
  333. ax.set_xticklabels(keys, rotation=rot)
  334. else:
  335. ax.set_yticklabels(keys, rotation=rot)
  336. maybe_color_bp(bp, **kwds)
  337. # Return axes in multiplot case, maybe revisit later # 985
  338. if return_type == "dict":
  339. return bp
  340. elif return_type == "both":
  341. return BoxPlot.BP(ax=ax, lines=bp)
  342. else:
  343. return ax
  344. colors = _get_colors()
  345. if column is None:
  346. columns = None
  347. else:
  348. if isinstance(column, (list, tuple)):
  349. columns = column
  350. else:
  351. columns = [column]
  352. if by is not None:
  353. # Prefer array return type for 2-D plots to match the subplot layout
  354. # https://github.com/pandas-dev/pandas/pull/12216#issuecomment-241175580
  355. result = _grouped_plot_by_column(
  356. plot_group,
  357. data,
  358. columns=columns,
  359. by=by,
  360. grid=grid,
  361. figsize=figsize,
  362. ax=ax,
  363. layout=layout,
  364. return_type=return_type,
  365. **kwds,
  366. )
  367. else:
  368. if return_type is None:
  369. return_type = "axes"
  370. if layout is not None:
  371. raise ValueError("The 'layout' keyword is not supported when 'by' is None")
  372. if ax is None:
  373. rc = {"figure.figsize": figsize} if figsize is not None else {}
  374. with plt.rc_context(rc):
  375. ax = plt.gca()
  376. data = data._get_numeric_data()
  377. naxes = len(data.columns)
  378. if naxes == 0:
  379. raise ValueError(
  380. "boxplot method requires numerical columns, nothing to plot."
  381. )
  382. if columns is None:
  383. columns = data.columns
  384. else:
  385. data = data[columns]
  386. result = plot_group(columns, data.values.T, ax, **kwds)
  387. ax.grid(grid)
  388. return result
  389. def boxplot_frame(
  390. self,
  391. column=None,
  392. by=None,
  393. ax=None,
  394. fontsize=None,
  395. rot: int = 0,
  396. grid: bool = True,
  397. figsize=None,
  398. layout=None,
  399. return_type=None,
  400. **kwds,
  401. ):
  402. import matplotlib.pyplot as plt
  403. ax = boxplot(
  404. self,
  405. column=column,
  406. by=by,
  407. ax=ax,
  408. fontsize=fontsize,
  409. grid=grid,
  410. rot=rot,
  411. figsize=figsize,
  412. layout=layout,
  413. return_type=return_type,
  414. **kwds,
  415. )
  416. plt.draw_if_interactive()
  417. return ax
  418. def boxplot_frame_groupby(
  419. grouped,
  420. subplots: bool = True,
  421. column=None,
  422. fontsize=None,
  423. rot: int = 0,
  424. grid: bool = True,
  425. ax=None,
  426. figsize=None,
  427. layout=None,
  428. sharex: bool = False,
  429. sharey: bool = True,
  430. **kwds,
  431. ):
  432. if subplots is True:
  433. naxes = len(grouped)
  434. fig, axes = create_subplots(
  435. naxes=naxes,
  436. squeeze=False,
  437. ax=ax,
  438. sharex=sharex,
  439. sharey=sharey,
  440. figsize=figsize,
  441. layout=layout,
  442. )
  443. axes = flatten_axes(axes)
  444. ret = pd.Series(dtype=object)
  445. for (key, group), ax in zip(grouped, axes):
  446. d = group.boxplot(
  447. ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds
  448. )
  449. ax.set_title(pprint_thing(key))
  450. ret.loc[key] = d
  451. maybe_adjust_figure(fig, bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
  452. else:
  453. keys, frames = zip(*grouped)
  454. if grouped.axis == 0:
  455. df = pd.concat(frames, keys=keys, axis=1)
  456. else:
  457. if len(frames) > 1:
  458. df = frames[0].join(frames[1::])
  459. else:
  460. df = frames[0]
  461. # GH 16748, DataFrameGroupby fails when subplots=False and `column` argument
  462. # is assigned, and in this case, since `df` here becomes MI after groupby,
  463. # so we need to couple the keys (grouped values) and column (original df
  464. # column) together to search for subset to plot
  465. if column is not None:
  466. column = com.convert_to_list_like(column)
  467. multi_key = pd.MultiIndex.from_product([keys, column])
  468. column = list(multi_key.values)
  469. ret = df.boxplot(
  470. column=column,
  471. fontsize=fontsize,
  472. rot=rot,
  473. grid=grid,
  474. ax=ax,
  475. figsize=figsize,
  476. layout=layout,
  477. **kwds,
  478. )
  479. return ret