"""The classes for specifying and compiling a declarative visualization."""
from __future__ import annotations
import io
import os
import re
import inspect
import itertools
import textwrap
from contextlib import contextmanager
from collections import abc
from collections.abc import Callable, Generator
from typing import Any, List, Literal, Optional, cast
from xml.etree import ElementTree
from cycler import cycler
import pandas as pd
from pandas import DataFrame, Series, Index
import matplotlib as mpl
from matplotlib.axes import Axes
from matplotlib.artist import Artist
from matplotlib.figure import Figure
import numpy as np
from PIL import Image
from seaborn._marks.base import Mark
from seaborn._stats.base import Stat
from seaborn._core.data import PlotData
from seaborn._core.moves import Move
from seaborn._core.scales import Scale, Nominal
from seaborn._core.subplots import Subplots
from seaborn._core.groupby import GroupBy
from seaborn._core.properties import PROPERTIES, Property
from seaborn._core.typing import (
DataSource,
VariableSpec,
VariableSpecList,
OrderSpec,
Default,
)
from seaborn._core.exceptions import PlotSpecError
from seaborn._core.rules import categorical_order
from seaborn._compat import set_scale_obj, set_layout_engine
from seaborn.rcmod import axes_style, plotting_context
from seaborn.palettes import color_palette
from seaborn.utils import _version_predates
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from matplotlib.figure import SubFigure
default = Default()
# ---- Definitions for internal specs ---------------------------------------------- #
class Layer(TypedDict, total=False):
mark: Mark # TODO allow list?
stat: Stat | None # TODO allow list?
move: Move | list[Move] | None
data: PlotData
source: DataSource
vars: dict[str, VariableSpec]
orient: str
legend: bool
label: str | None
class FacetSpec(TypedDict, total=False):
variables: dict[str, VariableSpec]
structure: dict[str, list[str]]
wrap: int | None
class PairSpec(TypedDict, total=False):
variables: dict[str, VariableSpec]
structure: dict[str, list[str]]
cross: bool
wrap: int | None
# --- Local helpers ---------------------------------------------------------------- #
@contextmanager
def theme_context(params: dict[str, Any]) -> Generator:
"""Temporarily modify specifc matplotlib rcParams."""
orig_params = {k: mpl.rcParams[k] for k in params}
color_codes = "bgrmyck"
nice_colors = [*color_palette("deep6"), (.15, .15, .15)]
orig_colors = [mpl.colors.colorConverter.colors[x] for x in color_codes]
# TODO how to allow this to reflect the color cycle when relevant?
try:
mpl.rcParams.update(params)
for (code, color) in zip(color_codes, nice_colors):
mpl.colors.colorConverter.colors[code] = color
yield
finally:
mpl.rcParams.update(orig_params)
for (code, color) in zip(color_codes, orig_colors):
mpl.colors.colorConverter.colors[code] = color
def build_plot_signature(cls):
"""
Decorator function for giving Plot a useful signature.
Currently this mostly saves us some duplicated typing, but we would
like eventually to have a way of registering new semantic properties,
at which point dynamic signature generation would become more important.
"""
sig = inspect.signature(cls)
params = [
inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL),
inspect.Parameter("data", inspect.Parameter.KEYWORD_ONLY, default=None)
]
params.extend([
inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=None)
for name in PROPERTIES
])
new_sig = sig.replace(parameters=params)
cls.__signature__ = new_sig
known_properties = textwrap.fill(
", ".join([f"|{p}|" for p in PROPERTIES]),
width=78, subsequent_indent=" " * 8,
)
if cls.__doc__ is not None: # support python -OO mode
cls.__doc__ = cls.__doc__.format(known_properties=known_properties)
return cls
# ---- Plot configuration ---------------------------------------------------------- #
class ThemeConfig(mpl.RcParams):
"""
Configuration object for the Plot.theme, using matplotlib rc parameters.
"""
THEME_GROUPS = [
"axes", "figure", "font", "grid", "hatch", "legend", "lines",
"mathtext", "markers", "patch", "savefig", "scatter",
"xaxis", "xtick", "yaxis", "ytick",
]
def __init__(self):
super().__init__()
self.reset()
@property
def _default(self) -> dict[str, Any]:
return {
**self._filter_params(mpl.rcParamsDefault),
**axes_style("darkgrid"),
**plotting_context("notebook"),
"axes.prop_cycle": cycler("color", color_palette("deep")),
}
def reset(self) -> None:
"""Update the theme dictionary with seaborn's default values."""
self.update(self._default)
def update(self, other: dict[str, Any] | None = None, /, **kwds):
"""Update the theme with a dictionary or keyword arguments of rc parameters."""
if other is not None:
theme = self._filter_params(other)
else:
theme = {}
theme.update(kwds)
super().update(theme)
def _filter_params(self, params: dict[str, Any]) -> dict[str, Any]:
"""Restruct to thematic rc params."""
return {
k: v for k, v in params.items()
if any(k.startswith(p) for p in self.THEME_GROUPS)
}
def _html_table(self, params: dict[str, Any]) -> list[str]:
lines = ["
"]
for k, v in params.items():
row = f"{k}: | {v!r} |
"
lines.append(row)
lines.append("
")
return lines
def _repr_html_(self) -> str:
repr = [
"",
"
",
*self._html_table(self),
"
",
"
",
]
return "\n".join(repr)
class DisplayConfig(TypedDict):
"""Configuration for IPython's rich display hooks."""
format: Literal["png", "svg"]
scaling: float
hidpi: bool
class PlotConfig:
"""Configuration for default behavior / appearance of class:`Plot` instances."""
def __init__(self):
self._theme = ThemeConfig()
self._display = {"format": "png", "scaling": .85, "hidpi": True}
@property
def theme(self) -> dict[str, Any]:
"""
Dictionary of base theme parameters for :class:`Plot`.
Keys and values correspond to matplotlib rc params, as documented here:
https://matplotlib.org/stable/tutorials/introductory/customizing.html
"""
return self._theme
@property
def display(self) -> DisplayConfig:
"""
Dictionary of parameters for rich display in Jupyter notebook.
Valid parameters:
- format ("png" or "svg"): Image format to produce
- scaling (float): Relative scaling of embedded image
- hidpi (bool): When True, double the DPI while preserving the size
"""
return self._display
# ---- The main interface for declarative plotting --------------------------------- #
@build_plot_signature
class Plot:
"""
An interface for declaratively specifying statistical graphics.
Plots are constructed by initializing this class and adding one or more
layers, comprising a `Mark` and optional `Stat` or `Move`. Additionally,
faceting variables or variable pairings may be defined to divide the space
into multiple subplots. The mappings from data values to visual properties
can be parametrized using scales, although the plot will try to infer good
defaults when scales are not explicitly defined.
The constructor accepts a data source (a :class:`pandas.DataFrame` or
dictionary with columnar values) and variable assignments. Variables can be
passed as keys to the data source or directly as data vectors. If multiple
data-containing objects are provided, they will be index-aligned.
The data source and variables defined in the constructor will be used for
all layers in the plot, unless overridden or disabled when adding a layer.
The following variables can be defined in the constructor:
{known_properties}
The `data`, `x`, and `y` variables can be passed as positional arguments or
using keywords. Whether the first positional argument is interpreted as a
data source or `x` variable depends on its type.
The methods of this class return a copy of the instance; use chaining to
build up a plot through multiple calls. Methods can be called in any order.
Most methods only add information to the plot spec; no actual processing
happens until the plot is shown or saved. It is also possible to compile
the plot without rendering it to access the lower-level representation.
"""
config = PlotConfig()
_data: PlotData
_layers: list[Layer]
_scales: dict[str, Scale]
_shares: dict[str, bool | str]
_limits: dict[str, tuple[Any, Any]]
_labels: dict[str, str | Callable[[str], str]]
_theme: dict[str, Any]
_facet_spec: FacetSpec
_pair_spec: PairSpec
_figure_spec: dict[str, Any]
_subplot_spec: dict[str, Any]
_layout_spec: dict[str, Any]
def __init__(
self,
*args: DataSource | VariableSpec,
data: DataSource = None,
**variables: VariableSpec,
):
if args:
data, variables = self._resolve_positionals(args, data, variables)
unknown = [x for x in variables if x not in PROPERTIES]
if unknown:
err = f"Plot() got unexpected keyword argument(s): {', '.join(unknown)}"
raise TypeError(err)
self._data = PlotData(data, variables)
self._layers = []
self._scales = {}
self._shares = {}
self._limits = {}
self._labels = {}
self._theme = {}
self._facet_spec = {}
self._pair_spec = {}
self._figure_spec = {}
self._subplot_spec = {}
self._layout_spec = {}
self._target = None
def _resolve_positionals(
self,
args: tuple[DataSource | VariableSpec, ...],
data: DataSource,
variables: dict[str, VariableSpec],
) -> tuple[DataSource, dict[str, VariableSpec]]:
"""Handle positional arguments, which may contain data / x / y."""
if len(args) > 3:
err = "Plot() accepts no more than 3 positional arguments (data, x, y)."
raise TypeError(err)
if (
isinstance(args[0], (abc.Mapping, pd.DataFrame))
or hasattr(args[0], "__dataframe__")
):
if data is not None:
raise TypeError("`data` given by both name and position.")
data, args = args[0], args[1:]
if len(args) == 2:
x, y = args
elif len(args) == 1:
x, y = *args, None
else:
x = y = None
for name, var in zip("yx", (y, x)):
if var is not None:
if name in variables:
raise TypeError(f"`{name}` given by both name and position.")
# Keep coordinates at the front of the variables dict
# Cast type because we know this isn't a DataSource at this point
variables = {name: cast(VariableSpec, var), **variables}
return data, variables
def __add__(self, other):
if isinstance(other, Mark) or isinstance(other, Stat):
raise TypeError("Sorry, this isn't ggplot! Perhaps try Plot.add?")
other_type = other.__class__.__name__
raise TypeError(f"Unsupported operand type(s) for +: 'Plot' and '{other_type}")
def _repr_png_(self) -> tuple[bytes, dict[str, float]] | None:
if Plot.config.display["format"] != "png":
return None
return self.plot()._repr_png_()
def _repr_svg_(self) -> str | None:
if Plot.config.display["format"] != "svg":
return None
return self.plot()._repr_svg_()
def _clone(self) -> Plot:
"""Generate a new object with the same information as the current spec."""
new = Plot()
# TODO any way to enforce that data does not get mutated?
new._data = self._data
new._layers.extend(self._layers)
new._scales.update(self._scales)
new._shares.update(self._shares)
new._limits.update(self._limits)
new._labels.update(self._labels)
new._theme.update(self._theme)
new._facet_spec.update(self._facet_spec)
new._pair_spec.update(self._pair_spec)
new._figure_spec.update(self._figure_spec)
new._subplot_spec.update(self._subplot_spec)
new._layout_spec.update(self._layout_spec)
new._target = self._target
return new
def _theme_with_defaults(self) -> dict[str, Any]:
theme = self.config.theme.copy()
theme.update(self._theme)
return theme
@property
def _variables(self) -> list[str]:
variables = (
list(self._data.frame)
+ list(self._pair_spec.get("variables", []))
+ list(self._facet_spec.get("variables", []))
)
for layer in self._layers:
variables.extend(v for v in layer["vars"] if v not in variables)
# Coerce to str in return to appease mypy; we know these will only
# ever be strings but I don't think we can type a DataFrame that way yet
return [str(v) for v in variables]
def on(self, target: Axes | SubFigure | Figure) -> Plot:
"""
Provide existing Matplotlib figure or axes for drawing the plot.
When using this method, you will also need to explicitly call a method that
triggers compilation, such as :meth:`Plot.show` or :meth:`Plot.save`. If you
want to postprocess using matplotlib, you'd need to call :meth:`Plot.plot`
first to compile the plot without rendering it.
Parameters
----------
target : Axes, SubFigure, or Figure
Matplotlib object to use. Passing :class:`matplotlib.axes.Axes` will add
artists without otherwise modifying the figure. Otherwise, subplots will be
created within the space of the given :class:`matplotlib.figure.Figure` or
:class:`matplotlib.figure.SubFigure`.
Examples
--------
.. include:: ../docstrings/objects.Plot.on.rst
"""
accepted_types: tuple # Allow tuple of various length
if hasattr(mpl.figure, "SubFigure"): # Added in mpl 3.4
accepted_types = (
mpl.axes.Axes, mpl.figure.SubFigure, mpl.figure.Figure
)
accepted_types_str = (
f"{mpl.axes.Axes}, {mpl.figure.SubFigure}, or {mpl.figure.Figure}"
)
else:
accepted_types = mpl.axes.Axes, mpl.figure.Figure
accepted_types_str = f"{mpl.axes.Axes} or {mpl.figure.Figure}"
if not isinstance(target, accepted_types):
err = (
f"The `Plot.on` target must be an instance of {accepted_types_str}. "
f"You passed an instance of {target.__class__} instead."
)
raise TypeError(err)
new = self._clone()
new._target = target
return new
def add(
self,
mark: Mark,
*transforms: Stat | Move,
orient: str | None = None,
legend: bool = True,
label: str | None = None,
data: DataSource = None,
**variables: VariableSpec,
) -> Plot:
"""
Specify a layer of the visualization in terms of mark and data transform(s).
This is the main method for specifying how the data should be visualized.
It can be called multiple times with different arguments to define
a plot with multiple layers.
Parameters
----------
mark : :class:`Mark`
The visual representation of the data to use in this layer.
transforms : :class:`Stat` or :class:`Move`
Objects representing transforms to be applied before plotting the data.
Currently, at most one :class:`Stat` can be used, and it
must be passed first. This constraint will be relaxed in the future.
orient : "x", "y", "v", or "h"
The orientation of the mark, which also affects how transforms are computed.
Typically corresponds to the axis that defines groups for aggregation.
The "v" (vertical) and "h" (horizontal) options are synonyms for "x" / "y",
but may be more intuitive with some marks. When not provided, an
orientation will be inferred from characteristics of the data and scales.
legend : bool
Option to suppress the mark/mappings for this layer from the legend.
label : str
A label to use for the layer in the legend, independent of any mappings.
data : DataFrame or dict
Data source to override the global source provided in the constructor.
variables : data vectors or identifiers
Additional layer-specific variables, including variables that will be
passed directly to the transforms without scaling.
Examples
--------
.. include:: ../docstrings/objects.Plot.add.rst
"""
if not isinstance(mark, Mark):
msg = f"mark must be a Mark instance, not {type(mark)!r}."
raise TypeError(msg)
# TODO This API for transforms was a late decision, and previously Plot.add
# accepted 0 or 1 Stat instances and 0, 1, or a list of Move instances.
# It will take some work to refactor the internals so that Stat and Move are
# treated identically, and until then well need to "unpack" the transforms
# here and enforce limitations on the order / types.
stat: Optional[Stat]
move: Optional[List[Move]]
error = False
if not transforms:
stat, move = None, None
elif isinstance(transforms[0], Stat):
stat = transforms[0]
move = [m for m in transforms[1:] if isinstance(m, Move)]
error = len(move) != len(transforms) - 1
else:
stat = None
move = [m for m in transforms if isinstance(m, Move)]
error = len(move) != len(transforms)
if error:
msg = " ".join([
"Transforms must have at most one Stat type (in the first position),",
"and all others must be a Move type. Given transform type(s):",
", ".join(str(type(t).__name__) for t in transforms) + "."
])
raise TypeError(msg)
new = self._clone()
new._layers.append({
"mark": mark,
"stat": stat,
"move": move,
# TODO it doesn't work to supply scalars to variables, but it should
"vars": variables,
"source": data,
"legend": legend,
"label": label,
"orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore
})
return new
def pair(
self,
x: VariableSpecList = None,
y: VariableSpecList = None,
wrap: int | None = None,
cross: bool = True,
) -> Plot:
"""
Produce subplots by pairing multiple `x` and/or `y` variables.
Parameters
----------
x, y : sequence(s) of data vectors or identifiers
Variables that will define the grid of subplots.
wrap : int
When using only `x` or `y`, "wrap" subplots across a two-dimensional grid
with this many columns (when using `x`) or rows (when using `y`).
cross : bool
When False, zip the `x` and `y` lists such that the first subplot gets the
first pair, the second gets the second pair, etc. Otherwise, create a
two-dimensional grid from the cartesian product of the lists.
Examples
--------
.. include:: ../docstrings/objects.Plot.pair.rst
"""
# TODO Add transpose= arg, which would then draw pair(y=[...]) across rows
# This may also be possible by setting `wrap=1`, but is that too unobvious?
# TODO PairGrid features not currently implemented: diagonals, corner
pair_spec: PairSpec = {}
axes = {"x": [] if x is None else x, "y": [] if y is None else y}
for axis, arg in axes.items():
if isinstance(arg, (str, int)):
err = f"You must pass a sequence of variable keys to `{axis}`"
raise TypeError(err)
pair_spec["variables"] = {}
pair_spec["structure"] = {}
for axis in "xy":
keys = []
for i, col in enumerate(axes[axis]):
key = f"{axis}{i}"
keys.append(key)
pair_spec["variables"][key] = col
if keys:
pair_spec["structure"][axis] = keys
if not cross and len(axes["x"]) != len(axes["y"]):
err = "Lengths of the `x` and `y` lists must match with cross=False"
raise ValueError(err)
pair_spec["cross"] = cross
pair_spec["wrap"] = wrap
new = self._clone()
new._pair_spec.update(pair_spec)
return new
def facet(
self,
col: VariableSpec = None,
row: VariableSpec = None,
order: OrderSpec | dict[str, OrderSpec] = None,
wrap: int | None = None,
) -> Plot:
"""
Produce subplots with conditional subsets of the data.
Parameters
----------
col, row : data vectors or identifiers
Variables used to define subsets along the columns and/or rows of the grid.
Can be references to the global data source passed in the constructor.
order : list of strings, or dict with dimensional keys
Define the order of the faceting variables.
wrap : int
When using only `col` or `row`, wrap subplots across a two-dimensional
grid with this many subplots on the faceting dimension.
Examples
--------
.. include:: ../docstrings/objects.Plot.facet.rst
"""
variables: dict[str, VariableSpec] = {}
if col is not None:
variables["col"] = col
if row is not None:
variables["row"] = row
structure = {}
if isinstance(order, dict):
for dim in ["col", "row"]:
dim_order = order.get(dim)
if dim_order is not None:
structure[dim] = list(dim_order)
elif order is not None:
if col is not None and row is not None:
err = " ".join([
"When faceting on both col= and row=, passing `order` as a list"
"is ambiguous. Use a dict with 'col' and/or 'row' keys instead."
])
raise RuntimeError(err)
elif col is not None:
structure["col"] = list(order)
elif row is not None:
structure["row"] = list(order)
spec: FacetSpec = {
"variables": variables,
"structure": structure,
"wrap": wrap,
}
new = self._clone()
new._facet_spec.update(spec)
return new
# TODO def twin()?
def scale(self, **scales: Scale) -> Plot:
"""
Specify mappings from data units to visual properties.
Keywords correspond to variables defined in the plot, including coordinate
variables (`x`, `y`) and semantic variables (`color`, `pointsize`, etc.).
A number of "magic" arguments are accepted, including:
- The name of a transform (e.g., `"log"`, `"sqrt"`)
- The name of a palette (e.g., `"viridis"`, `"muted"`)
- A tuple of values, defining the output range (e.g. `(1, 5)`)
- A dict, implying a :class:`Nominal` scale (e.g. `{"a": .2, "b": .5}`)
- A list of values, implying a :class:`Nominal` scale (e.g. `["b", "r"]`)
For more explicit control, pass a scale spec object such as :class:`Continuous`
or :class:`Nominal`. Or pass `None` to use an "identity" scale, which treats
data values as literally encoding visual properties.
Examples
--------
.. include:: ../docstrings/objects.Plot.scale.rst
"""
new = self._clone()
new._scales.update(scales)
return new
def share(self, **shares: bool | str) -> Plot:
"""
Control sharing of axis limits and ticks across subplots.
Keywords correspond to variables defined in the plot, and values can be
boolean (to share across all subplots), or one of "row" or "col" (to share
more selectively across one dimension of a grid).
Behavior for non-coordinate variables is currently undefined.
Examples
--------
.. include:: ../docstrings/objects.Plot.share.rst
"""
new = self._clone()
new._shares.update(shares)
return new
def limit(self, **limits: tuple[Any, Any]) -> Plot:
"""
Control the range of visible data.
Keywords correspond to variables defined in the plot, and values are a
`(min, max)` tuple (where either can be `None` to leave unset).
Limits apply only to the axis; data outside the visible range are
still used for any stat transforms and added to the plot.
Behavior for non-coordinate variables is currently undefined.
Examples
--------
.. include:: ../docstrings/objects.Plot.limit.rst
"""
new = self._clone()
new._limits.update(limits)
return new
def label(
self, *,
title: str | None = None,
legend: str | None = None,
**variables: str | Callable[[str], str]
) -> Plot:
"""
Control the labels and titles for axes, legends, and subplots.
Additional keywords correspond to variables defined in the plot.
Values can be one of the following types:
- string (used literally; pass "" to clear the default label)
- function (called on the default label)
For coordinate variables, the value sets the axis label.
For semantic variables, the value sets the legend title.
For faceting variables, `title=` modifies the subplot-specific label,
while `col=` and/or `row=` add a label for the faceting variable.
When using a single subplot, `title=` sets its title.
The `legend=` parameter sets the title for the "layer" legend
(i.e., when using `label` in :meth:`Plot.add`).
Examples
--------
.. include:: ../docstrings/objects.Plot.label.rst
"""
new = self._clone()
if title is not None:
new._labels["title"] = title
if legend is not None:
new._labels["legend"] = legend
new._labels.update(variables)
return new
def layout(
self,
*,
size: tuple[float, float] | Default = default,
engine: str | None | Default = default,
) -> Plot:
"""
Control the figure size and layout.
.. note::
Default figure sizes and the API for specifying the figure size are subject
to change in future "experimental" releases of the objects API. The default
layout engine may also change.
Parameters
----------
size : (width, height)
Size of the resulting figure, in inches. Size is inclusive of legend when
using pyplot, but not otherwise.
engine : {{"tight", "constrained", None}}
Name of method for automatically adjusting the layout to remove overlap.
The default depends on whether :meth:`Plot.on` is used.
Examples
--------
.. include:: ../docstrings/objects.Plot.layout.rst
"""
# TODO add an "auto" mode for figsize that roughly scales with the rcParams
# figsize (so that works), but expands to prevent subplots from being squished
# Also should we have height=, aspect=, exclusive with figsize? Or working
# with figsize when only one is defined?
new = self._clone()
if size is not default:
new._figure_spec["figsize"] = size
if engine is not default:
new._layout_spec["engine"] = engine
return new
# TODO def legend (ugh)
def theme(self, *args: dict[str, Any]) -> Plot:
"""
Control the appearance of elements in the plot.
.. note::
The API for customizing plot appearance is not yet finalized.
Currently, the only valid argument is a dict of matplotlib rc parameters.
(This dict must be passed as a positional argument.)
It is likely that this method will be enhanced in future releases.
Matplotlib rc parameters are documented on the following page:
https://matplotlib.org/stable/tutorials/introductory/customizing.html
Examples
--------
.. include:: ../docstrings/objects.Plot.theme.rst
"""
new = self._clone()
# We can skip this whole block on Python 3.8+ with positional-only syntax
nargs = len(args)
if nargs != 1:
err = f"theme() takes 1 positional argument, but {nargs} were given"
raise TypeError(err)
rc = mpl.RcParams(args[0])
new._theme.update(rc)
return new
def save(self, loc, **kwargs) -> Plot:
"""
Compile the plot and write it to a buffer or file on disk.
Parameters
----------
loc : str, path, or buffer
Location on disk to save the figure, or a buffer to write into.
kwargs
Other keyword arguments are passed through to
:meth:`matplotlib.figure.Figure.savefig`.
"""
# TODO expose important keyword arguments in our signature?
with theme_context(self._theme_with_defaults()):
self._plot().save(loc, **kwargs)
return self
def show(self, **kwargs) -> None:
"""
Compile the plot and display it by hooking into pyplot.
Calling this method is not necessary to render a plot in notebook context,
but it may be in other environments (e.g., in a terminal). After compiling the
plot, it calls :func:`matplotlib.pyplot.show` (passing any keyword parameters).
Unlike other :class:`Plot` methods, there is no return value. This should be
the last method you call when specifying a plot.
"""
# TODO make pyplot configurable at the class level, and when not using,
# import IPython.display and call on self to populate cell output?
# Keep an eye on whether matplotlib implements "attaching" an existing
# figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024
self.plot(pyplot=True).show(**kwargs)
def plot(self, pyplot: bool = False) -> Plotter:
"""
Compile the plot spec and return the Plotter object.
"""
with theme_context(self._theme_with_defaults()):
return self._plot(pyplot)
def _plot(self, pyplot: bool = False) -> Plotter:
# TODO if we have _target object, pyplot should be determined by whether it
# is hooked into the pyplot state machine (how do we check?)
plotter = Plotter(pyplot=pyplot, theme=self._theme_with_defaults())
# Process the variable assignments and initialize the figure
common, layers = plotter._extract_data(self)
plotter._setup_figure(self, common, layers)
# Process the scale spec for coordinate variables and transform their data
coord_vars = [v for v in self._variables if re.match(r"^x|y", v)]
plotter._setup_scales(self, common, layers, coord_vars)
# Apply statistical transform(s)
plotter._compute_stats(self, layers)
# Process scale spec for semantic variables and coordinates computed by stat
plotter._setup_scales(self, common, layers)
# TODO Remove these after updating other methods
# ---- Maybe have debug= param that attaches these when True?
plotter._data = common
plotter._layers = layers
# Process the data for each layer and add matplotlib artists
for layer in layers:
plotter._plot_layer(self, layer)
# Add various figure decorations
plotter._make_legend(self)
plotter._finalize_figure(self)
return plotter
# ---- The plot compilation engine ---------------------------------------------- #
class Plotter:
"""
Engine for compiling a :class:`Plot` spec into a Matplotlib figure.
This class is not intended to be instantiated directly by users.
"""
# TODO decide if we ever want these (Plot.plot(debug=True))?
_data: PlotData
_layers: list[Layer]
_figure: Figure
def __init__(self, pyplot: bool, theme: dict[str, Any]):
self._pyplot = pyplot
self._theme = theme
self._legend_contents: list[tuple[
tuple[str, str | int], list[Artist], list[str],
]] = []
self._scales: dict[str, Scale] = {}
def save(self, loc, **kwargs) -> Plotter: # TODO type args
kwargs.setdefault("dpi", 96)
try:
loc = os.path.expanduser(loc)
except TypeError:
# loc may be a buffer in which case that would not work
pass
self._figure.savefig(loc, **kwargs)
return self
def show(self, **kwargs) -> None:
"""
Display the plot by hooking into pyplot.
This method calls :func:`matplotlib.pyplot.show` with any keyword parameters.
"""
# TODO if we did not create the Plotter with pyplot, is it possible to do this?
# If not we should clearly raise.
import matplotlib.pyplot as plt
with theme_context(self._theme):
plt.show(**kwargs)
# TODO API for accessing the underlying matplotlib objects
# TODO what else is useful in the public API for this class?
def _repr_png_(self) -> tuple[bytes, dict[str, float]] | None:
# TODO use matplotlib backend directly instead of going through savefig?
# TODO perhaps have self.show() flip a switch to disable this, so that
# user does not end up with two versions of the figure in the output
# TODO use bbox_inches="tight" like the inline backend?
# pro: better results, con: (sometimes) confusing results
# Better solution would be to default (with option to change)
# to using constrained/tight layout.
if Plot.config.display["format"] != "png":
return None
buffer = io.BytesIO()
factor = 2 if Plot.config.display["hidpi"] else 1
scaling = Plot.config.display["scaling"] / factor
dpi = 96 * factor # TODO put dpi in Plot.config?
with theme_context(self._theme): # TODO _theme_with_defaults?
self._figure.savefig(buffer, dpi=dpi, format="png", bbox_inches="tight")
data = buffer.getvalue()
w, h = Image.open(buffer).size
metadata = {"width": w * scaling, "height": h * scaling}
return data, metadata
def _repr_svg_(self) -> str | None:
if Plot.config.display["format"] != "svg":
return None
# TODO DPI for rasterized artists?
scaling = Plot.config.display["scaling"]
buffer = io.StringIO()
with theme_context(self._theme): # TODO _theme_with_defaults?
self._figure.savefig(buffer, format="svg", bbox_inches="tight")
root = ElementTree.fromstring(buffer.getvalue())
w = scaling * float(root.attrib["width"][:-2])
h = scaling * float(root.attrib["height"][:-2])
root.attrib.update(width=f"{w}pt", height=f"{h}pt", viewbox=f"0 0 {w} {h}")
ElementTree.ElementTree(root).write(out := io.BytesIO())
return out.getvalue().decode()
def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]:
common_data = (
p._data
.join(None, p._facet_spec.get("variables"))
.join(None, p._pair_spec.get("variables"))
)
layers: list[Layer] = []
for layer in p._layers:
spec = layer.copy()
spec["data"] = common_data.join(layer.get("source"), layer.get("vars"))
layers.append(spec)
return common_data, layers
def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str:
if re.match(r"[xy]\d+", var):
key = var if var in p._labels else var[0]
else:
key = var
label: str
if key in p._labels:
manual_label = p._labels[key]
if callable(manual_label) and auto_label is not None:
label = manual_label(auto_label)
else:
label = cast(str, manual_label)
elif auto_label is None:
label = ""
else:
label = auto_label
return label
def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
# --- Parsing the faceting/pairing parameterization to specify figure grid
subplot_spec = p._subplot_spec.copy()
facet_spec = p._facet_spec.copy()
pair_spec = p._pair_spec.copy()
for axis in "xy":
if axis in p._shares:
subplot_spec[f"share{axis}"] = p._shares[axis]
for dim in ["col", "row"]:
if dim in common.frame and dim not in facet_spec["structure"]:
order = categorical_order(common.frame[dim])
facet_spec["structure"][dim] = order
self._subplots = subplots = Subplots(subplot_spec, facet_spec, pair_spec)
# --- Figure initialization
self._figure = subplots.init_figure(
pair_spec, self._pyplot, p._figure_spec, p._target,
)
# --- Figure annotation
for sub in subplots:
ax = sub["ax"]
for axis in "xy":
axis_key = sub[axis]
# ~~ Axis labels
# TODO Should we make it possible to use only one x/y label for
# all rows/columns in a faceted plot? Maybe using sub{axis}label,
# although the alignments of the labels from that method leaves
# something to be desired (in terms of how it defines 'centered').
names = [
common.names.get(axis_key),
*(layer["data"].names.get(axis_key) for layer in layers)
]
auto_label = next((name for name in names if name is not None), None)
label = self._resolve_label(p, axis_key, auto_label)
ax.set(**{f"{axis}label": label})
# ~~ Decoration visibility
# TODO there should be some override (in Plot.layout?) so that
# axis / tick labels can be shown on interior shared axes if desired
axis_obj = getattr(ax, f"{axis}axis")
visible_side = {"x": "bottom", "y": "left"}.get(axis)
show_axis_label = (
sub[visible_side]
or not p._pair_spec.get("cross", True)
or (
axis in p._pair_spec.get("structure", {})
and bool(p._pair_spec.get("wrap"))
)
)
axis_obj.get_label().set_visible(show_axis_label)
show_tick_labels = (
show_axis_label
or subplot_spec.get(f"share{axis}") not in (
True, "all", {"x": "col", "y": "row"}[axis]
)
)
for group in ("major", "minor"):
for t in getattr(axis_obj, f"get_{group}ticklabels")():
t.set_visible(show_tick_labels)
# TODO we want right-side titles for row facets in most cases?
# Let's have what we currently call "margin titles" but properly using the
# ax.set_title interface (see my gist)
title_parts = []
for dim in ["col", "row"]:
if sub[dim] is not None:
val = self._resolve_label(p, "title", f"{sub[dim]}")
if dim in p._labels:
key = self._resolve_label(p, dim, common.names.get(dim))
val = f"{key} {val}"
title_parts.append(val)
has_col = sub["col"] is not None
has_row = sub["row"] is not None
show_title = (
has_col and has_row
or (has_col or has_row) and p._facet_spec.get("wrap")
or (has_col and sub["top"])
# TODO or has_row and sub["right"] and
or has_row # TODO and not
)
if title_parts:
title = " | ".join(title_parts)
title_text = ax.set_title(title)
title_text.set_visible(show_title)
elif not (has_col or has_row):
title = self._resolve_label(p, "title", None)
title_text = ax.set_title(title)
def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:
grouping_vars = [v for v in PROPERTIES if v not in "xy"]
grouping_vars += ["col", "row", "group"]
pair_vars = spec._pair_spec.get("structure", {})
for layer in layers:
data = layer["data"]
mark = layer["mark"]
stat = layer["stat"]
if stat is None:
continue
iter_axes = itertools.product(*[
pair_vars.get(axis, [axis]) for axis in "xy"
])
old = data.frame
if pair_vars:
data.frames = {}
data.frame = data.frame.iloc[:0] # TODO to simplify typing
for coord_vars in iter_axes:
pairings = "xy", coord_vars
df = old.copy()
scales = self._scales.copy()
for axis, var in zip(*pairings):
if axis != var:
df = df.rename(columns={var: axis})
drop_cols = [x for x in df if re.match(rf"{axis}\d+", str(x))]
df = df.drop(drop_cols, axis=1)
scales[axis] = scales[var]
orient = layer["orient"] or mark._infer_orient(scales)
if stat.group_by_orient:
grouper = [orient, *grouping_vars]
else:
grouper = grouping_vars
groupby = GroupBy(grouper)
res = stat(df, groupby, orient, scales)
if pair_vars:
data.frames[coord_vars] = res
else:
data.frame = res
def _get_scale(
self, p: Plot, var: str, prop: Property, values: Series
) -> Scale:
if re.match(r"[xy]\d+", var):
key = var if var in p._scales else var[0]
else:
key = var
if key in p._scales:
arg = p._scales[key]
if arg is None or isinstance(arg, Scale):
scale = arg
else:
scale = prop.infer_scale(arg, values)
else:
scale = prop.default_scale(values)
return scale
def _get_subplot_data(self, df, var, view, share_state):
if share_state in [True, "all"]:
# The all-shared case is easiest, every subplot sees all the data
seed_values = df[var]
else:
# Otherwise, we need to setup separate scales for different subplots
if share_state in [False, "none"]:
# Fully independent axes are also easy: use each subplot's data
idx = self._get_subplot_index(df, view)
elif share_state in df:
# Sharing within row/col is more complicated
use_rows = df[share_state] == view[share_state]
idx = df.index[use_rows]
else:
# This configuration doesn't make much sense, but it's fine
idx = df.index
seed_values = df.loc[idx, var]
return seed_values
def _setup_scales(
self,
p: Plot,
common: PlotData,
layers: list[Layer],
variables: list[str] | None = None,
) -> None:
if variables is None:
# Add variables that have data but not a scale, which happens
# because this method can be called multiple time, to handle
# variables added during the Stat transform.
variables = []
for layer in layers:
variables.extend(layer["data"].frame.columns)
for df in layer["data"].frames.values():
variables.extend(str(v) for v in df if v not in variables)
variables = [v for v in variables if v not in self._scales]
for var in variables:
# Determine whether this is a coordinate variable
# (i.e., x/y, paired x/y, or derivative such as xmax)
m = re.match(r"^(?P(?Px|y)\d*).*", var)
if m is None:
coord = axis = None
else:
coord = m["coord"]
axis = m["axis"]
# Get keys that handle things like x0, xmax, properly where relevant
prop_key = var if axis is None else axis
scale_key = var if coord is None else coord
if prop_key not in PROPERTIES:
continue
# Concatenate layers, using only the relevant coordinate and faceting vars,
# This is unnecessarily wasteful, as layer data will often be redundant.
# But figuring out the minimal amount we need is more complicated.
cols = [var, "col", "row"]
parts = [common.frame.filter(cols)]
for layer in layers:
parts.append(layer["data"].frame.filter(cols))
for df in layer["data"].frames.values():
parts.append(df.filter(cols))
var_df = pd.concat(parts, ignore_index=True)
prop = PROPERTIES[prop_key]
scale = self._get_scale(p, scale_key, prop, var_df[var])
if scale_key not in p._variables:
# TODO this implies that the variable was added by the stat
# It allows downstream orientation inference to work properly.
# But it feels rather hacky, so ideally revisit.
scale._priority = 0 # type: ignore
if axis is None:
# We could think about having a broader concept of (un)shared properties
# In general, not something you want to do (different scales in facets)
# But could make sense e.g. with paired plots. Build later.
share_state = None
subplots = []
else:
share_state = self._subplots.subplot_spec[f"share{axis}"]
subplots = [view for view in self._subplots if view[axis] == coord]
# Shared categorical axes are broken on matplotlib<3.4.0.
# https://github.com/matplotlib/matplotlib/pull/18308
# This only affects us when sharing *paired* axes. This is a novel/niche
# behavior, so we will raise rather than hack together a workaround.
if axis is not None and _version_predates(mpl, "3.4"):
paired_axis = axis in p._pair_spec.get("structure", {})
cat_scale = isinstance(scale, Nominal)
ok_dim = {"x": "col", "y": "row"}[axis]
shared_axes = share_state not in [False, "none", ok_dim]
if paired_axis and cat_scale and shared_axes:
err = "Sharing paired categorical axes requires matplotlib>=3.4.0"
raise RuntimeError(err)
if scale is None:
self._scales[var] = Scale._identity()
else:
try:
self._scales[var] = scale._setup(var_df[var], prop)
except Exception as err:
raise PlotSpecError._during("Scale setup", var) from err
if axis is None or (var != coord and coord in p._variables):
# Everything below here applies only to coordinate variables
continue
# Set up an empty series to receive the transformed values.
# We need this to handle piecemeal transforms of categories -> floats.
transformed_data = []
for layer in layers:
index = layer["data"].frame.index
empty_series = pd.Series(dtype=float, index=index, name=var)
transformed_data.append(empty_series)
for view in subplots:
axis_obj = getattr(view["ax"], f"{axis}axis")
seed_values = self._get_subplot_data(var_df, var, view, share_state)
view_scale = scale._setup(seed_values, prop, axis=axis_obj)
set_scale_obj(view["ax"], axis, view_scale._matplotlib_scale)
for layer, new_series in zip(layers, transformed_data):
layer_df = layer["data"].frame
if var not in layer_df:
continue
idx = self._get_subplot_index(layer_df, view)
try:
new_series.loc[idx] = view_scale(layer_df.loc[idx, var])
except Exception as err:
spec_error = PlotSpecError._during("Scaling operation", var)
raise spec_error from err
# Now the transformed data series are complete, update the layer data
for layer, new_series in zip(layers, transformed_data):
layer_df = layer["data"].frame
if var in layer_df:
layer_df[var] = pd.to_numeric(new_series)
def _plot_layer(self, p: Plot, layer: Layer) -> None:
data = layer["data"]
mark = layer["mark"]
move = layer["move"]
default_grouping_vars = ["col", "row", "group"] # TODO where best to define?
grouping_properties = [v for v in PROPERTIES if v[0] not in "xy"]
pair_variables = p._pair_spec.get("structure", {})
for subplots, df, scales in self._generate_pairings(data, pair_variables):
orient = layer["orient"] or mark._infer_orient(scales)
def get_order(var):
# Ignore order for x/y: they have been scaled to numeric indices,
# so any original order is no longer valid. Default ordering rules
# sorted unique numbers will correctly reconstruct intended order
# TODO This is tricky, make sure we add some tests for this
if var not in "xy" and var in scales:
return getattr(scales[var], "order", None)
if orient in df:
width = pd.Series(index=df.index, dtype=float)
for view in subplots:
view_idx = self._get_subplot_data(
df, orient, view, p._shares.get(orient)
).index
view_df = df.loc[view_idx]
if "width" in mark._mappable_props:
view_width = mark._resolve(view_df, "width", None)
elif "width" in df:
view_width = view_df["width"]
else:
view_width = 0.8 # TODO what default?
spacing = scales[orient]._spacing(view_df.loc[view_idx, orient])
width.loc[view_idx] = view_width * spacing
df["width"] = width
if "baseline" in mark._mappable_props:
# TODO what marks should have this?
# If we can set baseline with, e.g., Bar(), then the
# "other" (e.g. y for x oriented bars) parameterization
# is somewhat ambiguous.
baseline = mark._resolve(df, "baseline", None)
else:
# TODO unlike width, we might not want to add baseline to data
# if the mark doesn't use it. Practically, there is a concern about
# Mark abstraction like Area / Ribbon
baseline = 0 if "baseline" not in df else df["baseline"]
df["baseline"] = baseline
if move is not None:
moves = move if isinstance(move, list) else [move]
for move_step in moves:
move_by = getattr(move_step, "by", None)
if move_by is None:
move_by = grouping_properties
move_groupers = [*move_by, *default_grouping_vars]
if move_step.group_by_orient:
move_groupers.insert(0, orient)
order = {var: get_order(var) for var in move_groupers}
groupby = GroupBy(order)
df = move_step(df, groupby, orient, scales)
df = self._unscale_coords(subplots, df, orient)
grouping_vars = mark._grouping_props + default_grouping_vars
split_generator = self._setup_split_generator(grouping_vars, df, subplots)
mark._plot(split_generator, scales, orient)
# TODO is this the right place for this?
for view in self._subplots:
view["ax"].autoscale_view()
if layer["legend"]:
self._update_legend_contents(p, mark, data, scales, layer["label"])
def _unscale_coords(
self, subplots: list[dict], df: DataFrame, orient: str,
) -> DataFrame:
# TODO do we still have numbers in the variable name at this point?
coord_cols = [c for c in df if re.match(r"^[xy]\D*$", str(c))]
out_df = (
df
.drop(coord_cols, axis=1)
.reindex(df.columns, axis=1) # So unscaled columns retain their place
.copy(deep=False)
)
for view in subplots:
view_df = self._filter_subplot_data(df, view)
axes_df = view_df[coord_cols]
for var, values in axes_df.items():
axis = getattr(view["ax"], f"{str(var)[0]}axis")
# TODO see https://github.com/matplotlib/matplotlib/issues/22713
transform = axis.get_transform().inverted().transform
inverted = transform(values)
out_df.loc[values.index, str(var)] = inverted
return out_df
def _generate_pairings(
self, data: PlotData, pair_variables: dict,
) -> Generator[
tuple[list[dict], DataFrame, dict[str, Scale]], None, None
]:
# TODO retype return with subplot_spec or similar
iter_axes = itertools.product(*[
pair_variables.get(axis, [axis]) for axis in "xy"
])
for x, y in iter_axes:
subplots = []
for view in self._subplots:
if (view["x"] == x) and (view["y"] == y):
subplots.append(view)
if data.frame.empty and data.frames:
out_df = data.frames[(x, y)].copy()
elif not pair_variables:
out_df = data.frame.copy()
else:
if data.frame.empty and data.frames:
out_df = data.frames[(x, y)].copy()
else:
out_df = data.frame.copy()
scales = self._scales.copy()
if x in out_df:
scales["x"] = self._scales[x]
if y in out_df:
scales["y"] = self._scales[y]
for axis, var in zip("xy", (x, y)):
if axis != var:
out_df = out_df.rename(columns={var: axis})
cols = [col for col in out_df if re.match(rf"{axis}\d+", str(col))]
out_df = out_df.drop(cols, axis=1)
yield subplots, out_df, scales
def _get_subplot_index(self, df: DataFrame, subplot: dict) -> Index:
dims = df.columns.intersection(["col", "row"])
if dims.empty:
return df.index
keep_rows = pd.Series(True, df.index, dtype=bool)
for dim in dims:
keep_rows &= df[dim] == subplot[dim]
return df.index[keep_rows]
def _filter_subplot_data(self, df: DataFrame, subplot: dict) -> DataFrame:
# TODO note redundancies with preceding function ... needs refactoring
dims = df.columns.intersection(["col", "row"])
if dims.empty:
return df
keep_rows = pd.Series(True, df.index, dtype=bool)
for dim in dims:
keep_rows &= df[dim] == subplot[dim]
return df[keep_rows]
def _setup_split_generator(
self, grouping_vars: list[str], df: DataFrame, subplots: list[dict[str, Any]],
) -> Callable[[], Generator]:
grouping_keys = []
grouping_vars = [
v for v in grouping_vars if v in df and v not in ["col", "row"]
]
for var in grouping_vars:
order = getattr(self._scales[var], "order", None)
if order is None:
order = categorical_order(df[var])
grouping_keys.append(order)
def split_generator(keep_na=False) -> Generator:
for view in subplots:
axes_df = self._filter_subplot_data(df, view)
axes_df_inf_as_nan = axes_df.copy()
axes_df_inf_as_nan = axes_df_inf_as_nan.mask(
axes_df_inf_as_nan.isin([np.inf, -np.inf]), np.nan
)
if keep_na:
# The simpler thing to do would be x.dropna().reindex(x.index).
# But that doesn't work with the way that the subset iteration
# is written below, which assumes data for grouping vars.
# Matplotlib (usually?) masks nan data, so this should "work".
# Downstream code can also drop these rows, at some speed cost.
present = axes_df_inf_as_nan.notna().all(axis=1)
nulled = {}
for axis in "xy":
if axis in axes_df:
nulled[axis] = axes_df[axis].where(present)
axes_df = axes_df_inf_as_nan.assign(**nulled)
else:
axes_df = axes_df_inf_as_nan.dropna()
subplot_keys = {}
for dim in ["col", "row"]:
if view[dim] is not None:
subplot_keys[dim] = view[dim]
if not grouping_vars or not any(grouping_keys):
if not axes_df.empty:
yield subplot_keys, axes_df.copy(), view["ax"]
continue
grouped_df = axes_df.groupby(
grouping_vars, sort=False, as_index=False, observed=False,
)
for key in itertools.product(*grouping_keys):
# Pandas fails with singleton tuple inputs
pd_key = key[0] if len(key) == 1 else key
try:
df_subset = grouped_df.get_group(pd_key)
except KeyError:
# TODO (from initial work on categorical plots refactor)
# We are adding this to allow backwards compatability
# with the empty artists that old categorical plots would
# add (before 0.12), which we may decide to break, in which
# case this option could be removed
df_subset = axes_df.loc[[]]
if df_subset.empty:
continue
sub_vars = dict(zip(grouping_vars, key))
sub_vars.update(subplot_keys)
# TODO need copy(deep=...) policy (here, above, anywhere else?)
yield sub_vars, df_subset.copy(), view["ax"]
return split_generator
def _update_legend_contents(
self,
p: Plot,
mark: Mark,
data: PlotData,
scales: dict[str, Scale],
layer_label: str | None,
) -> None:
"""Add legend artists / labels for one layer in the plot."""
if data.frame.empty and data.frames:
legend_vars: list[str] = []
for frame in data.frames.values():
frame_vars = frame.columns.intersection(list(scales))
legend_vars.extend(v for v in frame_vars if v not in legend_vars)
else:
legend_vars = list(data.frame.columns.intersection(list(scales)))
# First handle layer legends, which occupy a single entry in legend_contents.
if layer_label is not None:
legend_title = str(p._labels.get("legend", ""))
layer_key = (legend_title, -1)
artist = mark._legend_artist([], None, {})
if artist is not None:
for content in self._legend_contents:
if content[0] == layer_key:
content[1].append(artist)
content[2].append(layer_label)
break
else:
self._legend_contents.append((layer_key, [artist], [layer_label]))
# Then handle the scale legends
# First pass: Identify the values that will be shown for each variable
schema: list[tuple[
tuple[str, str | int], list[str], tuple[list[Any], list[str]]
]] = []
schema = []
for var in legend_vars:
var_legend = scales[var]._legend
if var_legend is not None:
values, labels = var_legend
for (_, part_id), part_vars, _ in schema:
if data.ids[var] == part_id:
# Allow multiple plot semantics to represent same data variable
part_vars.append(var)
break
else:
title = self._resolve_label(p, var, data.names[var])
entry = (title, data.ids[var]), [var], (values, labels)
schema.append(entry)
# Second pass, generate an artist corresponding to each value
contents: list[tuple[tuple[str, str | int], Any, list[str]]] = []
for key, variables, (values, labels) in schema:
artists = []
for val in values:
artist = mark._legend_artist(variables, val, scales)
if artist is not None:
artists.append(artist)
if artists:
contents.append((key, artists, labels))
self._legend_contents.extend(contents)
def _make_legend(self, p: Plot) -> None:
"""Create the legend artist(s) and add onto the figure."""
# Combine artists representing same information across layers
# Input list has an entry for each distinct variable in each layer
# Output dict has an entry for each distinct variable
merged_contents: dict[
tuple[str, str | int], tuple[list[tuple[Artist, ...]], list[str]],
] = {}
for key, new_artists, labels in self._legend_contents:
# Key is (name, id); we need the id to resolve variable uniqueness,
# but will need the name in the next step to title the legend
if key not in merged_contents:
# Matplotlib accepts a tuple of artists and will overlay them
new_artist_tuples = [tuple([a]) for a in new_artists]
merged_contents[key] = new_artist_tuples, labels
else:
existing_artists = merged_contents[key][0]
for i, new_artist in enumerate(new_artists):
existing_artists[i] += tuple([new_artist])
# When using pyplot, an "external" legend won't be shown, so this
# keeps it inside the axes (though still attached to the figure)
# This is necessary because matplotlib layout engines currently don't
# support figure legends — ideally this will change.
loc = "center right" if self._pyplot else "center left"
base_legend = None
for (name, _), (handles, labels) in merged_contents.items():
legend = mpl.legend.Legend(
self._figure,
handles, # type: ignore # matplotlib/issues/26639
labels,
title=name,
loc=loc,
bbox_to_anchor=(.98, .55),
)
if base_legend:
# Matplotlib has no public API for this so it is a bit of a hack.
# Ideally we'd define our own legend class with more flexibility,
# but that is a lot of work!
base_legend_box = base_legend.get_children()[0]
this_legend_box = legend.get_children()[0]
base_legend_box.get_children().extend(this_legend_box.get_children())
else:
base_legend = legend
self._figure.legends.append(legend)
def _finalize_figure(self, p: Plot) -> None:
for sub in self._subplots:
ax = sub["ax"]
for axis in "xy":
axis_key = sub[axis]
axis_obj = getattr(ax, f"{axis}axis")
# Axis limits
if axis_key in p._limits or axis in p._limits:
convert_units = getattr(ax, f"{axis}axis").convert_units
a, b = p._limits.get(axis_key) or p._limits[axis]
lo = a if a is None else convert_units(a)
hi = b if b is None else convert_units(b)
if isinstance(a, str):
lo = cast(float, lo) - 0.5
if isinstance(b, str):
hi = cast(float, hi) + 0.5
ax.set(**{f"{axis}lim": (lo, hi)})
if axis_key in self._scales: # TODO when would it not be?
self._scales[axis_key]._finalize(p, axis_obj)
if (engine := p._layout_spec.get("engine", default)) is not default:
# None is a valid arg for Figure.set_layout_engine, hence `default`
set_layout_engine(self._figure, engine)
elif p._target is None:
# Don't modify the layout engine if the user supplied their own
# matplotlib figure and didn't specify an engine through Plot
# TODO switch default to "constrained"?
# TODO either way, make configurable
set_layout_engine(self._figure, "tight")