core.py 63 KB


  1. from __future__ import annotations
  2. from abc import (
  3. ABC,
  4. abstractmethod,
  5. )
  6. from typing import (
  7. TYPE_CHECKING,
  8. Hashable,
  9. Iterable,
  10. Literal,
  11. Sequence,
  12. )
  13. import warnings
  14. import matplotlib as mpl
  15. from matplotlib.artist import Artist
  16. import numpy as np
  17. from pandas._typing import (
  18. IndexLabel,
  19. PlottingOrientation,
  20. npt,
  21. )
  22. from pandas.errors import AbstractMethodError
  23. from pandas.util._decorators import cache_readonly
  24. from pandas.util._exceptions import find_stack_level
  25. from pandas.core.dtypes.common import (
  26. is_any_real_numeric_dtype,
  27. is_categorical_dtype,
  28. is_extension_array_dtype,
  29. is_float,
  30. is_float_dtype,
  31. is_hashable,
  32. is_integer,
  33. is_integer_dtype,
  34. is_iterator,
  35. is_list_like,
  36. is_number,
  37. is_numeric_dtype,
  38. )
  39. from pandas.core.dtypes.generic import (
  40. ABCDataFrame,
  41. ABCIndex,
  42. ABCMultiIndex,
  43. ABCPeriodIndex,
  44. ABCSeries,
  45. )
  46. from pandas.core.dtypes.missing import (
  47. isna,
  48. notna,
  49. )
  50. import pandas.core.common as com
  51. from pandas.core.frame import DataFrame
  52. from pandas.util.version import Version
  53. from pandas.io.formats.printing import pprint_thing
  54. from pandas.plotting._matplotlib import tools
  55. from pandas.plotting._matplotlib.converter import register_pandas_matplotlib_converters
  56. from pandas.plotting._matplotlib.groupby import reconstruct_data_with_by
  57. from pandas.plotting._matplotlib.misc import unpack_single_str_list
  58. from pandas.plotting._matplotlib.style import get_standard_colors
  59. from pandas.plotting._matplotlib.timeseries import (
  60. decorate_axes,
  61. format_dateaxis,
  62. maybe_convert_index,
  63. maybe_resample,
  64. use_dynamic_x,
  65. )
  66. from pandas.plotting._matplotlib.tools import (
  67. create_subplots,
  68. flatten_axes,
  69. format_date_labels,
  70. get_all_lines,
  71. get_xlim,
  72. handle_shared_axes,
  73. )
  74. if TYPE_CHECKING:
  75. from matplotlib.axes import Axes
  76. from matplotlib.axis import Axis
  77. def _color_in_style(style: str) -> bool:
  78. """
  79. Check if there is a color letter in the style string.
  80. """
  81. from matplotlib.colors import BASE_COLORS
  82. return not set(BASE_COLORS).isdisjoint(style)
  83. class MPLPlot(ABC):
  84. """
  85. Base class for assembling a pandas plot using matplotlib
  86. Parameters
  87. ----------
  88. data :
  89. """
  90. @property
  91. @abstractmethod
  92. def _kind(self) -> str:
  93. """Specify kind str. Must be overridden in child class"""
  94. raise NotImplementedError
  95. _layout_type = "vertical"
  96. _default_rot = 0
  97. @property
  98. def orientation(self) -> str | None:
  99. return None
  100. axes: np.ndarray # of Axes objects
  101. def __init__(
  102. self,
  103. data,
  104. kind=None,
  105. by: IndexLabel | None = None,
  106. subplots: bool | Sequence[Sequence[str]] = False,
  107. sharex=None,
  108. sharey: bool = False,
  109. use_index: bool = True,
  110. figsize=None,
  111. grid=None,
  112. legend: bool | str = True,
  113. rot=None,
  114. ax=None,
  115. fig=None,
  116. title=None,
  117. xlim=None,
  118. ylim=None,
  119. xticks=None,
  120. yticks=None,
  121. xlabel: Hashable | None = None,
  122. ylabel: Hashable | None = None,
  123. fontsize=None,
  124. secondary_y: bool | tuple | list | np.ndarray = False,
  125. colormap=None,
  126. table: bool = False,
  127. layout=None,
  128. include_bool: bool = False,
  129. column: IndexLabel | None = None,
  130. **kwds,
  131. ) -> None:
  132. import matplotlib.pyplot as plt
  133. self.data = data
  134. # if users assign an empty list or tuple, raise `ValueError`
  135. # similar to current `df.box` and `df.hist` APIs.
  136. if by in ([], ()):
  137. raise ValueError("No group keys passed!")
  138. self.by = com.maybe_make_list(by)
  139. # Assign the rest of columns into self.columns if by is explicitly defined
  140. # while column is not, only need `columns` in hist/box plot when it's DF
  141. # TODO: Might deprecate `column` argument in future PR (#28373)
  142. if isinstance(data, DataFrame):
  143. if column:
  144. self.columns = com.maybe_make_list(column)
  145. else:
  146. if self.by is None:
  147. self.columns = [
  148. col for col in data.columns if is_numeric_dtype(data[col])
  149. ]
  150. else:
  151. self.columns = [
  152. col
  153. for col in data.columns
  154. if col not in self.by and is_numeric_dtype(data[col])
  155. ]
  156. # For `hist` plot, need to get grouped original data before `self.data` is
  157. # updated later
  158. if self.by is not None and self._kind == "hist":
  159. self._grouped = data.groupby(unpack_single_str_list(self.by))
  160. self.kind = kind
  161. self.subplots = self._validate_subplots_kwarg(subplots)
  162. if sharex is None:
  163. # if by is defined, subplots are used and sharex should be False
  164. if ax is None and by is None:
  165. self.sharex = True
  166. else:
  167. # if we get an axis, the users should do the visibility
  168. # setting...
  169. self.sharex = False
  170. else:
  171. self.sharex = sharex
  172. self.sharey = sharey
  173. self.figsize = figsize
  174. self.layout = layout
  175. self.xticks = xticks
  176. self.yticks = yticks
  177. self.xlim = xlim
  178. self.ylim = ylim
  179. self.title = title
  180. self.use_index = use_index
  181. self.xlabel = xlabel
  182. self.ylabel = ylabel
  183. self.fontsize = fontsize
  184. if rot is not None:
  185. self.rot = rot
  186. # need to know for format_date_labels since it's rotated to 30 by
  187. # default
  188. self._rot_set = True
  189. else:
  190. self._rot_set = False
  191. self.rot = self._default_rot
  192. if grid is None:
  193. grid = False if secondary_y else plt.rcParams["axes.grid"]
  194. self.grid = grid
  195. self.legend = legend
  196. self.legend_handles: list[Artist] = []
  197. self.legend_labels: list[Hashable] = []
  198. self.logx = kwds.pop("logx", False)
  199. self.logy = kwds.pop("logy", False)
  200. self.loglog = kwds.pop("loglog", False)
  201. self.label = kwds.pop("label", None)
  202. self.style = kwds.pop("style", None)
  203. self.mark_right = kwds.pop("mark_right", True)
  204. self.stacked = kwds.pop("stacked", False)
  205. self.ax = ax
  206. self.fig = fig
  207. self.axes = np.array([], dtype=object) # "real" version get set in `generate`
  208. # parse errorbar input if given
  209. xerr = kwds.pop("xerr", None)
  210. yerr = kwds.pop("yerr", None)
  211. self.errors = {
  212. kw: self._parse_errorbars(kw, err)
  213. for kw, err in zip(["xerr", "yerr"], [xerr, yerr])
  214. }
  215. if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, ABCIndex)):
  216. secondary_y = [secondary_y]
  217. self.secondary_y = secondary_y
  218. # ugly TypeError if user passes matplotlib's `cmap` name.
  219. # Probably better to accept either.
  220. if "cmap" in kwds and colormap:
  221. raise TypeError("Only specify one of `cmap` and `colormap`.")
  222. if "cmap" in kwds:
  223. self.colormap = kwds.pop("cmap")
  224. else:
  225. self.colormap = colormap
  226. self.table = table
  227. self.include_bool = include_bool
  228. self.kwds = kwds
  229. self._validate_color_args()
  230. def _validate_subplots_kwarg(
  231. self, subplots: bool | Sequence[Sequence[str]]
  232. ) -> bool | list[tuple[int, ...]]:
  233. """
  234. Validate the subplots parameter
  235. - check type and content
  236. - check for duplicate columns
  237. - check for invalid column names
  238. - convert column names into indices
  239. - add missing columns in a group of their own
  240. See comments in code below for more details.
  241. Parameters
  242. ----------
  243. subplots : subplots parameters as passed to PlotAccessor
  244. Returns
  245. -------
  246. validated subplots : a bool or a list of tuples of column indices. Columns
  247. in the same tuple will be grouped together in the resulting plot.
  248. """
  249. if isinstance(subplots, bool):
  250. return subplots
  251. elif not isinstance(subplots, Iterable):
  252. raise ValueError("subplots should be a bool or an iterable")
  253. supported_kinds = (
  254. "line",
  255. "bar",
  256. "barh",
  257. "hist",
  258. "kde",
  259. "density",
  260. "area",
  261. "pie",
  262. )
  263. if self._kind not in supported_kinds:
  264. raise ValueError(
  265. "When subplots is an iterable, kind must be "
  266. f"one of {', '.join(supported_kinds)}. Got {self._kind}."
  267. )
  268. if isinstance(self.data, ABCSeries):
  269. raise NotImplementedError(
  270. "An iterable subplots for a Series is not supported."
  271. )
  272. columns = self.data.columns
  273. if isinstance(columns, ABCMultiIndex):
  274. raise NotImplementedError(
  275. "An iterable subplots for a DataFrame with a MultiIndex column "
  276. "is not supported."
  277. )
  278. if columns.nunique() != len(columns):
  279. raise NotImplementedError(
  280. "An iterable subplots for a DataFrame with non-unique column "
  281. "labels is not supported."
  282. )
  283. # subplots is a list of tuples where each tuple is a group of
  284. # columns to be grouped together (one ax per group).
  285. # we consolidate the subplots list such that:
  286. # - the tuples contain indices instead of column names
  287. # - the columns that aren't yet in the list are added in a group
  288. # of their own.
  289. # For example with columns from a to g, and
  290. # subplots = [(a, c), (b, f, e)],
  291. # we end up with [(ai, ci), (bi, fi, ei), (di,), (gi,)]
  292. # This way, we can handle self.subplots in a homogeneous manner
  293. # later.
  294. # TODO: also accept indices instead of just names?
  295. out = []
  296. seen_columns: set[Hashable] = set()
  297. for group in subplots:
  298. if not is_list_like(group):
  299. raise ValueError(
  300. "When subplots is an iterable, each entry "
  301. "should be a list/tuple of column names."
  302. )
  303. idx_locs = columns.get_indexer_for(group)
  304. if (idx_locs == -1).any():
  305. bad_labels = np.extract(idx_locs == -1, group)
  306. raise ValueError(
  307. f"Column label(s) {list(bad_labels)} not found in the DataFrame."
  308. )
  309. unique_columns = set(group)
  310. duplicates = seen_columns.intersection(unique_columns)
  311. if duplicates:
  312. raise ValueError(
  313. "Each column should be in only one subplot. "
  314. f"Columns {duplicates} were found in multiple subplots."
  315. )
  316. seen_columns = seen_columns.union(unique_columns)
  317. out.append(tuple(idx_locs))
  318. unseen_columns = columns.difference(seen_columns)
  319. for column in unseen_columns:
  320. idx_loc = columns.get_loc(column)
  321. out.append((idx_loc,))
  322. return out
  323. def _validate_color_args(self):
  324. if (
  325. "color" in self.kwds
  326. and self.nseries == 1
  327. and not is_list_like(self.kwds["color"])
  328. ):
  329. # support series.plot(color='green')
  330. self.kwds["color"] = [self.kwds["color"]]
  331. if (
  332. "color" in self.kwds
  333. and isinstance(self.kwds["color"], tuple)
  334. and self.nseries == 1
  335. and len(self.kwds["color"]) in (3, 4)
  336. ):
  337. # support RGB and RGBA tuples in series plot
  338. self.kwds["color"] = [self.kwds["color"]]
  339. if (
  340. "color" in self.kwds or "colors" in self.kwds
  341. ) and self.colormap is not None:
  342. warnings.warn(
  343. "'color' and 'colormap' cannot be used simultaneously. Using 'color'",
  344. stacklevel=find_stack_level(),
  345. )
  346. if "color" in self.kwds and self.style is not None:
  347. if is_list_like(self.style):
  348. styles = self.style
  349. else:
  350. styles = [self.style]
  351. # need only a single match
  352. for s in styles:
  353. if _color_in_style(s):
  354. raise ValueError(
  355. "Cannot pass 'style' string with a color symbol and "
  356. "'color' keyword argument. Please use one or the "
  357. "other or pass 'style' without a color symbol"
  358. )
  359. def _iter_data(self, data=None, keep_index: bool = False, fillna=None):
  360. if data is None:
  361. data = self.data
  362. if fillna is not None:
  363. data = data.fillna(fillna)
  364. for col, values in data.items():
  365. if keep_index is True:
  366. yield col, values
  367. else:
  368. yield col, values.values
  369. @property
  370. def nseries(self) -> int:
  371. # When `by` is explicitly assigned, grouped data size will be defined, and
  372. # this will determine number of subplots to have, aka `self.nseries`
  373. if self.data.ndim == 1:
  374. return 1
  375. elif self.by is not None and self._kind == "hist":
  376. return len(self._grouped)
  377. elif self.by is not None and self._kind == "box":
  378. return len(self.columns)
  379. else:
  380. return self.data.shape[1]
  381. def draw(self) -> None:
  382. self.plt.draw_if_interactive()
  383. def generate(self) -> None:
  384. self._args_adjust()
  385. self._compute_plot_data()
  386. self._setup_subplots()
  387. self._make_plot()
  388. self._add_table()
  389. self._make_legend()
  390. self._adorn_subplots()
  391. for ax in self.axes:
  392. self._post_plot_logic_common(ax, self.data)
  393. self._post_plot_logic(ax, self.data)
  394. @abstractmethod
  395. def _args_adjust(self) -> None:
  396. pass
  397. def _has_plotted_object(self, ax: Axes) -> bool:
  398. """check whether ax has data"""
  399. return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0
  400. def _maybe_right_yaxis(self, ax: Axes, axes_num):
  401. if not self.on_right(axes_num):
  402. # secondary axes may be passed via ax kw
  403. return self._get_ax_layer(ax)
  404. if hasattr(ax, "right_ax"):
  405. # if it has right_ax property, ``ax`` must be left axes
  406. return ax.right_ax
  407. elif hasattr(ax, "left_ax"):
  408. # if it has left_ax property, ``ax`` must be right axes
  409. return ax
  410. else:
  411. # otherwise, create twin axes
  412. orig_ax, new_ax = ax, ax.twinx()
  413. # TODO: use Matplotlib public API when available
  414. new_ax._get_lines = orig_ax._get_lines
  415. new_ax._get_patches_for_fill = orig_ax._get_patches_for_fill
  416. orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
  417. if not self._has_plotted_object(orig_ax): # no data on left y
  418. orig_ax.get_yaxis().set_visible(False)
  419. if self.logy is True or self.loglog is True:
  420. new_ax.set_yscale("log")
  421. elif self.logy == "sym" or self.loglog == "sym":
  422. new_ax.set_yscale("symlog")
  423. return new_ax
  424. def _setup_subplots(self):
  425. if self.subplots:
  426. naxes = (
  427. self.nseries if isinstance(self.subplots, bool) else len(self.subplots)
  428. )
  429. fig, axes = create_subplots(
  430. naxes=naxes,
  431. sharex=self.sharex,
  432. sharey=self.sharey,
  433. figsize=self.figsize,
  434. ax=self.ax,
  435. layout=self.layout,
  436. layout_type=self._layout_type,
  437. )
  438. else:
  439. if self.ax is None:
  440. fig = self.plt.figure(figsize=self.figsize)
  441. axes = fig.add_subplot(111)
  442. else:
  443. fig = self.ax.get_figure()
  444. if self.figsize is not None:
  445. fig.set_size_inches(self.figsize)
  446. axes = self.ax
  447. axes = flatten_axes(axes)
  448. valid_log = {False, True, "sym", None}
  449. input_log = {self.logx, self.logy, self.loglog}
  450. if input_log - valid_log:
  451. invalid_log = next(iter(input_log - valid_log))
  452. raise ValueError(
  453. f"Boolean, None and 'sym' are valid options, '{invalid_log}' is given."
  454. )
  455. if self.logx is True or self.loglog is True:
  456. [a.set_xscale("log") for a in axes]
  457. elif self.logx == "sym" or self.loglog == "sym":
  458. [a.set_xscale("symlog") for a in axes]
  459. if self.logy is True or self.loglog is True:
  460. [a.set_yscale("log") for a in axes]
  461. elif self.logy == "sym" or self.loglog == "sym":
  462. [a.set_yscale("symlog") for a in axes]
  463. self.fig = fig
  464. self.axes = axes
  465. @property
  466. def result(self):
  467. """
  468. Return result axes
  469. """
  470. if self.subplots:
  471. if self.layout is not None and not is_list_like(self.ax):
  472. return self.axes.reshape(*self.layout)
  473. else:
  474. return self.axes
  475. else:
  476. sec_true = isinstance(self.secondary_y, bool) and self.secondary_y
  477. # error: Argument 1 to "len" has incompatible type "Union[bool,
  478. # Tuple[Any, ...], List[Any], ndarray[Any, Any]]"; expected "Sized"
  479. all_sec = (
  480. is_list_like(self.secondary_y)
  481. and len(self.secondary_y) == self.nseries # type: ignore[arg-type]
  482. )
  483. if sec_true or all_sec:
  484. # if all data is plotted on secondary, return right axes
  485. return self._get_ax_layer(self.axes[0], primary=False)
  486. else:
  487. return self.axes[0]
  488. def _convert_to_ndarray(self, data):
  489. # GH31357: categorical columns are processed separately
  490. if is_categorical_dtype(data):
  491. return data
  492. # GH32073: cast to float if values contain nulled integers
  493. if (
  494. is_integer_dtype(data.dtype) or is_float_dtype(data.dtype)
  495. ) and is_extension_array_dtype(data.dtype):
  496. return data.to_numpy(dtype="float", na_value=np.nan)
  497. # GH25587: cast ExtensionArray of pandas (IntegerArray, etc.) to
  498. # np.ndarray before plot.
  499. if len(data) > 0:
  500. return np.asarray(data)
  501. return data
  502. def _compute_plot_data(self):
  503. data = self.data
  504. if isinstance(data, ABCSeries):
  505. label = self.label
  506. if label is None and data.name is None:
  507. label = ""
  508. if label is None:
  509. # We'll end up with columns of [0] instead of [None]
  510. data = data.to_frame()
  511. else:
  512. data = data.to_frame(name=label)
  513. elif self._kind in ("hist", "box"):
  514. cols = self.columns if self.by is None else self.columns + self.by
  515. data = data.loc[:, cols]
  516. # GH15079 reconstruct data if by is defined
  517. if self.by is not None:
  518. self.subplots = True
  519. data = reconstruct_data_with_by(self.data, by=self.by, cols=self.columns)
  520. # GH16953, infer_objects is needed as fallback, for ``Series``
  521. # with ``dtype == object``
  522. data = data.infer_objects(copy=False)
  523. include_type = [np.number, "datetime", "datetimetz", "timedelta"]
  524. # GH23719, allow plotting boolean
  525. if self.include_bool is True:
  526. include_type.append(np.bool_)
  527. # GH22799, exclude datetime-like type for boxplot
  528. exclude_type = None
  529. if self._kind == "box":
  530. # TODO: change after solving issue 27881
  531. include_type = [np.number]
  532. exclude_type = ["timedelta"]
  533. # GH 18755, include object and category type for scatter plot
  534. if self._kind == "scatter":
  535. include_type.extend(["object", "category"])
  536. numeric_data = data.select_dtypes(include=include_type, exclude=exclude_type)
  537. try:
  538. is_empty = numeric_data.columns.empty
  539. except AttributeError:
  540. is_empty = not len(numeric_data)
  541. # no non-numeric frames or series allowed
  542. if is_empty:
  543. raise TypeError("no numeric data to plot")
  544. self.data = numeric_data.apply(self._convert_to_ndarray)
  545. def _make_plot(self):
  546. raise AbstractMethodError(self)
  547. def _add_table(self) -> None:
  548. if self.table is False:
  549. return
  550. elif self.table is True:
  551. data = self.data.transpose()
  552. else:
  553. data = self.table
  554. ax = self._get_ax(0)
  555. tools.table(ax, data)
  556. def _post_plot_logic_common(self, ax, data):
  557. """Common post process for each axes"""
  558. if self.orientation == "vertical" or self.orientation is None:
  559. self._apply_axis_properties(ax.xaxis, rot=self.rot, fontsize=self.fontsize)
  560. self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
  561. if hasattr(ax, "right_ax"):
  562. self._apply_axis_properties(ax.right_ax.yaxis, fontsize=self.fontsize)
  563. elif self.orientation == "horizontal":
  564. self._apply_axis_properties(ax.yaxis, rot=self.rot, fontsize=self.fontsize)
  565. self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize)
  566. if hasattr(ax, "right_ax"):
  567. self._apply_axis_properties(ax.right_ax.yaxis, fontsize=self.fontsize)
  568. else: # pragma no cover
  569. raise ValueError
  570. @abstractmethod
  571. def _post_plot_logic(self, ax, data) -> None:
  572. """Post process for each axes. Overridden in child classes"""
  573. def _adorn_subplots(self):
  574. """Common post process unrelated to data"""
  575. if len(self.axes) > 0:
  576. all_axes = self._get_subplots()
  577. nrows, ncols = self._get_axes_layout()
  578. handle_shared_axes(
  579. axarr=all_axes,
  580. nplots=len(all_axes),
  581. naxes=nrows * ncols,
  582. nrows=nrows,
  583. ncols=ncols,
  584. sharex=self.sharex,
  585. sharey=self.sharey,
  586. )
  587. for ax in self.axes:
  588. ax = getattr(ax, "right_ax", ax)
  589. if self.yticks is not None:
  590. ax.set_yticks(self.yticks)
  591. if self.xticks is not None:
  592. ax.set_xticks(self.xticks)
  593. if self.ylim is not None:
  594. ax.set_ylim(self.ylim)
  595. if self.xlim is not None:
  596. ax.set_xlim(self.xlim)
  597. # GH9093, currently Pandas does not show ylabel, so if users provide
  598. # ylabel will set it as ylabel in the plot.
  599. if self.ylabel is not None:
  600. ax.set_ylabel(pprint_thing(self.ylabel))
  601. ax.grid(self.grid)
  602. if self.title:
  603. if self.subplots:
  604. if is_list_like(self.title):
  605. if len(self.title) != self.nseries:
  606. raise ValueError(
  607. "The length of `title` must equal the number "
  608. "of columns if using `title` of type `list` "
  609. "and `subplots=True`.\n"
  610. f"length of title = {len(self.title)}\n"
  611. f"number of columns = {self.nseries}"
  612. )
  613. for ax, title in zip(self.axes, self.title):
  614. ax.set_title(title)
  615. else:
  616. self.fig.suptitle(self.title)
  617. else:
  618. if is_list_like(self.title):
  619. msg = (
  620. "Using `title` of type `list` is not supported "
  621. "unless `subplots=True` is passed"
  622. )
  623. raise ValueError(msg)
  624. self.axes[0].set_title(self.title)
  625. def _apply_axis_properties(self, axis: Axis, rot=None, fontsize=None) -> None:
  626. """
  627. Tick creation within matplotlib is reasonably expensive and is
  628. internally deferred until accessed as Ticks are created/destroyed
  629. multiple times per draw. It's therefore beneficial for us to avoid
  630. accessing unless we will act on the Tick.
  631. """
  632. if rot is not None or fontsize is not None:
  633. # rot=0 is a valid setting, hence the explicit None check
  634. labels = axis.get_majorticklabels() + axis.get_minorticklabels()
  635. for label in labels:
  636. if rot is not None:
  637. label.set_rotation(rot)
  638. if fontsize is not None:
  639. label.set_fontsize(fontsize)
  640. @property
  641. def legend_title(self) -> str | None:
  642. if not isinstance(self.data.columns, ABCMultiIndex):
  643. name = self.data.columns.name
  644. if name is not None:
  645. name = pprint_thing(name)
  646. return name
  647. else:
  648. stringified = map(pprint_thing, self.data.columns.names)
  649. return ",".join(stringified)
  650. def _mark_right_label(self, label: str, index: int) -> str:
  651. """
  652. Append ``(right)`` to the label of a line if it's plotted on the right axis.
  653. Note that ``(right)`` is only appended when ``subplots=False``.
  654. """
  655. if not self.subplots and self.mark_right and self.on_right(index):
  656. label += " (right)"
  657. return label
  658. def _append_legend_handles_labels(self, handle: Artist, label: str) -> None:
  659. """
  660. Append current handle and label to ``legend_handles`` and ``legend_labels``.
  661. These will be used to make the legend.
  662. """
  663. self.legend_handles.append(handle)
  664. self.legend_labels.append(label)
  665. def _make_legend(self) -> None:
  666. ax, leg = self._get_ax_legend(self.axes[0])
  667. handles = []
  668. labels = []
  669. title = ""
  670. if not self.subplots:
  671. if leg is not None:
  672. title = leg.get_title().get_text()
  673. # Replace leg.legend_handles because it misses marker info
  674. if Version(mpl.__version__) < Version("3.7"):
  675. handles = leg.legendHandles
  676. else:
  677. handles = leg.legend_handles
  678. labels = [x.get_text() for x in leg.get_texts()]
  679. if self.legend:
  680. if self.legend == "reverse":
  681. handles += reversed(self.legend_handles)
  682. labels += reversed(self.legend_labels)
  683. else:
  684. handles += self.legend_handles
  685. labels += self.legend_labels
  686. if self.legend_title is not None:
  687. title = self.legend_title
  688. if len(handles) > 0:
  689. ax.legend(handles, labels, loc="best", title=title)
  690. elif self.subplots and self.legend:
  691. for ax in self.axes:
  692. if ax.get_visible():
  693. ax.legend(loc="best")
  694. def _get_ax_legend(self, ax: Axes):
  695. """
  696. Take in axes and return ax and legend under different scenarios
  697. """
  698. leg = ax.get_legend()
  699. other_ax = getattr(ax, "left_ax", None) or getattr(ax, "right_ax", None)
  700. other_leg = None
  701. if other_ax is not None:
  702. other_leg = other_ax.get_legend()
  703. if leg is None and other_leg is not None:
  704. leg = other_leg
  705. ax = other_ax
  706. return ax, leg
  707. @cache_readonly
  708. def plt(self):
  709. import matplotlib.pyplot as plt
  710. return plt
  711. _need_to_set_index = False
  712. def _get_xticks(self, convert_period: bool = False):
  713. index = self.data.index
  714. is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time")
  715. if self.use_index:
  716. if convert_period and isinstance(index, ABCPeriodIndex):
  717. self.data = self.data.reindex(index=index.sort_values())
  718. x = self.data.index.to_timestamp()._mpl_repr()
  719. elif is_any_real_numeric_dtype(index):
  720. # Matplotlib supports numeric values or datetime objects as
  721. # xaxis values. Taking LBYL approach here, by the time
  722. # matplotlib raises exception when using non numeric/datetime
  723. # values for xaxis, several actions are already taken by plt.
  724. x = index._mpl_repr()
  725. elif is_datetype:
  726. self.data = self.data[notna(self.data.index)]
  727. self.data = self.data.sort_index()
  728. x = self.data.index._mpl_repr()
  729. else:
  730. self._need_to_set_index = True
  731. x = list(range(len(index)))
  732. else:
  733. x = list(range(len(index)))
  734. return x
  735. @classmethod
  736. @register_pandas_matplotlib_converters
  737. def _plot(
  738. cls, ax: Axes, x, y: np.ndarray, style=None, is_errorbar: bool = False, **kwds
  739. ):
  740. mask = isna(y)
  741. if mask.any():
  742. y = np.ma.array(y)
  743. y = np.ma.masked_where(mask, y)
  744. if isinstance(x, ABCIndex):
  745. x = x._mpl_repr()
  746. if is_errorbar:
  747. if "xerr" in kwds:
  748. kwds["xerr"] = np.array(kwds.get("xerr"))
  749. if "yerr" in kwds:
  750. kwds["yerr"] = np.array(kwds.get("yerr"))
  751. return ax.errorbar(x, y, **kwds)
  752. else:
  753. # prevent style kwarg from going to errorbar, where it is unsupported
  754. args = (x, y, style) if style is not None else (x, y)
  755. return ax.plot(*args, **kwds)
  756. def _get_custom_index_name(self):
  757. """Specify whether xlabel/ylabel should be used to override index name"""
  758. return self.xlabel
  759. def _get_index_name(self) -> str | None:
  760. if isinstance(self.data.index, ABCMultiIndex):
  761. name = self.data.index.names
  762. if com.any_not_none(*name):
  763. name = ",".join([pprint_thing(x) for x in name])
  764. else:
  765. name = None
  766. else:
  767. name = self.data.index.name
  768. if name is not None:
  769. name = pprint_thing(name)
  770. # GH 45145, override the default axis label if one is provided.
  771. index_name = self._get_custom_index_name()
  772. if index_name is not None:
  773. name = pprint_thing(index_name)
  774. return name
  775. @classmethod
  776. def _get_ax_layer(cls, ax, primary: bool = True):
  777. """get left (primary) or right (secondary) axes"""
  778. if primary:
  779. return getattr(ax, "left_ax", ax)
  780. else:
  781. return getattr(ax, "right_ax", ax)
  782. def _col_idx_to_axis_idx(self, col_idx: int) -> int:
  783. """Return the index of the axis where the column at col_idx should be plotted"""
  784. if isinstance(self.subplots, list):
  785. # Subplots is a list: some columns will be grouped together in the same ax
  786. return next(
  787. group_idx
  788. for (group_idx, group) in enumerate(self.subplots)
  789. if col_idx in group
  790. )
  791. else:
  792. # subplots is True: one ax per column
  793. return col_idx
  794. def _get_ax(self, i: int):
  795. # get the twinx ax if appropriate
  796. if self.subplots:
  797. i = self._col_idx_to_axis_idx(i)
  798. ax = self.axes[i]
  799. ax = self._maybe_right_yaxis(ax, i)
  800. self.axes[i] = ax
  801. else:
  802. ax = self.axes[0]
  803. ax = self._maybe_right_yaxis(ax, i)
  804. ax.get_yaxis().set_visible(True)
  805. return ax
  806. @classmethod
  807. def get_default_ax(cls, ax) -> None:
  808. import matplotlib.pyplot as plt
  809. if ax is None and len(plt.get_fignums()) > 0:
  810. with plt.rc_context():
  811. ax = plt.gca()
  812. ax = cls._get_ax_layer(ax)
  813. def on_right(self, i):
  814. if isinstance(self.secondary_y, bool):
  815. return self.secondary_y
  816. if isinstance(self.secondary_y, (tuple, list, np.ndarray, ABCIndex)):
  817. return self.data.columns[i] in self.secondary_y
  818. def _apply_style_colors(self, colors, kwds, col_num, label):
  819. """
  820. Manage style and color based on column number and its label.
  821. Returns tuple of appropriate style and kwds which "color" may be added.
  822. """
  823. style = None
  824. if self.style is not None:
  825. if isinstance(self.style, list):
  826. try:
  827. style = self.style[col_num]
  828. except IndexError:
  829. pass
  830. elif isinstance(self.style, dict):
  831. style = self.style.get(label, style)
  832. else:
  833. style = self.style
  834. has_color = "color" in kwds or self.colormap is not None
  835. nocolor_style = style is None or not _color_in_style(style)
  836. if (has_color or self.subplots) and nocolor_style:
  837. if isinstance(colors, dict):
  838. kwds["color"] = colors[label]
  839. else:
  840. kwds["color"] = colors[col_num % len(colors)]
  841. return style, kwds
  842. def _get_colors(
  843. self,
  844. num_colors: int | None = None,
  845. color_kwds: str = "color",
  846. ):
  847. if num_colors is None:
  848. num_colors = self.nseries
  849. return get_standard_colors(
  850. num_colors=num_colors,
  851. colormap=self.colormap,
  852. color=self.kwds.get(color_kwds),
  853. )
  854. def _parse_errorbars(self, label, err):
  855. """
  856. Look for error keyword arguments and return the actual errorbar data
  857. or return the error DataFrame/dict
  858. Error bars can be specified in several ways:
  859. Series: the user provides a pandas.Series object of the same
  860. length as the data
  861. ndarray: provides a np.ndarray of the same length as the data
  862. DataFrame/dict: error values are paired with keys matching the
  863. key in the plotted DataFrame
  864. str: the name of the column within the plotted DataFrame
  865. Asymmetrical error bars are also supported, however raw error values
  866. must be provided in this case. For a ``N`` length :class:`Series`, a
  867. ``2xN`` array should be provided indicating lower and upper (or left
  868. and right) errors. For a ``MxN`` :class:`DataFrame`, asymmetrical errors
  869. should be in a ``Mx2xN`` array.
  870. """
  871. if err is None:
  872. return None
  873. def match_labels(data, e):
  874. e = e.reindex(data.index)
  875. return e
  876. # key-matched DataFrame
  877. if isinstance(err, ABCDataFrame):
  878. err = match_labels(self.data, err)
  879. # key-matched dict
  880. elif isinstance(err, dict):
  881. pass
  882. # Series of error values
  883. elif isinstance(err, ABCSeries):
  884. # broadcast error series across data
  885. err = match_labels(self.data, err)
  886. err = np.atleast_2d(err)
  887. err = np.tile(err, (self.nseries, 1))
  888. # errors are a column in the dataframe
  889. elif isinstance(err, str):
  890. evalues = self.data[err].values
  891. self.data = self.data[self.data.columns.drop(err)]
  892. err = np.atleast_2d(evalues)
  893. err = np.tile(err, (self.nseries, 1))
  894. elif is_list_like(err):
  895. if is_iterator(err):
  896. err = np.atleast_2d(list(err))
  897. else:
  898. # raw error values
  899. err = np.atleast_2d(err)
  900. err_shape = err.shape
  901. # asymmetrical error bars
  902. if isinstance(self.data, ABCSeries) and err_shape[0] == 2:
  903. err = np.expand_dims(err, 0)
  904. err_shape = err.shape
  905. if err_shape[2] != len(self.data):
  906. raise ValueError(
  907. "Asymmetrical error bars should be provided "
  908. f"with the shape (2, {len(self.data)})"
  909. )
  910. elif isinstance(self.data, ABCDataFrame) and err.ndim == 3:
  911. if (
  912. (err_shape[0] != self.nseries)
  913. or (err_shape[1] != 2)
  914. or (err_shape[2] != len(self.data))
  915. ):
  916. raise ValueError(
  917. "Asymmetrical error bars should be provided "
  918. f"with the shape ({self.nseries}, 2, {len(self.data)})"
  919. )
  920. # broadcast errors to each data series
  921. if len(err) == 1:
  922. err = np.tile(err, (self.nseries, 1))
  923. elif is_number(err):
  924. err = np.tile([err], (self.nseries, len(self.data)))
  925. else:
  926. msg = f"No valid {label} detected"
  927. raise ValueError(msg)
  928. return err
  929. def _get_errorbars(
  930. self, label=None, index=None, xerr: bool = True, yerr: bool = True
  931. ):
  932. errors = {}
  933. for kw, flag in zip(["xerr", "yerr"], [xerr, yerr]):
  934. if flag:
  935. err = self.errors[kw]
  936. # user provided label-matched dataframe of errors
  937. if isinstance(err, (ABCDataFrame, dict)):
  938. if label is not None and label in err.keys():
  939. err = err[label]
  940. else:
  941. err = None
  942. elif index is not None and err is not None:
  943. err = err[index]
  944. if err is not None:
  945. errors[kw] = err
  946. return errors
  947. def _get_subplots(self):
  948. from matplotlib.axes import Subplot
  949. return [
  950. ax
  951. for ax in self.fig.get_axes()
  952. if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None)
  953. ]
  954. def _get_axes_layout(self) -> tuple[int, int]:
  955. axes = self._get_subplots()
  956. x_set = set()
  957. y_set = set()
  958. for ax in axes:
  959. # check axes coordinates to estimate layout
  960. points = ax.get_position().get_points()
  961. x_set.add(points[0][0])
  962. y_set.add(points[0][1])
  963. return (len(y_set), len(x_set))
  964. class PlanePlot(MPLPlot, ABC):
  965. """
  966. Abstract class for plotting on plane, currently scatter and hexbin.
  967. """
  968. _layout_type = "single"
  969. def __init__(self, data, x, y, **kwargs) -> None:
  970. MPLPlot.__init__(self, data, **kwargs)
  971. if x is None or y is None:
  972. raise ValueError(self._kind + " requires an x and y column")
  973. if is_integer(x) and not self.data.columns._holds_integer():
  974. x = self.data.columns[x]
  975. if is_integer(y) and not self.data.columns._holds_integer():
  976. y = self.data.columns[y]
  977. # Scatter plot allows to plot objects data
  978. if self._kind == "hexbin":
  979. if len(self.data[x]._get_numeric_data()) == 0:
  980. raise ValueError(self._kind + " requires x column to be numeric")
  981. if len(self.data[y]._get_numeric_data()) == 0:
  982. raise ValueError(self._kind + " requires y column to be numeric")
  983. self.x = x
  984. self.y = y
  985. @property
  986. def nseries(self) -> int:
  987. return 1
  988. def _post_plot_logic(self, ax: Axes, data) -> None:
  989. x, y = self.x, self.y
  990. xlabel = self.xlabel if self.xlabel is not None else pprint_thing(x)
  991. ylabel = self.ylabel if self.ylabel is not None else pprint_thing(y)
  992. ax.set_xlabel(xlabel)
  993. ax.set_ylabel(ylabel)
  994. def _plot_colorbar(self, ax: Axes, **kwds):
  995. # Addresses issues #10611 and #10678:
  996. # When plotting scatterplots and hexbinplots in IPython
  997. # inline backend the colorbar axis height tends not to
  998. # exactly match the parent axis height.
  999. # The difference is due to small fractional differences
  1000. # in floating points with similar representation.
  1001. # To deal with this, this method forces the colorbar
  1002. # height to take the height of the parent axes.
  1003. # For a more detailed description of the issue
  1004. # see the following link:
  1005. # https://github.com/ipython/ipython/issues/11215
  1006. # GH33389, if ax is used multiple times, we should always
  1007. # use the last one which contains the latest information
  1008. # about the ax
  1009. img = ax.collections[-1]
  1010. return self.fig.colorbar(img, ax=ax, **kwds)
  1011. class ScatterPlot(PlanePlot):
  1012. @property
  1013. def _kind(self) -> Literal["scatter"]:
  1014. return "scatter"
  1015. def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
  1016. if s is None:
  1017. # hide the matplotlib default for size, in case we want to change
  1018. # the handling of this argument later
  1019. s = 20
  1020. elif is_hashable(s) and s in data.columns:
  1021. s = data[s]
  1022. super().__init__(data, x, y, s=s, **kwargs)
  1023. if is_integer(c) and not self.data.columns._holds_integer():
  1024. c = self.data.columns[c]
  1025. self.c = c
  1026. def _make_plot(self):
  1027. x, y, c, data = self.x, self.y, self.c, self.data
  1028. ax = self.axes[0]
  1029. c_is_column = is_hashable(c) and c in self.data.columns
  1030. color_by_categorical = c_is_column and is_categorical_dtype(self.data[c])
  1031. color = self.kwds.pop("color", None)
  1032. if c is not None and color is not None:
  1033. raise TypeError("Specify exactly one of `c` and `color`")
  1034. if c is None and color is None:
  1035. c_values = self.plt.rcParams["patch.facecolor"]
  1036. elif color is not None:
  1037. c_values = color
  1038. elif color_by_categorical:
  1039. c_values = self.data[c].cat.codes
  1040. elif c_is_column:
  1041. c_values = self.data[c].values
  1042. else:
  1043. c_values = c
  1044. if self.colormap is not None:
  1045. cmap = mpl.colormaps.get_cmap(self.colormap)
  1046. else:
  1047. # cmap is only used if c_values are integers, otherwise UserWarning
  1048. if is_integer_dtype(c_values):
  1049. # pandas uses colormap, matplotlib uses cmap.
  1050. cmap = "Greys"
  1051. cmap = mpl.colormaps[cmap]
  1052. else:
  1053. cmap = None
  1054. if color_by_categorical:
  1055. from matplotlib import colors
  1056. n_cats = len(self.data[c].cat.categories)
  1057. cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)])
  1058. bounds = np.linspace(0, n_cats, n_cats + 1)
  1059. norm = colors.BoundaryNorm(bounds, cmap.N)
  1060. else:
  1061. norm = self.kwds.pop("norm", None)
  1062. # plot colorbar if
  1063. # 1. colormap is assigned, and
  1064. # 2.`c` is a column containing only numeric values
  1065. plot_colorbar = self.colormap or c_is_column
  1066. cb = self.kwds.pop("colorbar", is_numeric_dtype(c_values) and plot_colorbar)
  1067. if self.legend and hasattr(self, "label"):
  1068. label = self.label
  1069. else:
  1070. label = None
  1071. scatter = ax.scatter(
  1072. data[x].values,
  1073. data[y].values,
  1074. c=c_values,
  1075. label=label,
  1076. cmap=cmap,
  1077. norm=norm,
  1078. **self.kwds,
  1079. )
  1080. if cb:
  1081. cbar_label = c if c_is_column else ""
  1082. cbar = self._plot_colorbar(ax, label=cbar_label)
  1083. if color_by_categorical:
  1084. cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
  1085. cbar.ax.set_yticklabels(self.data[c].cat.categories)
  1086. if label is not None:
  1087. self._append_legend_handles_labels(scatter, label)
  1088. else:
  1089. self.legend = False
  1090. errors_x = self._get_errorbars(label=x, index=0, yerr=False)
  1091. errors_y = self._get_errorbars(label=y, index=0, xerr=False)
  1092. if len(errors_x) > 0 or len(errors_y) > 0:
  1093. err_kwds = dict(errors_x, **errors_y)
  1094. err_kwds["ecolor"] = scatter.get_facecolor()[0]
  1095. ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds)
  1096. def _args_adjust(self) -> None:
  1097. pass
  1098. class HexBinPlot(PlanePlot):
  1099. @property
  1100. def _kind(self) -> Literal["hexbin"]:
  1101. return "hexbin"
  1102. def __init__(self, data, x, y, C=None, **kwargs) -> None:
  1103. super().__init__(data, x, y, **kwargs)
  1104. if is_integer(C) and not self.data.columns._holds_integer():
  1105. C = self.data.columns[C]
  1106. self.C = C
  1107. def _make_plot(self) -> None:
  1108. x, y, data, C = self.x, self.y, self.data, self.C
  1109. ax = self.axes[0]
  1110. # pandas uses colormap, matplotlib uses cmap.
  1111. cmap = self.colormap or "BuGn"
  1112. cmap = mpl.colormaps.get_cmap(cmap)
  1113. cb = self.kwds.pop("colorbar", True)
  1114. if C is None:
  1115. c_values = None
  1116. else:
  1117. c_values = data[C].values
  1118. ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap, **self.kwds)
  1119. if cb:
  1120. self._plot_colorbar(ax)
  1121. def _make_legend(self) -> None:
  1122. pass
  1123. def _args_adjust(self) -> None:
  1124. pass
  1125. class LinePlot(MPLPlot):
  1126. _default_rot = 0
  1127. @property
  1128. def orientation(self) -> PlottingOrientation:
  1129. return "vertical"
  1130. @property
  1131. def _kind(self) -> Literal["line", "area", "hist", "kde", "box"]:
  1132. return "line"
  1133. def __init__(self, data, **kwargs) -> None:
  1134. from pandas.plotting import plot_params
  1135. MPLPlot.__init__(self, data, **kwargs)
  1136. if self.stacked:
  1137. self.data = self.data.fillna(value=0)
  1138. self.x_compat = plot_params["x_compat"]
  1139. if "x_compat" in self.kwds:
  1140. self.x_compat = bool(self.kwds.pop("x_compat"))
  1141. def _is_ts_plot(self) -> bool:
  1142. # this is slightly deceptive
  1143. return not self.x_compat and self.use_index and self._use_dynamic_x()
  1144. def _use_dynamic_x(self):
  1145. return use_dynamic_x(self._get_ax(0), self.data)
  1146. def _make_plot(self) -> None:
  1147. if self._is_ts_plot():
  1148. data = maybe_convert_index(self._get_ax(0), self.data)
  1149. x = data.index # dummy, not used
  1150. plotf = self._ts_plot
  1151. it = self._iter_data(data=data, keep_index=True)
  1152. else:
  1153. x = self._get_xticks(convert_period=True)
  1154. # error: Incompatible types in assignment (expression has type
  1155. # "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
  1156. # type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
  1157. plotf = self._plot # type: ignore[assignment]
  1158. it = self._iter_data()
  1159. stacking_id = self._get_stacking_id()
  1160. is_errorbar = com.any_not_none(*self.errors.values())
  1161. colors = self._get_colors()
  1162. for i, (label, y) in enumerate(it):
  1163. ax = self._get_ax(i)
  1164. kwds = self.kwds.copy()
  1165. style, kwds = self._apply_style_colors(colors, kwds, i, label)
  1166. errors = self._get_errorbars(label=label, index=i)
  1167. kwds = dict(kwds, **errors)
  1168. label = pprint_thing(label) # .encode('utf-8')
  1169. label = self._mark_right_label(label, index=i)
  1170. kwds["label"] = label
  1171. newlines = plotf(
  1172. ax,
  1173. x,
  1174. y,
  1175. style=style,
  1176. column_num=i,
  1177. stacking_id=stacking_id,
  1178. is_errorbar=is_errorbar,
  1179. **kwds,
  1180. )
  1181. self._append_legend_handles_labels(newlines[0], label)
  1182. if self._is_ts_plot():
  1183. # reset of xlim should be used for ts data
  1184. # TODO: GH28021, should find a way to change view limit on xaxis
  1185. lines = get_all_lines(ax)
  1186. left, right = get_xlim(lines)
  1187. ax.set_xlim(left, right)
  1188. # error: Signature of "_plot" incompatible with supertype "MPLPlot"
  1189. @classmethod
  1190. def _plot( # type: ignore[override]
  1191. cls, ax: Axes, x, y, style=None, column_num=None, stacking_id=None, **kwds
  1192. ):
  1193. # column_num is used to get the target column from plotf in line and
  1194. # area plots
  1195. if column_num == 0:
  1196. cls._initialize_stacker(ax, stacking_id, len(y))
  1197. y_values = cls._get_stacked_values(ax, stacking_id, y, kwds["label"])
  1198. lines = MPLPlot._plot(ax, x, y_values, style=style, **kwds)
  1199. cls._update_stacker(ax, stacking_id, y)
  1200. return lines
  1201. def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
  1202. # accept x to be consistent with normal plot func,
  1203. # x is not passed to tsplot as it uses data.index as x coordinate
  1204. # column_num must be in kwds for stacking purpose
  1205. freq, data = maybe_resample(data, ax, kwds)
  1206. # Set ax with freq info
  1207. decorate_axes(ax, freq, kwds)
  1208. # digging deeper
  1209. if hasattr(ax, "left_ax"):
  1210. decorate_axes(ax.left_ax, freq, kwds)
  1211. if hasattr(ax, "right_ax"):
  1212. decorate_axes(ax.right_ax, freq, kwds)
  1213. ax._plot_data.append((data, self._kind, kwds))
  1214. lines = self._plot(ax, data.index, data.values, style=style, **kwds)
  1215. # set date formatter, locators and rescale limits
  1216. format_dateaxis(ax, ax.freq, data.index)
  1217. return lines
  1218. def _get_stacking_id(self):
  1219. if self.stacked:
  1220. return id(self.data)
  1221. else:
  1222. return None
  1223. @classmethod
  1224. def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
  1225. if stacking_id is None:
  1226. return
  1227. if not hasattr(ax, "_stacker_pos_prior"):
  1228. ax._stacker_pos_prior = {}
  1229. if not hasattr(ax, "_stacker_neg_prior"):
  1230. ax._stacker_neg_prior = {}
  1231. ax._stacker_pos_prior[stacking_id] = np.zeros(n)
  1232. ax._stacker_neg_prior[stacking_id] = np.zeros(n)
  1233. @classmethod
  1234. def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
  1235. if stacking_id is None:
  1236. return values
  1237. if not hasattr(ax, "_stacker_pos_prior"):
  1238. # stacker may not be initialized for subplots
  1239. cls._initialize_stacker(ax, stacking_id, len(values))
  1240. if (values >= 0).all():
  1241. return ax._stacker_pos_prior[stacking_id] + values
  1242. elif (values <= 0).all():
  1243. return ax._stacker_neg_prior[stacking_id] + values
  1244. raise ValueError(
  1245. "When stacked is True, each column must be either "
  1246. "all positive or all negative. "
  1247. f"Column '{label}' contains both positive and negative values"
  1248. )
  1249. @classmethod
  1250. def _update_stacker(cls, ax: Axes, stacking_id, values) -> None:
  1251. if stacking_id is None:
  1252. return
  1253. if (values >= 0).all():
  1254. ax._stacker_pos_prior[stacking_id] += values
  1255. elif (values <= 0).all():
  1256. ax._stacker_neg_prior[stacking_id] += values
  1257. def _args_adjust(self) -> None:
  1258. pass
  1259. def _post_plot_logic(self, ax: Axes, data) -> None:
  1260. from matplotlib.ticker import FixedLocator
  1261. def get_label(i):
  1262. if is_float(i) and i.is_integer():
  1263. i = int(i)
  1264. try:
  1265. return pprint_thing(data.index[i])
  1266. except Exception:
  1267. return ""
  1268. if self._need_to_set_index:
  1269. xticks = ax.get_xticks()
  1270. xticklabels = [get_label(x) for x in xticks]
  1271. ax.xaxis.set_major_locator(FixedLocator(xticks))
  1272. ax.set_xticklabels(xticklabels)
  1273. # If the index is an irregular time series, then by default
  1274. # we rotate the tick labels. The exception is if there are
  1275. # subplots which don't share their x-axes, in which we case
  1276. # we don't rotate the ticklabels as by default the subplots
  1277. # would be too close together.
  1278. condition = (
  1279. not self._use_dynamic_x()
  1280. and (data.index._is_all_dates and self.use_index)
  1281. and (not self.subplots or (self.subplots and self.sharex))
  1282. )
  1283. index_name = self._get_index_name()
  1284. if condition:
  1285. # irregular TS rotated 30 deg. by default
  1286. # probably a better place to check / set this.
  1287. if not self._rot_set:
  1288. self.rot = 30
  1289. format_date_labels(ax, rot=self.rot)
  1290. if index_name is not None and self.use_index:
  1291. ax.set_xlabel(index_name)
  1292. class AreaPlot(LinePlot):
  1293. @property
  1294. def _kind(self) -> Literal["area"]:
  1295. return "area"
  1296. def __init__(self, data, **kwargs) -> None:
  1297. kwargs.setdefault("stacked", True)
  1298. data = data.fillna(value=0)
  1299. LinePlot.__init__(self, data, **kwargs)
  1300. if not self.stacked:
  1301. # use smaller alpha to distinguish overlap
  1302. self.kwds.setdefault("alpha", 0.5)
  1303. if self.logy or self.loglog:
  1304. raise ValueError("Log-y scales are not supported in area plot")
  1305. # error: Signature of "_plot" incompatible with supertype "MPLPlot"
  1306. @classmethod
  1307. def _plot( # type: ignore[override]
  1308. cls,
  1309. ax: Axes,
  1310. x,
  1311. y,
  1312. style=None,
  1313. column_num=None,
  1314. stacking_id=None,
  1315. is_errorbar: bool = False,
  1316. **kwds,
  1317. ):
  1318. if column_num == 0:
  1319. cls._initialize_stacker(ax, stacking_id, len(y))
  1320. y_values = cls._get_stacked_values(ax, stacking_id, y, kwds["label"])
  1321. # need to remove label, because subplots uses mpl legend as it is
  1322. line_kwds = kwds.copy()
  1323. line_kwds.pop("label")
  1324. lines = MPLPlot._plot(ax, x, y_values, style=style, **line_kwds)
  1325. # get data from the line to get coordinates for fill_between
  1326. xdata, y_values = lines[0].get_data(orig=False)
  1327. # unable to use ``_get_stacked_values`` here to get starting point
  1328. if stacking_id is None:
  1329. start = np.zeros(len(y))
  1330. elif (y >= 0).all():
  1331. start = ax._stacker_pos_prior[stacking_id]
  1332. elif (y <= 0).all():
  1333. start = ax._stacker_neg_prior[stacking_id]
  1334. else:
  1335. start = np.zeros(len(y))
  1336. if "color" not in kwds:
  1337. kwds["color"] = lines[0].get_color()
  1338. rect = ax.fill_between(xdata, start, y_values, **kwds)
  1339. cls._update_stacker(ax, stacking_id, y)
  1340. # LinePlot expects list of artists
  1341. res = [rect]
  1342. return res
  1343. def _args_adjust(self) -> None:
  1344. pass
  1345. def _post_plot_logic(self, ax: Axes, data) -> None:
  1346. LinePlot._post_plot_logic(self, ax, data)
  1347. is_shared_y = len(list(ax.get_shared_y_axes())) > 0
  1348. # do not override the default axis behaviour in case of shared y axes
  1349. if self.ylim is None and not is_shared_y:
  1350. if (data >= 0).all().all():
  1351. ax.set_ylim(0, None)
  1352. elif (data <= 0).all().all():
  1353. ax.set_ylim(None, 0)
  1354. class BarPlot(MPLPlot):
  1355. @property
  1356. def _kind(self) -> Literal["bar", "barh"]:
  1357. return "bar"
  1358. _default_rot = 90
  1359. @property
  1360. def orientation(self) -> PlottingOrientation:
  1361. return "vertical"
  1362. def __init__(self, data, **kwargs) -> None:
  1363. # we have to treat a series differently than a
  1364. # 1-column DataFrame w.r.t. color handling
  1365. self._is_series = isinstance(data, ABCSeries)
  1366. self.bar_width = kwargs.pop("width", 0.5)
  1367. pos = kwargs.pop("position", 0.5)
  1368. kwargs.setdefault("align", "center")
  1369. self.tick_pos = np.arange(len(data))
  1370. self.bottom = kwargs.pop("bottom", 0)
  1371. self.left = kwargs.pop("left", 0)
  1372. self.log = kwargs.pop("log", False)
  1373. MPLPlot.__init__(self, data, **kwargs)
  1374. if self.stacked or self.subplots:
  1375. self.tickoffset = self.bar_width * pos
  1376. if kwargs["align"] == "edge":
  1377. self.lim_offset = self.bar_width / 2
  1378. else:
  1379. self.lim_offset = 0
  1380. else:
  1381. if kwargs["align"] == "edge":
  1382. w = self.bar_width / self.nseries
  1383. self.tickoffset = self.bar_width * (pos - 0.5) + w * 0.5
  1384. self.lim_offset = w * 0.5
  1385. else:
  1386. self.tickoffset = self.bar_width * pos
  1387. self.lim_offset = 0
  1388. self.ax_pos = self.tick_pos - self.tickoffset
  1389. def _args_adjust(self) -> None:
  1390. if is_list_like(self.bottom):
  1391. self.bottom = np.array(self.bottom)
  1392. if is_list_like(self.left):
  1393. self.left = np.array(self.left)
  1394. # error: Signature of "_plot" incompatible with supertype "MPLPlot"
  1395. @classmethod
  1396. def _plot( # type: ignore[override]
  1397. cls,
  1398. ax: Axes,
  1399. x,
  1400. y,
  1401. w,
  1402. start: int | npt.NDArray[np.intp] = 0,
  1403. log: bool = False,
  1404. **kwds,
  1405. ):
  1406. return ax.bar(x, y, w, bottom=start, log=log, **kwds)
  1407. @property
  1408. def _start_base(self):
  1409. return self.bottom
  1410. def _make_plot(self) -> None:
  1411. colors = self._get_colors()
  1412. ncolors = len(colors)
  1413. pos_prior = neg_prior = np.zeros(len(self.data))
  1414. K = self.nseries
  1415. for i, (label, y) in enumerate(self._iter_data(fillna=0)):
  1416. ax = self._get_ax(i)
  1417. kwds = self.kwds.copy()
  1418. if self._is_series:
  1419. kwds["color"] = colors
  1420. elif isinstance(colors, dict):
  1421. kwds["color"] = colors[label]
  1422. else:
  1423. kwds["color"] = colors[i % ncolors]
  1424. errors = self._get_errorbars(label=label, index=i)
  1425. kwds = dict(kwds, **errors)
  1426. label = pprint_thing(label)
  1427. label = self._mark_right_label(label, index=i)
  1428. if (("yerr" in kwds) or ("xerr" in kwds)) and (kwds.get("ecolor") is None):
  1429. kwds["ecolor"] = mpl.rcParams["xtick.color"]
  1430. start = 0
  1431. if self.log and (y >= 1).all():
  1432. start = 1
  1433. start = start + self._start_base
  1434. if self.subplots:
  1435. w = self.bar_width / 2
  1436. rect = self._plot(
  1437. ax,
  1438. self.ax_pos + w,
  1439. y,
  1440. self.bar_width,
  1441. start=start,
  1442. label=label,
  1443. log=self.log,
  1444. **kwds,
  1445. )
  1446. ax.set_title(label)
  1447. elif self.stacked:
  1448. mask = y > 0
  1449. start = np.where(mask, pos_prior, neg_prior) + self._start_base
  1450. w = self.bar_width / 2
  1451. rect = self._plot(
  1452. ax,
  1453. self.ax_pos + w,
  1454. y,
  1455. self.bar_width,
  1456. start=start,
  1457. label=label,
  1458. log=self.log,
  1459. **kwds,
  1460. )
  1461. pos_prior = pos_prior + np.where(mask, y, 0)
  1462. neg_prior = neg_prior + np.where(mask, 0, y)
  1463. else:
  1464. w = self.bar_width / K
  1465. rect = self._plot(
  1466. ax,
  1467. self.ax_pos + (i + 0.5) * w,
  1468. y,
  1469. w,
  1470. start=start,
  1471. label=label,
  1472. log=self.log,
  1473. **kwds,
  1474. )
  1475. self._append_legend_handles_labels(rect, label)
  1476. def _post_plot_logic(self, ax: Axes, data) -> None:
  1477. if self.use_index:
  1478. str_index = [pprint_thing(key) for key in data.index]
  1479. else:
  1480. str_index = [pprint_thing(key) for key in range(data.shape[0])]
  1481. s_edge = self.ax_pos[0] - 0.25 + self.lim_offset
  1482. e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset
  1483. self._decorate_ticks(ax, self._get_index_name(), str_index, s_edge, e_edge)
  1484. def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
  1485. ax.set_xlim((start_edge, end_edge))
  1486. if self.xticks is not None:
  1487. ax.set_xticks(np.array(self.xticks))
  1488. else:
  1489. ax.set_xticks(self.tick_pos)
  1490. ax.set_xticklabels(ticklabels)
  1491. if name is not None and self.use_index:
  1492. ax.set_xlabel(name)
  1493. class BarhPlot(BarPlot):
  1494. @property
  1495. def _kind(self) -> Literal["barh"]:
  1496. return "barh"
  1497. _default_rot = 0
  1498. @property
  1499. def orientation(self) -> Literal["horizontal"]:
  1500. return "horizontal"
  1501. @property
  1502. def _start_base(self):
  1503. return self.left
  1504. # error: Signature of "_plot" incompatible with supertype "MPLPlot"
  1505. @classmethod
  1506. def _plot( # type: ignore[override]
  1507. cls,
  1508. ax: Axes,
  1509. x,
  1510. y,
  1511. w,
  1512. start: int | npt.NDArray[np.intp] = 0,
  1513. log: bool = False,
  1514. **kwds,
  1515. ):
  1516. return ax.barh(x, y, w, left=start, log=log, **kwds)
  1517. def _get_custom_index_name(self):
  1518. return self.ylabel
  1519. def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
  1520. # horizontal bars
  1521. ax.set_ylim((start_edge, end_edge))
  1522. ax.set_yticks(self.tick_pos)
  1523. ax.set_yticklabels(ticklabels)
  1524. if name is not None and self.use_index:
  1525. ax.set_ylabel(name)
  1526. ax.set_xlabel(self.xlabel)
  1527. class PiePlot(MPLPlot):
  1528. @property
  1529. def _kind(self) -> Literal["pie"]:
  1530. return "pie"
  1531. _layout_type = "horizontal"
  1532. def __init__(self, data, kind=None, **kwargs) -> None:
  1533. data = data.fillna(value=0)
  1534. if (data < 0).any().any():
  1535. raise ValueError(f"{self._kind} plot doesn't allow negative values")
  1536. MPLPlot.__init__(self, data, kind=kind, **kwargs)
  1537. def _args_adjust(self) -> None:
  1538. self.grid = False
  1539. self.logy = False
  1540. self.logx = False
  1541. self.loglog = False
  1542. def _validate_color_args(self) -> None:
  1543. pass
  1544. def _make_plot(self) -> None:
  1545. colors = self._get_colors(num_colors=len(self.data), color_kwds="colors")
  1546. self.kwds.setdefault("colors", colors)
  1547. for i, (label, y) in enumerate(self._iter_data()):
  1548. ax = self._get_ax(i)
  1549. if label is not None:
  1550. label = pprint_thing(label)
  1551. ax.set_ylabel(label)
  1552. kwds = self.kwds.copy()
  1553. def blank_labeler(label, value):
  1554. if value == 0:
  1555. return ""
  1556. else:
  1557. return label
  1558. idx = [pprint_thing(v) for v in self.data.index]
  1559. labels = kwds.pop("labels", idx)
  1560. # labels is used for each wedge's labels
  1561. # Blank out labels for values of 0 so they don't overlap
  1562. # with nonzero wedges
  1563. if labels is not None:
  1564. blabels = [blank_labeler(left, value) for left, value in zip(labels, y)]
  1565. else:
  1566. blabels = None
  1567. results = ax.pie(y, labels=blabels, **kwds)
  1568. if kwds.get("autopct", None) is not None:
  1569. patches, texts, autotexts = results
  1570. else:
  1571. patches, texts = results
  1572. autotexts = []
  1573. if self.fontsize is not None:
  1574. for t in texts + autotexts:
  1575. t.set_fontsize(self.fontsize)
  1576. # leglabels is used for legend labels
  1577. leglabels = labels if labels is not None else idx
  1578. for _patch, _leglabel in zip(patches, leglabels):
  1579. self._append_legend_handles_labels(_patch, _leglabel)
  1580. def _post_plot_logic(self, ax: Axes, data) -> None:
  1581. pass