12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877 |
- from __future__ import annotations
- from abc import (
- ABC,
- abstractmethod,
- )
- from typing import (
- TYPE_CHECKING,
- Hashable,
- Iterable,
- Literal,
- Sequence,
- )
- import warnings
- import matplotlib as mpl
- from matplotlib.artist import Artist
- import numpy as np
- from pandas._typing import (
- IndexLabel,
- PlottingOrientation,
- npt,
- )
- from pandas.errors import AbstractMethodError
- from pandas.util._decorators import cache_readonly
- from pandas.util._exceptions import find_stack_level
- from pandas.core.dtypes.common import (
- is_any_real_numeric_dtype,
- is_categorical_dtype,
- is_extension_array_dtype,
- is_float,
- is_float_dtype,
- is_hashable,
- is_integer,
- is_integer_dtype,
- is_iterator,
- is_list_like,
- is_number,
- is_numeric_dtype,
- )
- from pandas.core.dtypes.generic import (
- ABCDataFrame,
- ABCIndex,
- ABCMultiIndex,
- ABCPeriodIndex,
- ABCSeries,
- )
- from pandas.core.dtypes.missing import (
- isna,
- notna,
- )
- import pandas.core.common as com
- from pandas.core.frame import DataFrame
- from pandas.util.version import Version
- from pandas.io.formats.printing import pprint_thing
- from pandas.plotting._matplotlib import tools
- from pandas.plotting._matplotlib.converter import register_pandas_matplotlib_converters
- from pandas.plotting._matplotlib.groupby import reconstruct_data_with_by
- from pandas.plotting._matplotlib.misc import unpack_single_str_list
- from pandas.plotting._matplotlib.style import get_standard_colors
- from pandas.plotting._matplotlib.timeseries import (
- decorate_axes,
- format_dateaxis,
- maybe_convert_index,
- maybe_resample,
- use_dynamic_x,
- )
- from pandas.plotting._matplotlib.tools import (
- create_subplots,
- flatten_axes,
- format_date_labels,
- get_all_lines,
- get_xlim,
- handle_shared_axes,
- )
- if TYPE_CHECKING:
- from matplotlib.axes import Axes
- from matplotlib.axis import Axis
- def _color_in_style(style: str) -> bool:
- """
- Check if there is a color letter in the style string.
- """
- from matplotlib.colors import BASE_COLORS
- return not set(BASE_COLORS).isdisjoint(style)
- class MPLPlot(ABC):
- """
- Base class for assembling a pandas plot using matplotlib
- Parameters
- ----------
- data :
- """
- @property
- @abstractmethod
- def _kind(self) -> str:
- """Specify kind str. Must be overridden in child class"""
- raise NotImplementedError
- _layout_type = "vertical"
- _default_rot = 0
- @property
- def orientation(self) -> str | None:
- return None
- axes: np.ndarray # of Axes objects
- def __init__(
- self,
- data,
- kind=None,
- by: IndexLabel | None = None,
- subplots: bool | Sequence[Sequence[str]] = False,
- sharex=None,
- sharey: bool = False,
- use_index: bool = True,
- figsize=None,
- grid=None,
- legend: bool | str = True,
- rot=None,
- ax=None,
- fig=None,
- title=None,
- xlim=None,
- ylim=None,
- xticks=None,
- yticks=None,
- xlabel: Hashable | None = None,
- ylabel: Hashable | None = None,
- fontsize=None,
- secondary_y: bool | tuple | list | np.ndarray = False,
- colormap=None,
- table: bool = False,
- layout=None,
- include_bool: bool = False,
- column: IndexLabel | None = None,
- **kwds,
- ) -> None:
- import matplotlib.pyplot as plt
- self.data = data
- # if users assign an empty list or tuple, raise `ValueError`
- # similar to current `df.box` and `df.hist` APIs.
- if by in ([], ()):
- raise ValueError("No group keys passed!")
- self.by = com.maybe_make_list(by)
- # Assign the rest of columns into self.columns if by is explicitly defined
- # while column is not, only need `columns` in hist/box plot when it's DF
- # TODO: Might deprecate `column` argument in future PR (#28373)
- if isinstance(data, DataFrame):
- if column:
- self.columns = com.maybe_make_list(column)
- else:
- if self.by is None:
- self.columns = [
- col for col in data.columns if is_numeric_dtype(data[col])
- ]
- else:
- self.columns = [
- col
- for col in data.columns
- if col not in self.by and is_numeric_dtype(data[col])
- ]
- # For `hist` plot, need to get grouped original data before `self.data` is
- # updated later
- if self.by is not None and self._kind == "hist":
- self._grouped = data.groupby(unpack_single_str_list(self.by))
- self.kind = kind
- self.subplots = self._validate_subplots_kwarg(subplots)
- if sharex is None:
- # if by is defined, subplots are used and sharex should be False
- if ax is None and by is None:
- self.sharex = True
- else:
- # if we get an axis, the users should do the visibility
- # setting...
- self.sharex = False
- else:
- self.sharex = sharex
- self.sharey = sharey
- self.figsize = figsize
- self.layout = layout
- self.xticks = xticks
- self.yticks = yticks
- self.xlim = xlim
- self.ylim = ylim
- self.title = title
- self.use_index = use_index
- self.xlabel = xlabel
- self.ylabel = ylabel
- self.fontsize = fontsize
- if rot is not None:
- self.rot = rot
- # need to know for format_date_labels since it's rotated to 30 by
- # default
- self._rot_set = True
- else:
- self._rot_set = False
- self.rot = self._default_rot
- if grid is None:
- grid = False if secondary_y else plt.rcParams["axes.grid"]
- self.grid = grid
- self.legend = legend
- self.legend_handles: list[Artist] = []
- self.legend_labels: list[Hashable] = []
- self.logx = kwds.pop("logx", False)
- self.logy = kwds.pop("logy", False)
- self.loglog = kwds.pop("loglog", False)
- self.label = kwds.pop("label", None)
- self.style = kwds.pop("style", None)
- self.mark_right = kwds.pop("mark_right", True)
- self.stacked = kwds.pop("stacked", False)
- self.ax = ax
- self.fig = fig
- self.axes = np.array([], dtype=object) # "real" version get set in `generate`
- # parse errorbar input if given
- xerr = kwds.pop("xerr", None)
- yerr = kwds.pop("yerr", None)
- self.errors = {
- kw: self._parse_errorbars(kw, err)
- for kw, err in zip(["xerr", "yerr"], [xerr, yerr])
- }
- if not isinstance(secondary_y, (bool, tuple, list, np.ndarray, ABCIndex)):
- secondary_y = [secondary_y]
- self.secondary_y = secondary_y
- # ugly TypeError if user passes matplotlib's `cmap` name.
- # Probably better to accept either.
- if "cmap" in kwds and colormap:
- raise TypeError("Only specify one of `cmap` and `colormap`.")
- if "cmap" in kwds:
- self.colormap = kwds.pop("cmap")
- else:
- self.colormap = colormap
- self.table = table
- self.include_bool = include_bool
- self.kwds = kwds
- self._validate_color_args()
- def _validate_subplots_kwarg(
- self, subplots: bool | Sequence[Sequence[str]]
- ) -> bool | list[tuple[int, ...]]:
- """
- Validate the subplots parameter
- - check type and content
- - check for duplicate columns
- - check for invalid column names
- - convert column names into indices
- - add missing columns in a group of their own
- See comments in code below for more details.
- Parameters
- ----------
- subplots : subplots parameters as passed to PlotAccessor
- Returns
- -------
- validated subplots : a bool or a list of tuples of column indices. Columns
- in the same tuple will be grouped together in the resulting plot.
- """
- if isinstance(subplots, bool):
- return subplots
- elif not isinstance(subplots, Iterable):
- raise ValueError("subplots should be a bool or an iterable")
- supported_kinds = (
- "line",
- "bar",
- "barh",
- "hist",
- "kde",
- "density",
- "area",
- "pie",
- )
- if self._kind not in supported_kinds:
- raise ValueError(
- "When subplots is an iterable, kind must be "
- f"one of {', '.join(supported_kinds)}. Got {self._kind}."
- )
- if isinstance(self.data, ABCSeries):
- raise NotImplementedError(
- "An iterable subplots for a Series is not supported."
- )
- columns = self.data.columns
- if isinstance(columns, ABCMultiIndex):
- raise NotImplementedError(
- "An iterable subplots for a DataFrame with a MultiIndex column "
- "is not supported."
- )
- if columns.nunique() != len(columns):
- raise NotImplementedError(
- "An iterable subplots for a DataFrame with non-unique column "
- "labels is not supported."
- )
- # subplots is a list of tuples where each tuple is a group of
- # columns to be grouped together (one ax per group).
- # we consolidate the subplots list such that:
- # - the tuples contain indices instead of column names
- # - the columns that aren't yet in the list are added in a group
- # of their own.
- # For example with columns from a to g, and
- # subplots = [(a, c), (b, f, e)],
- # we end up with [(ai, ci), (bi, fi, ei), (di,), (gi,)]
- # This way, we can handle self.subplots in a homogeneous manner
- # later.
- # TODO: also accept indices instead of just names?
- out = []
- seen_columns: set[Hashable] = set()
- for group in subplots:
- if not is_list_like(group):
- raise ValueError(
- "When subplots is an iterable, each entry "
- "should be a list/tuple of column names."
- )
- idx_locs = columns.get_indexer_for(group)
- if (idx_locs == -1).any():
- bad_labels = np.extract(idx_locs == -1, group)
- raise ValueError(
- f"Column label(s) {list(bad_labels)} not found in the DataFrame."
- )
- unique_columns = set(group)
- duplicates = seen_columns.intersection(unique_columns)
- if duplicates:
- raise ValueError(
- "Each column should be in only one subplot. "
- f"Columns {duplicates} were found in multiple subplots."
- )
- seen_columns = seen_columns.union(unique_columns)
- out.append(tuple(idx_locs))
- unseen_columns = columns.difference(seen_columns)
- for column in unseen_columns:
- idx_loc = columns.get_loc(column)
- out.append((idx_loc,))
- return out
- def _validate_color_args(self):
- if (
- "color" in self.kwds
- and self.nseries == 1
- and not is_list_like(self.kwds["color"])
- ):
- # support series.plot(color='green')
- self.kwds["color"] = [self.kwds["color"]]
- if (
- "color" in self.kwds
- and isinstance(self.kwds["color"], tuple)
- and self.nseries == 1
- and len(self.kwds["color"]) in (3, 4)
- ):
- # support RGB and RGBA tuples in series plot
- self.kwds["color"] = [self.kwds["color"]]
- if (
- "color" in self.kwds or "colors" in self.kwds
- ) and self.colormap is not None:
- warnings.warn(
- "'color' and 'colormap' cannot be used simultaneously. Using 'color'",
- stacklevel=find_stack_level(),
- )
- if "color" in self.kwds and self.style is not None:
- if is_list_like(self.style):
- styles = self.style
- else:
- styles = [self.style]
- # need only a single match
- for s in styles:
- if _color_in_style(s):
- raise ValueError(
- "Cannot pass 'style' string with a color symbol and "
- "'color' keyword argument. Please use one or the "
- "other or pass 'style' without a color symbol"
- )
- def _iter_data(self, data=None, keep_index: bool = False, fillna=None):
- if data is None:
- data = self.data
- if fillna is not None:
- data = data.fillna(fillna)
- for col, values in data.items():
- if keep_index is True:
- yield col, values
- else:
- yield col, values.values
- @property
- def nseries(self) -> int:
- # When `by` is explicitly assigned, grouped data size will be defined, and
- # this will determine number of subplots to have, aka `self.nseries`
- if self.data.ndim == 1:
- return 1
- elif self.by is not None and self._kind == "hist":
- return len(self._grouped)
- elif self.by is not None and self._kind == "box":
- return len(self.columns)
- else:
- return self.data.shape[1]
- def draw(self) -> None:
- self.plt.draw_if_interactive()
- def generate(self) -> None:
- self._args_adjust()
- self._compute_plot_data()
- self._setup_subplots()
- self._make_plot()
- self._add_table()
- self._make_legend()
- self._adorn_subplots()
- for ax in self.axes:
- self._post_plot_logic_common(ax, self.data)
- self._post_plot_logic(ax, self.data)
- @abstractmethod
- def _args_adjust(self) -> None:
- pass
- def _has_plotted_object(self, ax: Axes) -> bool:
- """check whether ax has data"""
- return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0
- def _maybe_right_yaxis(self, ax: Axes, axes_num):
- if not self.on_right(axes_num):
- # secondary axes may be passed via ax kw
- return self._get_ax_layer(ax)
- if hasattr(ax, "right_ax"):
- # if it has right_ax property, ``ax`` must be left axes
- return ax.right_ax
- elif hasattr(ax, "left_ax"):
- # if it has left_ax property, ``ax`` must be right axes
- return ax
- else:
- # otherwise, create twin axes
- orig_ax, new_ax = ax, ax.twinx()
- # TODO: use Matplotlib public API when available
- new_ax._get_lines = orig_ax._get_lines
- new_ax._get_patches_for_fill = orig_ax._get_patches_for_fill
- orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
- if not self._has_plotted_object(orig_ax): # no data on left y
- orig_ax.get_yaxis().set_visible(False)
- if self.logy is True or self.loglog is True:
- new_ax.set_yscale("log")
- elif self.logy == "sym" or self.loglog == "sym":
- new_ax.set_yscale("symlog")
- return new_ax
- def _setup_subplots(self):
- if self.subplots:
- naxes = (
- self.nseries if isinstance(self.subplots, bool) else len(self.subplots)
- )
- fig, axes = create_subplots(
- naxes=naxes,
- sharex=self.sharex,
- sharey=self.sharey,
- figsize=self.figsize,
- ax=self.ax,
- layout=self.layout,
- layout_type=self._layout_type,
- )
- else:
- if self.ax is None:
- fig = self.plt.figure(figsize=self.figsize)
- axes = fig.add_subplot(111)
- else:
- fig = self.ax.get_figure()
- if self.figsize is not None:
- fig.set_size_inches(self.figsize)
- axes = self.ax
- axes = flatten_axes(axes)
- valid_log = {False, True, "sym", None}
- input_log = {self.logx, self.logy, self.loglog}
- if input_log - valid_log:
- invalid_log = next(iter(input_log - valid_log))
- raise ValueError(
- f"Boolean, None and 'sym' are valid options, '{invalid_log}' is given."
- )
- if self.logx is True or self.loglog is True:
- [a.set_xscale("log") for a in axes]
- elif self.logx == "sym" or self.loglog == "sym":
- [a.set_xscale("symlog") for a in axes]
- if self.logy is True or self.loglog is True:
- [a.set_yscale("log") for a in axes]
- elif self.logy == "sym" or self.loglog == "sym":
- [a.set_yscale("symlog") for a in axes]
- self.fig = fig
- self.axes = axes
- @property
- def result(self):
- """
- Return result axes
- """
- if self.subplots:
- if self.layout is not None and not is_list_like(self.ax):
- return self.axes.reshape(*self.layout)
- else:
- return self.axes
- else:
- sec_true = isinstance(self.secondary_y, bool) and self.secondary_y
- # error: Argument 1 to "len" has incompatible type "Union[bool,
- # Tuple[Any, ...], List[Any], ndarray[Any, Any]]"; expected "Sized"
- all_sec = (
- is_list_like(self.secondary_y)
- and len(self.secondary_y) == self.nseries # type: ignore[arg-type]
- )
- if sec_true or all_sec:
- # if all data is plotted on secondary, return right axes
- return self._get_ax_layer(self.axes[0], primary=False)
- else:
- return self.axes[0]
- def _convert_to_ndarray(self, data):
- # GH31357: categorical columns are processed separately
- if is_categorical_dtype(data):
- return data
- # GH32073: cast to float if values contain nulled integers
- if (
- is_integer_dtype(data.dtype) or is_float_dtype(data.dtype)
- ) and is_extension_array_dtype(data.dtype):
- return data.to_numpy(dtype="float", na_value=np.nan)
- # GH25587: cast ExtensionArray of pandas (IntegerArray, etc.) to
- # np.ndarray before plot.
- if len(data) > 0:
- return np.asarray(data)
- return data
- def _compute_plot_data(self):
- data = self.data
- if isinstance(data, ABCSeries):
- label = self.label
- if label is None and data.name is None:
- label = ""
- if label is None:
- # We'll end up with columns of [0] instead of [None]
- data = data.to_frame()
- else:
- data = data.to_frame(name=label)
- elif self._kind in ("hist", "box"):
- cols = self.columns if self.by is None else self.columns + self.by
- data = data.loc[:, cols]
- # GH15079 reconstruct data if by is defined
- if self.by is not None:
- self.subplots = True
- data = reconstruct_data_with_by(self.data, by=self.by, cols=self.columns)
- # GH16953, infer_objects is needed as fallback, for ``Series``
- # with ``dtype == object``
- data = data.infer_objects(copy=False)
- include_type = [np.number, "datetime", "datetimetz", "timedelta"]
- # GH23719, allow plotting boolean
- if self.include_bool is True:
- include_type.append(np.bool_)
- # GH22799, exclude datetime-like type for boxplot
- exclude_type = None
- if self._kind == "box":
- # TODO: change after solving issue 27881
- include_type = [np.number]
- exclude_type = ["timedelta"]
- # GH 18755, include object and category type for scatter plot
- if self._kind == "scatter":
- include_type.extend(["object", "category"])
- numeric_data = data.select_dtypes(include=include_type, exclude=exclude_type)
- try:
- is_empty = numeric_data.columns.empty
- except AttributeError:
- is_empty = not len(numeric_data)
- # no non-numeric frames or series allowed
- if is_empty:
- raise TypeError("no numeric data to plot")
- self.data = numeric_data.apply(self._convert_to_ndarray)
- def _make_plot(self):
- raise AbstractMethodError(self)
- def _add_table(self) -> None:
- if self.table is False:
- return
- elif self.table is True:
- data = self.data.transpose()
- else:
- data = self.table
- ax = self._get_ax(0)
- tools.table(ax, data)
- def _post_plot_logic_common(self, ax, data):
- """Common post process for each axes"""
- if self.orientation == "vertical" or self.orientation is None:
- self._apply_axis_properties(ax.xaxis, rot=self.rot, fontsize=self.fontsize)
- self._apply_axis_properties(ax.yaxis, fontsize=self.fontsize)
- if hasattr(ax, "right_ax"):
- self._apply_axis_properties(ax.right_ax.yaxis, fontsize=self.fontsize)
- elif self.orientation == "horizontal":
- self._apply_axis_properties(ax.yaxis, rot=self.rot, fontsize=self.fontsize)
- self._apply_axis_properties(ax.xaxis, fontsize=self.fontsize)
- if hasattr(ax, "right_ax"):
- self._apply_axis_properties(ax.right_ax.yaxis, fontsize=self.fontsize)
- else: # pragma no cover
- raise ValueError
- @abstractmethod
- def _post_plot_logic(self, ax, data) -> None:
- """Post process for each axes. Overridden in child classes"""
- def _adorn_subplots(self):
- """Common post process unrelated to data"""
- if len(self.axes) > 0:
- all_axes = self._get_subplots()
- nrows, ncols = self._get_axes_layout()
- handle_shared_axes(
- axarr=all_axes,
- nplots=len(all_axes),
- naxes=nrows * ncols,
- nrows=nrows,
- ncols=ncols,
- sharex=self.sharex,
- sharey=self.sharey,
- )
- for ax in self.axes:
- ax = getattr(ax, "right_ax", ax)
- if self.yticks is not None:
- ax.set_yticks(self.yticks)
- if self.xticks is not None:
- ax.set_xticks(self.xticks)
- if self.ylim is not None:
- ax.set_ylim(self.ylim)
- if self.xlim is not None:
- ax.set_xlim(self.xlim)
- # GH9093, currently Pandas does not show ylabel, so if users provide
- # ylabel will set it as ylabel in the plot.
- if self.ylabel is not None:
- ax.set_ylabel(pprint_thing(self.ylabel))
- ax.grid(self.grid)
- if self.title:
- if self.subplots:
- if is_list_like(self.title):
- if len(self.title) != self.nseries:
- raise ValueError(
- "The length of `title` must equal the number "
- "of columns if using `title` of type `list` "
- "and `subplots=True`.\n"
- f"length of title = {len(self.title)}\n"
- f"number of columns = {self.nseries}"
- )
- for ax, title in zip(self.axes, self.title):
- ax.set_title(title)
- else:
- self.fig.suptitle(self.title)
- else:
- if is_list_like(self.title):
- msg = (
- "Using `title` of type `list` is not supported "
- "unless `subplots=True` is passed"
- )
- raise ValueError(msg)
- self.axes[0].set_title(self.title)
- def _apply_axis_properties(self, axis: Axis, rot=None, fontsize=None) -> None:
- """
- Tick creation within matplotlib is reasonably expensive and is
- internally deferred until accessed as Ticks are created/destroyed
- multiple times per draw. It's therefore beneficial for us to avoid
- accessing unless we will act on the Tick.
- """
- if rot is not None or fontsize is not None:
- # rot=0 is a valid setting, hence the explicit None check
- labels = axis.get_majorticklabels() + axis.get_minorticklabels()
- for label in labels:
- if rot is not None:
- label.set_rotation(rot)
- if fontsize is not None:
- label.set_fontsize(fontsize)
- @property
- def legend_title(self) -> str | None:
- if not isinstance(self.data.columns, ABCMultiIndex):
- name = self.data.columns.name
- if name is not None:
- name = pprint_thing(name)
- return name
- else:
- stringified = map(pprint_thing, self.data.columns.names)
- return ",".join(stringified)
- def _mark_right_label(self, label: str, index: int) -> str:
- """
- Append ``(right)`` to the label of a line if it's plotted on the right axis.
- Note that ``(right)`` is only appended when ``subplots=False``.
- """
- if not self.subplots and self.mark_right and self.on_right(index):
- label += " (right)"
- return label
- def _append_legend_handles_labels(self, handle: Artist, label: str) -> None:
- """
- Append current handle and label to ``legend_handles`` and ``legend_labels``.
- These will be used to make the legend.
- """
- self.legend_handles.append(handle)
- self.legend_labels.append(label)
- def _make_legend(self) -> None:
- ax, leg = self._get_ax_legend(self.axes[0])
- handles = []
- labels = []
- title = ""
- if not self.subplots:
- if leg is not None:
- title = leg.get_title().get_text()
- # Replace leg.legend_handles because it misses marker info
- if Version(mpl.__version__) < Version("3.7"):
- handles = leg.legendHandles
- else:
- handles = leg.legend_handles
- labels = [x.get_text() for x in leg.get_texts()]
- if self.legend:
- if self.legend == "reverse":
- handles += reversed(self.legend_handles)
- labels += reversed(self.legend_labels)
- else:
- handles += self.legend_handles
- labels += self.legend_labels
- if self.legend_title is not None:
- title = self.legend_title
- if len(handles) > 0:
- ax.legend(handles, labels, loc="best", title=title)
- elif self.subplots and self.legend:
- for ax in self.axes:
- if ax.get_visible():
- ax.legend(loc="best")
- def _get_ax_legend(self, ax: Axes):
- """
- Take in axes and return ax and legend under different scenarios
- """
- leg = ax.get_legend()
- other_ax = getattr(ax, "left_ax", None) or getattr(ax, "right_ax", None)
- other_leg = None
- if other_ax is not None:
- other_leg = other_ax.get_legend()
- if leg is None and other_leg is not None:
- leg = other_leg
- ax = other_ax
- return ax, leg
- @cache_readonly
- def plt(self):
- import matplotlib.pyplot as plt
- return plt
- _need_to_set_index = False
- def _get_xticks(self, convert_period: bool = False):
- index = self.data.index
- is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time")
- if self.use_index:
- if convert_period and isinstance(index, ABCPeriodIndex):
- self.data = self.data.reindex(index=index.sort_values())
- x = self.data.index.to_timestamp()._mpl_repr()
- elif is_any_real_numeric_dtype(index):
- # Matplotlib supports numeric values or datetime objects as
- # xaxis values. Taking LBYL approach here, by the time
- # matplotlib raises exception when using non numeric/datetime
- # values for xaxis, several actions are already taken by plt.
- x = index._mpl_repr()
- elif is_datetype:
- self.data = self.data[notna(self.data.index)]
- self.data = self.data.sort_index()
- x = self.data.index._mpl_repr()
- else:
- self._need_to_set_index = True
- x = list(range(len(index)))
- else:
- x = list(range(len(index)))
- return x
- @classmethod
- @register_pandas_matplotlib_converters
- def _plot(
- cls, ax: Axes, x, y: np.ndarray, style=None, is_errorbar: bool = False, **kwds
- ):
- mask = isna(y)
- if mask.any():
- y = np.ma.array(y)
- y = np.ma.masked_where(mask, y)
- if isinstance(x, ABCIndex):
- x = x._mpl_repr()
- if is_errorbar:
- if "xerr" in kwds:
- kwds["xerr"] = np.array(kwds.get("xerr"))
- if "yerr" in kwds:
- kwds["yerr"] = np.array(kwds.get("yerr"))
- return ax.errorbar(x, y, **kwds)
- else:
- # prevent style kwarg from going to errorbar, where it is unsupported
- args = (x, y, style) if style is not None else (x, y)
- return ax.plot(*args, **kwds)
- def _get_custom_index_name(self):
- """Specify whether xlabel/ylabel should be used to override index name"""
- return self.xlabel
- def _get_index_name(self) -> str | None:
- if isinstance(self.data.index, ABCMultiIndex):
- name = self.data.index.names
- if com.any_not_none(*name):
- name = ",".join([pprint_thing(x) for x in name])
- else:
- name = None
- else:
- name = self.data.index.name
- if name is not None:
- name = pprint_thing(name)
- # GH 45145, override the default axis label if one is provided.
- index_name = self._get_custom_index_name()
- if index_name is not None:
- name = pprint_thing(index_name)
- return name
- @classmethod
- def _get_ax_layer(cls, ax, primary: bool = True):
- """get left (primary) or right (secondary) axes"""
- if primary:
- return getattr(ax, "left_ax", ax)
- else:
- return getattr(ax, "right_ax", ax)
- def _col_idx_to_axis_idx(self, col_idx: int) -> int:
- """Return the index of the axis where the column at col_idx should be plotted"""
- if isinstance(self.subplots, list):
- # Subplots is a list: some columns will be grouped together in the same ax
- return next(
- group_idx
- for (group_idx, group) in enumerate(self.subplots)
- if col_idx in group
- )
- else:
- # subplots is True: one ax per column
- return col_idx
- def _get_ax(self, i: int):
- # get the twinx ax if appropriate
- if self.subplots:
- i = self._col_idx_to_axis_idx(i)
- ax = self.axes[i]
- ax = self._maybe_right_yaxis(ax, i)
- self.axes[i] = ax
- else:
- ax = self.axes[0]
- ax = self._maybe_right_yaxis(ax, i)
- ax.get_yaxis().set_visible(True)
- return ax
- @classmethod
- def get_default_ax(cls, ax) -> None:
- import matplotlib.pyplot as plt
- if ax is None and len(plt.get_fignums()) > 0:
- with plt.rc_context():
- ax = plt.gca()
- ax = cls._get_ax_layer(ax)
- def on_right(self, i):
- if isinstance(self.secondary_y, bool):
- return self.secondary_y
- if isinstance(self.secondary_y, (tuple, list, np.ndarray, ABCIndex)):
- return self.data.columns[i] in self.secondary_y
- def _apply_style_colors(self, colors, kwds, col_num, label):
- """
- Manage style and color based on column number and its label.
- Returns tuple of appropriate style and kwds which "color" may be added.
- """
- style = None
- if self.style is not None:
- if isinstance(self.style, list):
- try:
- style = self.style[col_num]
- except IndexError:
- pass
- elif isinstance(self.style, dict):
- style = self.style.get(label, style)
- else:
- style = self.style
- has_color = "color" in kwds or self.colormap is not None
- nocolor_style = style is None or not _color_in_style(style)
- if (has_color or self.subplots) and nocolor_style:
- if isinstance(colors, dict):
- kwds["color"] = colors[label]
- else:
- kwds["color"] = colors[col_num % len(colors)]
- return style, kwds
- def _get_colors(
- self,
- num_colors: int | None = None,
- color_kwds: str = "color",
- ):
- if num_colors is None:
- num_colors = self.nseries
- return get_standard_colors(
- num_colors=num_colors,
- colormap=self.colormap,
- color=self.kwds.get(color_kwds),
- )
- def _parse_errorbars(self, label, err):
- """
- Look for error keyword arguments and return the actual errorbar data
- or return the error DataFrame/dict
- Error bars can be specified in several ways:
- Series: the user provides a pandas.Series object of the same
- length as the data
- ndarray: provides a np.ndarray of the same length as the data
- DataFrame/dict: error values are paired with keys matching the
- key in the plotted DataFrame
- str: the name of the column within the plotted DataFrame
- Asymmetrical error bars are also supported, however raw error values
- must be provided in this case. For a ``N`` length :class:`Series`, a
- ``2xN`` array should be provided indicating lower and upper (or left
- and right) errors. For a ``MxN`` :class:`DataFrame`, asymmetrical errors
- should be in a ``Mx2xN`` array.
- """
- if err is None:
- return None
- def match_labels(data, e):
- e = e.reindex(data.index)
- return e
- # key-matched DataFrame
- if isinstance(err, ABCDataFrame):
- err = match_labels(self.data, err)
- # key-matched dict
- elif isinstance(err, dict):
- pass
- # Series of error values
- elif isinstance(err, ABCSeries):
- # broadcast error series across data
- err = match_labels(self.data, err)
- err = np.atleast_2d(err)
- err = np.tile(err, (self.nseries, 1))
- # errors are a column in the dataframe
- elif isinstance(err, str):
- evalues = self.data[err].values
- self.data = self.data[self.data.columns.drop(err)]
- err = np.atleast_2d(evalues)
- err = np.tile(err, (self.nseries, 1))
- elif is_list_like(err):
- if is_iterator(err):
- err = np.atleast_2d(list(err))
- else:
- # raw error values
- err = np.atleast_2d(err)
- err_shape = err.shape
- # asymmetrical error bars
- if isinstance(self.data, ABCSeries) and err_shape[0] == 2:
- err = np.expand_dims(err, 0)
- err_shape = err.shape
- if err_shape[2] != len(self.data):
- raise ValueError(
- "Asymmetrical error bars should be provided "
- f"with the shape (2, {len(self.data)})"
- )
- elif isinstance(self.data, ABCDataFrame) and err.ndim == 3:
- if (
- (err_shape[0] != self.nseries)
- or (err_shape[1] != 2)
- or (err_shape[2] != len(self.data))
- ):
- raise ValueError(
- "Asymmetrical error bars should be provided "
- f"with the shape ({self.nseries}, 2, {len(self.data)})"
- )
- # broadcast errors to each data series
- if len(err) == 1:
- err = np.tile(err, (self.nseries, 1))
- elif is_number(err):
- err = np.tile([err], (self.nseries, len(self.data)))
- else:
- msg = f"No valid {label} detected"
- raise ValueError(msg)
- return err
- def _get_errorbars(
- self, label=None, index=None, xerr: bool = True, yerr: bool = True
- ):
- errors = {}
- for kw, flag in zip(["xerr", "yerr"], [xerr, yerr]):
- if flag:
- err = self.errors[kw]
- # user provided label-matched dataframe of errors
- if isinstance(err, (ABCDataFrame, dict)):
- if label is not None and label in err.keys():
- err = err[label]
- else:
- err = None
- elif index is not None and err is not None:
- err = err[index]
- if err is not None:
- errors[kw] = err
- return errors
- def _get_subplots(self):
- from matplotlib.axes import Subplot
- return [
- ax
- for ax in self.fig.get_axes()
- if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None)
- ]
- def _get_axes_layout(self) -> tuple[int, int]:
- axes = self._get_subplots()
- x_set = set()
- y_set = set()
- for ax in axes:
- # check axes coordinates to estimate layout
- points = ax.get_position().get_points()
- x_set.add(points[0][0])
- y_set.add(points[0][1])
- return (len(y_set), len(x_set))
- class PlanePlot(MPLPlot, ABC):
- """
- Abstract class for plotting on plane, currently scatter and hexbin.
- """
- _layout_type = "single"
- def __init__(self, data, x, y, **kwargs) -> None:
- MPLPlot.__init__(self, data, **kwargs)
- if x is None or y is None:
- raise ValueError(self._kind + " requires an x and y column")
- if is_integer(x) and not self.data.columns._holds_integer():
- x = self.data.columns[x]
- if is_integer(y) and not self.data.columns._holds_integer():
- y = self.data.columns[y]
- # Scatter plot allows to plot objects data
- if self._kind == "hexbin":
- if len(self.data[x]._get_numeric_data()) == 0:
- raise ValueError(self._kind + " requires x column to be numeric")
- if len(self.data[y]._get_numeric_data()) == 0:
- raise ValueError(self._kind + " requires y column to be numeric")
- self.x = x
- self.y = y
- @property
- def nseries(self) -> int:
- return 1
- def _post_plot_logic(self, ax: Axes, data) -> None:
- x, y = self.x, self.y
- xlabel = self.xlabel if self.xlabel is not None else pprint_thing(x)
- ylabel = self.ylabel if self.ylabel is not None else pprint_thing(y)
- ax.set_xlabel(xlabel)
- ax.set_ylabel(ylabel)
- def _plot_colorbar(self, ax: Axes, **kwds):
- # Addresses issues #10611 and #10678:
- # When plotting scatterplots and hexbinplots in IPython
- # inline backend the colorbar axis height tends not to
- # exactly match the parent axis height.
- # The difference is due to small fractional differences
- # in floating points with similar representation.
- # To deal with this, this method forces the colorbar
- # height to take the height of the parent axes.
- # For a more detailed description of the issue
- # see the following link:
- # https://github.com/ipython/ipython/issues/11215
- # GH33389, if ax is used multiple times, we should always
- # use the last one which contains the latest information
- # about the ax
- img = ax.collections[-1]
- return self.fig.colorbar(img, ax=ax, **kwds)
- class ScatterPlot(PlanePlot):
- @property
- def _kind(self) -> Literal["scatter"]:
- return "scatter"
- def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None:
- if s is None:
- # hide the matplotlib default for size, in case we want to change
- # the handling of this argument later
- s = 20
- elif is_hashable(s) and s in data.columns:
- s = data[s]
- super().__init__(data, x, y, s=s, **kwargs)
- if is_integer(c) and not self.data.columns._holds_integer():
- c = self.data.columns[c]
- self.c = c
- def _make_plot(self):
- x, y, c, data = self.x, self.y, self.c, self.data
- ax = self.axes[0]
- c_is_column = is_hashable(c) and c in self.data.columns
- color_by_categorical = c_is_column and is_categorical_dtype(self.data[c])
- color = self.kwds.pop("color", None)
- if c is not None and color is not None:
- raise TypeError("Specify exactly one of `c` and `color`")
- if c is None and color is None:
- c_values = self.plt.rcParams["patch.facecolor"]
- elif color is not None:
- c_values = color
- elif color_by_categorical:
- c_values = self.data[c].cat.codes
- elif c_is_column:
- c_values = self.data[c].values
- else:
- c_values = c
- if self.colormap is not None:
- cmap = mpl.colormaps.get_cmap(self.colormap)
- else:
- # cmap is only used if c_values are integers, otherwise UserWarning
- if is_integer_dtype(c_values):
- # pandas uses colormap, matplotlib uses cmap.
- cmap = "Greys"
- cmap = mpl.colormaps[cmap]
- else:
- cmap = None
- if color_by_categorical:
- from matplotlib import colors
- n_cats = len(self.data[c].cat.categories)
- cmap = colors.ListedColormap([cmap(i) for i in range(cmap.N)])
- bounds = np.linspace(0, n_cats, n_cats + 1)
- norm = colors.BoundaryNorm(bounds, cmap.N)
- else:
- norm = self.kwds.pop("norm", None)
- # plot colorbar if
- # 1. colormap is assigned, and
- # 2.`c` is a column containing only numeric values
- plot_colorbar = self.colormap or c_is_column
- cb = self.kwds.pop("colorbar", is_numeric_dtype(c_values) and plot_colorbar)
- if self.legend and hasattr(self, "label"):
- label = self.label
- else:
- label = None
- scatter = ax.scatter(
- data[x].values,
- data[y].values,
- c=c_values,
- label=label,
- cmap=cmap,
- norm=norm,
- **self.kwds,
- )
- if cb:
- cbar_label = c if c_is_column else ""
- cbar = self._plot_colorbar(ax, label=cbar_label)
- if color_by_categorical:
- cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats))
- cbar.ax.set_yticklabels(self.data[c].cat.categories)
- if label is not None:
- self._append_legend_handles_labels(scatter, label)
- else:
- self.legend = False
- errors_x = self._get_errorbars(label=x, index=0, yerr=False)
- errors_y = self._get_errorbars(label=y, index=0, xerr=False)
- if len(errors_x) > 0 or len(errors_y) > 0:
- err_kwds = dict(errors_x, **errors_y)
- err_kwds["ecolor"] = scatter.get_facecolor()[0]
- ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds)
- def _args_adjust(self) -> None:
- pass
- class HexBinPlot(PlanePlot):
- @property
- def _kind(self) -> Literal["hexbin"]:
- return "hexbin"
- def __init__(self, data, x, y, C=None, **kwargs) -> None:
- super().__init__(data, x, y, **kwargs)
- if is_integer(C) and not self.data.columns._holds_integer():
- C = self.data.columns[C]
- self.C = C
- def _make_plot(self) -> None:
- x, y, data, C = self.x, self.y, self.data, self.C
- ax = self.axes[0]
- # pandas uses colormap, matplotlib uses cmap.
- cmap = self.colormap or "BuGn"
- cmap = mpl.colormaps.get_cmap(cmap)
- cb = self.kwds.pop("colorbar", True)
- if C is None:
- c_values = None
- else:
- c_values = data[C].values
- ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap, **self.kwds)
- if cb:
- self._plot_colorbar(ax)
- def _make_legend(self) -> None:
- pass
- def _args_adjust(self) -> None:
- pass
- class LinePlot(MPLPlot):
- _default_rot = 0
- @property
- def orientation(self) -> PlottingOrientation:
- return "vertical"
- @property
- def _kind(self) -> Literal["line", "area", "hist", "kde", "box"]:
- return "line"
- def __init__(self, data, **kwargs) -> None:
- from pandas.plotting import plot_params
- MPLPlot.__init__(self, data, **kwargs)
- if self.stacked:
- self.data = self.data.fillna(value=0)
- self.x_compat = plot_params["x_compat"]
- if "x_compat" in self.kwds:
- self.x_compat = bool(self.kwds.pop("x_compat"))
- def _is_ts_plot(self) -> bool:
- # this is slightly deceptive
- return not self.x_compat and self.use_index and self._use_dynamic_x()
- def _use_dynamic_x(self):
- return use_dynamic_x(self._get_ax(0), self.data)
- def _make_plot(self) -> None:
- if self._is_ts_plot():
- data = maybe_convert_index(self._get_ax(0), self.data)
- x = data.index # dummy, not used
- plotf = self._ts_plot
- it = self._iter_data(data=data, keep_index=True)
- else:
- x = self._get_xticks(convert_period=True)
- # error: Incompatible types in assignment (expression has type
- # "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
- # type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
- plotf = self._plot # type: ignore[assignment]
- it = self._iter_data()
- stacking_id = self._get_stacking_id()
- is_errorbar = com.any_not_none(*self.errors.values())
- colors = self._get_colors()
- for i, (label, y) in enumerate(it):
- ax = self._get_ax(i)
- kwds = self.kwds.copy()
- style, kwds = self._apply_style_colors(colors, kwds, i, label)
- errors = self._get_errorbars(label=label, index=i)
- kwds = dict(kwds, **errors)
- label = pprint_thing(label) # .encode('utf-8')
- label = self._mark_right_label(label, index=i)
- kwds["label"] = label
- newlines = plotf(
- ax,
- x,
- y,
- style=style,
- column_num=i,
- stacking_id=stacking_id,
- is_errorbar=is_errorbar,
- **kwds,
- )
- self._append_legend_handles_labels(newlines[0], label)
- if self._is_ts_plot():
- # reset of xlim should be used for ts data
- # TODO: GH28021, should find a way to change view limit on xaxis
- lines = get_all_lines(ax)
- left, right = get_xlim(lines)
- ax.set_xlim(left, right)
- # error: Signature of "_plot" incompatible with supertype "MPLPlot"
- @classmethod
- def _plot( # type: ignore[override]
- cls, ax: Axes, x, y, style=None, column_num=None, stacking_id=None, **kwds
- ):
- # column_num is used to get the target column from plotf in line and
- # area plots
- if column_num == 0:
- cls._initialize_stacker(ax, stacking_id, len(y))
- y_values = cls._get_stacked_values(ax, stacking_id, y, kwds["label"])
- lines = MPLPlot._plot(ax, x, y_values, style=style, **kwds)
- cls._update_stacker(ax, stacking_id, y)
- return lines
- def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
- # accept x to be consistent with normal plot func,
- # x is not passed to tsplot as it uses data.index as x coordinate
- # column_num must be in kwds for stacking purpose
- freq, data = maybe_resample(data, ax, kwds)
- # Set ax with freq info
- decorate_axes(ax, freq, kwds)
- # digging deeper
- if hasattr(ax, "left_ax"):
- decorate_axes(ax.left_ax, freq, kwds)
- if hasattr(ax, "right_ax"):
- decorate_axes(ax.right_ax, freq, kwds)
- ax._plot_data.append((data, self._kind, kwds))
- lines = self._plot(ax, data.index, data.values, style=style, **kwds)
- # set date formatter, locators and rescale limits
- format_dateaxis(ax, ax.freq, data.index)
- return lines
- def _get_stacking_id(self):
- if self.stacked:
- return id(self.data)
- else:
- return None
- @classmethod
- def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
- if stacking_id is None:
- return
- if not hasattr(ax, "_stacker_pos_prior"):
- ax._stacker_pos_prior = {}
- if not hasattr(ax, "_stacker_neg_prior"):
- ax._stacker_neg_prior = {}
- ax._stacker_pos_prior[stacking_id] = np.zeros(n)
- ax._stacker_neg_prior[stacking_id] = np.zeros(n)
- @classmethod
- def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
- if stacking_id is None:
- return values
- if not hasattr(ax, "_stacker_pos_prior"):
- # stacker may not be initialized for subplots
- cls._initialize_stacker(ax, stacking_id, len(values))
- if (values >= 0).all():
- return ax._stacker_pos_prior[stacking_id] + values
- elif (values <= 0).all():
- return ax._stacker_neg_prior[stacking_id] + values
- raise ValueError(
- "When stacked is True, each column must be either "
- "all positive or all negative. "
- f"Column '{label}' contains both positive and negative values"
- )
- @classmethod
- def _update_stacker(cls, ax: Axes, stacking_id, values) -> None:
- if stacking_id is None:
- return
- if (values >= 0).all():
- ax._stacker_pos_prior[stacking_id] += values
- elif (values <= 0).all():
- ax._stacker_neg_prior[stacking_id] += values
- def _args_adjust(self) -> None:
- pass
- def _post_plot_logic(self, ax: Axes, data) -> None:
- from matplotlib.ticker import FixedLocator
- def get_label(i):
- if is_float(i) and i.is_integer():
- i = int(i)
- try:
- return pprint_thing(data.index[i])
- except Exception:
- return ""
- if self._need_to_set_index:
- xticks = ax.get_xticks()
- xticklabels = [get_label(x) for x in xticks]
- ax.xaxis.set_major_locator(FixedLocator(xticks))
- ax.set_xticklabels(xticklabels)
- # If the index is an irregular time series, then by default
- # we rotate the tick labels. The exception is if there are
- # subplots which don't share their x-axes, in which we case
- # we don't rotate the ticklabels as by default the subplots
- # would be too close together.
- condition = (
- not self._use_dynamic_x()
- and (data.index._is_all_dates and self.use_index)
- and (not self.subplots or (self.subplots and self.sharex))
- )
- index_name = self._get_index_name()
- if condition:
- # irregular TS rotated 30 deg. by default
- # probably a better place to check / set this.
- if not self._rot_set:
- self.rot = 30
- format_date_labels(ax, rot=self.rot)
- if index_name is not None and self.use_index:
- ax.set_xlabel(index_name)
- class AreaPlot(LinePlot):
- @property
- def _kind(self) -> Literal["area"]:
- return "area"
- def __init__(self, data, **kwargs) -> None:
- kwargs.setdefault("stacked", True)
- data = data.fillna(value=0)
- LinePlot.__init__(self, data, **kwargs)
- if not self.stacked:
- # use smaller alpha to distinguish overlap
- self.kwds.setdefault("alpha", 0.5)
- if self.logy or self.loglog:
- raise ValueError("Log-y scales are not supported in area plot")
- # error: Signature of "_plot" incompatible with supertype "MPLPlot"
- @classmethod
- def _plot( # type: ignore[override]
- cls,
- ax: Axes,
- x,
- y,
- style=None,
- column_num=None,
- stacking_id=None,
- is_errorbar: bool = False,
- **kwds,
- ):
- if column_num == 0:
- cls._initialize_stacker(ax, stacking_id, len(y))
- y_values = cls._get_stacked_values(ax, stacking_id, y, kwds["label"])
- # need to remove label, because subplots uses mpl legend as it is
- line_kwds = kwds.copy()
- line_kwds.pop("label")
- lines = MPLPlot._plot(ax, x, y_values, style=style, **line_kwds)
- # get data from the line to get coordinates for fill_between
- xdata, y_values = lines[0].get_data(orig=False)
- # unable to use ``_get_stacked_values`` here to get starting point
- if stacking_id is None:
- start = np.zeros(len(y))
- elif (y >= 0).all():
- start = ax._stacker_pos_prior[stacking_id]
- elif (y <= 0).all():
- start = ax._stacker_neg_prior[stacking_id]
- else:
- start = np.zeros(len(y))
- if "color" not in kwds:
- kwds["color"] = lines[0].get_color()
- rect = ax.fill_between(xdata, start, y_values, **kwds)
- cls._update_stacker(ax, stacking_id, y)
- # LinePlot expects list of artists
- res = [rect]
- return res
- def _args_adjust(self) -> None:
- pass
- def _post_plot_logic(self, ax: Axes, data) -> None:
- LinePlot._post_plot_logic(self, ax, data)
- is_shared_y = len(list(ax.get_shared_y_axes())) > 0
- # do not override the default axis behaviour in case of shared y axes
- if self.ylim is None and not is_shared_y:
- if (data >= 0).all().all():
- ax.set_ylim(0, None)
- elif (data <= 0).all().all():
- ax.set_ylim(None, 0)
- class BarPlot(MPLPlot):
- @property
- def _kind(self) -> Literal["bar", "barh"]:
- return "bar"
- _default_rot = 90
- @property
- def orientation(self) -> PlottingOrientation:
- return "vertical"
- def __init__(self, data, **kwargs) -> None:
- # we have to treat a series differently than a
- # 1-column DataFrame w.r.t. color handling
- self._is_series = isinstance(data, ABCSeries)
- self.bar_width = kwargs.pop("width", 0.5)
- pos = kwargs.pop("position", 0.5)
- kwargs.setdefault("align", "center")
- self.tick_pos = np.arange(len(data))
- self.bottom = kwargs.pop("bottom", 0)
- self.left = kwargs.pop("left", 0)
- self.log = kwargs.pop("log", False)
- MPLPlot.__init__(self, data, **kwargs)
- if self.stacked or self.subplots:
- self.tickoffset = self.bar_width * pos
- if kwargs["align"] == "edge":
- self.lim_offset = self.bar_width / 2
- else:
- self.lim_offset = 0
- else:
- if kwargs["align"] == "edge":
- w = self.bar_width / self.nseries
- self.tickoffset = self.bar_width * (pos - 0.5) + w * 0.5
- self.lim_offset = w * 0.5
- else:
- self.tickoffset = self.bar_width * pos
- self.lim_offset = 0
- self.ax_pos = self.tick_pos - self.tickoffset
- def _args_adjust(self) -> None:
- if is_list_like(self.bottom):
- self.bottom = np.array(self.bottom)
- if is_list_like(self.left):
- self.left = np.array(self.left)
- # error: Signature of "_plot" incompatible with supertype "MPLPlot"
- @classmethod
- def _plot( # type: ignore[override]
- cls,
- ax: Axes,
- x,
- y,
- w,
- start: int | npt.NDArray[np.intp] = 0,
- log: bool = False,
- **kwds,
- ):
- return ax.bar(x, y, w, bottom=start, log=log, **kwds)
- @property
- def _start_base(self):
- return self.bottom
- def _make_plot(self) -> None:
- colors = self._get_colors()
- ncolors = len(colors)
- pos_prior = neg_prior = np.zeros(len(self.data))
- K = self.nseries
- for i, (label, y) in enumerate(self._iter_data(fillna=0)):
- ax = self._get_ax(i)
- kwds = self.kwds.copy()
- if self._is_series:
- kwds["color"] = colors
- elif isinstance(colors, dict):
- kwds["color"] = colors[label]
- else:
- kwds["color"] = colors[i % ncolors]
- errors = self._get_errorbars(label=label, index=i)
- kwds = dict(kwds, **errors)
- label = pprint_thing(label)
- label = self._mark_right_label(label, index=i)
- if (("yerr" in kwds) or ("xerr" in kwds)) and (kwds.get("ecolor") is None):
- kwds["ecolor"] = mpl.rcParams["xtick.color"]
- start = 0
- if self.log and (y >= 1).all():
- start = 1
- start = start + self._start_base
- if self.subplots:
- w = self.bar_width / 2
- rect = self._plot(
- ax,
- self.ax_pos + w,
- y,
- self.bar_width,
- start=start,
- label=label,
- log=self.log,
- **kwds,
- )
- ax.set_title(label)
- elif self.stacked:
- mask = y > 0
- start = np.where(mask, pos_prior, neg_prior) + self._start_base
- w = self.bar_width / 2
- rect = self._plot(
- ax,
- self.ax_pos + w,
- y,
- self.bar_width,
- start=start,
- label=label,
- log=self.log,
- **kwds,
- )
- pos_prior = pos_prior + np.where(mask, y, 0)
- neg_prior = neg_prior + np.where(mask, 0, y)
- else:
- w = self.bar_width / K
- rect = self._plot(
- ax,
- self.ax_pos + (i + 0.5) * w,
- y,
- w,
- start=start,
- label=label,
- log=self.log,
- **kwds,
- )
- self._append_legend_handles_labels(rect, label)
- def _post_plot_logic(self, ax: Axes, data) -> None:
- if self.use_index:
- str_index = [pprint_thing(key) for key in data.index]
- else:
- str_index = [pprint_thing(key) for key in range(data.shape[0])]
- s_edge = self.ax_pos[0] - 0.25 + self.lim_offset
- e_edge = self.ax_pos[-1] + 0.25 + self.bar_width + self.lim_offset
- self._decorate_ticks(ax, self._get_index_name(), str_index, s_edge, e_edge)
- def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
- ax.set_xlim((start_edge, end_edge))
- if self.xticks is not None:
- ax.set_xticks(np.array(self.xticks))
- else:
- ax.set_xticks(self.tick_pos)
- ax.set_xticklabels(ticklabels)
- if name is not None and self.use_index:
- ax.set_xlabel(name)
- class BarhPlot(BarPlot):
- @property
- def _kind(self) -> Literal["barh"]:
- return "barh"
- _default_rot = 0
- @property
- def orientation(self) -> Literal["horizontal"]:
- return "horizontal"
- @property
- def _start_base(self):
- return self.left
- # error: Signature of "_plot" incompatible with supertype "MPLPlot"
- @classmethod
- def _plot( # type: ignore[override]
- cls,
- ax: Axes,
- x,
- y,
- w,
- start: int | npt.NDArray[np.intp] = 0,
- log: bool = False,
- **kwds,
- ):
- return ax.barh(x, y, w, left=start, log=log, **kwds)
- def _get_custom_index_name(self):
- return self.ylabel
- def _decorate_ticks(self, ax: Axes, name, ticklabels, start_edge, end_edge) -> None:
- # horizontal bars
- ax.set_ylim((start_edge, end_edge))
- ax.set_yticks(self.tick_pos)
- ax.set_yticklabels(ticklabels)
- if name is not None and self.use_index:
- ax.set_ylabel(name)
- ax.set_xlabel(self.xlabel)
- class PiePlot(MPLPlot):
- @property
- def _kind(self) -> Literal["pie"]:
- return "pie"
- _layout_type = "horizontal"
- def __init__(self, data, kind=None, **kwargs) -> None:
- data = data.fillna(value=0)
- if (data < 0).any().any():
- raise ValueError(f"{self._kind} plot doesn't allow negative values")
- MPLPlot.__init__(self, data, kind=kind, **kwargs)
- def _args_adjust(self) -> None:
- self.grid = False
- self.logy = False
- self.logx = False
- self.loglog = False
- def _validate_color_args(self) -> None:
- pass
- def _make_plot(self) -> None:
- colors = self._get_colors(num_colors=len(self.data), color_kwds="colors")
- self.kwds.setdefault("colors", colors)
- for i, (label, y) in enumerate(self._iter_data()):
- ax = self._get_ax(i)
- if label is not None:
- label = pprint_thing(label)
- ax.set_ylabel(label)
- kwds = self.kwds.copy()
- def blank_labeler(label, value):
- if value == 0:
- return ""
- else:
- return label
- idx = [pprint_thing(v) for v in self.data.index]
- labels = kwds.pop("labels", idx)
- # labels is used for each wedge's labels
- # Blank out labels for values of 0 so they don't overlap
- # with nonzero wedges
- if labels is not None:
- blabels = [blank_labeler(left, value) for left, value in zip(labels, y)]
- else:
- blabels = None
- results = ax.pie(y, labels=blabels, **kwds)
- if kwds.get("autopct", None) is not None:
- patches, texts, autotexts = results
- else:
- patches, texts = results
- autotexts = []
- if self.fontsize is not None:
- for t in texts + autotexts:
- t.set_fontsize(self.fontsize)
- # leglabels is used for legend labels
- leglabels = labels if labels is not None else idx
- for _patch, _leglabel in zip(patches, leglabels):
- self._append_legend_handles_labels(_patch, _leglabel)
- def _post_plot_logic(self, ax: Axes, data) -> None:
- pass
|