plot.py 66 KB


  1. """The classes for specifying and compiling a declarative visualization."""
  2. from __future__ import annotations
  3. import io
  4. import os
  5. import re
  6. import inspect
  7. import itertools
  8. import textwrap
  9. from contextlib import contextmanager
  10. from collections import abc
  11. from collections.abc import Callable, Generator
  12. from typing import Any, List, Literal, Optional, cast
  13. from xml.etree import ElementTree
  14. from cycler import cycler
  15. import pandas as pd
  16. from pandas import DataFrame, Series, Index
  17. import matplotlib as mpl
  18. from matplotlib.axes import Axes
  19. from matplotlib.artist import Artist
  20. from matplotlib.figure import Figure
  21. import numpy as np
  22. from PIL import Image
  23. from seaborn._marks.base import Mark
  24. from seaborn._stats.base import Stat
  25. from seaborn._core.data import PlotData
  26. from seaborn._core.moves import Move
  27. from seaborn._core.scales import Scale, Nominal
  28. from seaborn._core.subplots import Subplots
  29. from seaborn._core.groupby import GroupBy
  30. from seaborn._core.properties import PROPERTIES, Property
  31. from seaborn._core.typing import (
  32. DataSource,
  33. VariableSpec,
  34. VariableSpecList,
  35. OrderSpec,
  36. Default,
  37. )
  38. from seaborn._core.exceptions import PlotSpecError
  39. from seaborn._core.rules import categorical_order
  40. from seaborn._compat import set_scale_obj, set_layout_engine
  41. from seaborn.rcmod import axes_style, plotting_context
  42. from seaborn.palettes import color_palette
  43. from seaborn.utils import _version_predates
  44. from typing import TYPE_CHECKING, TypedDict
  45. if TYPE_CHECKING:
  46. from matplotlib.figure import SubFigure
  47. default = Default()
  48. # ---- Definitions for internal specs ---------------------------------------------- #
  49. class Layer(TypedDict, total=False):
  50. mark: Mark # TODO allow list?
  51. stat: Stat | None # TODO allow list?
  52. move: Move | list[Move] | None
  53. data: PlotData
  54. source: DataSource
  55. vars: dict[str, VariableSpec]
  56. orient: str
  57. legend: bool
  58. label: str | None
  59. class FacetSpec(TypedDict, total=False):
  60. variables: dict[str, VariableSpec]
  61. structure: dict[str, list[str]]
  62. wrap: int | None
  63. class PairSpec(TypedDict, total=False):
  64. variables: dict[str, VariableSpec]
  65. structure: dict[str, list[str]]
  66. cross: bool
  67. wrap: int | None
  68. # --- Local helpers ---------------------------------------------------------------- #
  69. @contextmanager
  70. def theme_context(params: dict[str, Any]) -> Generator:
  71. """Temporarily modify specifc matplotlib rcParams."""
  72. orig_params = {k: mpl.rcParams[k] for k in params}
  73. color_codes = "bgrmyck"
  74. nice_colors = [*color_palette("deep6"), (.15, .15, .15)]
  75. orig_colors = [mpl.colors.colorConverter.colors[x] for x in color_codes]
  76. # TODO how to allow this to reflect the color cycle when relevant?
  77. try:
  78. mpl.rcParams.update(params)
  79. for (code, color) in zip(color_codes, nice_colors):
  80. mpl.colors.colorConverter.colors[code] = color
  81. yield
  82. finally:
  83. mpl.rcParams.update(orig_params)
  84. for (code, color) in zip(color_codes, orig_colors):
  85. mpl.colors.colorConverter.colors[code] = color
  86. def build_plot_signature(cls):
  87. """
  88. Decorator function for giving Plot a useful signature.
  89. Currently this mostly saves us some duplicated typing, but we would
  90. like eventually to have a way of registering new semantic properties,
  91. at which point dynamic signature generation would become more important.
  92. """
  93. sig = inspect.signature(cls)
  94. params = [
  95. inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL),
  96. inspect.Parameter("data", inspect.Parameter.KEYWORD_ONLY, default=None)
  97. ]
  98. params.extend([
  99. inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=None)
  100. for name in PROPERTIES
  101. ])
  102. new_sig = sig.replace(parameters=params)
  103. cls.__signature__ = new_sig
  104. known_properties = textwrap.fill(
  105. ", ".join([f"|{p}|" for p in PROPERTIES]),
  106. width=78, subsequent_indent=" " * 8,
  107. )
  108. if cls.__doc__ is not None: # support python -OO mode
  109. cls.__doc__ = cls.__doc__.format(known_properties=known_properties)
  110. return cls
  111. # ---- Plot configuration ---------------------------------------------------------- #
  112. class ThemeConfig(mpl.RcParams):
  113. """
  114. Configuration object for the Plot.theme, using matplotlib rc parameters.
  115. """
  116. THEME_GROUPS = [
  117. "axes", "figure", "font", "grid", "hatch", "legend", "lines",
  118. "mathtext", "markers", "patch", "savefig", "scatter",
  119. "xaxis", "xtick", "yaxis", "ytick",
  120. ]
  121. def __init__(self):
  122. super().__init__()
  123. self.reset()
  124. @property
  125. def _default(self) -> dict[str, Any]:
  126. return {
  127. **self._filter_params(mpl.rcParamsDefault),
  128. **axes_style("darkgrid"),
  129. **plotting_context("notebook"),
  130. "axes.prop_cycle": cycler("color", color_palette("deep")),
  131. }
  132. def reset(self) -> None:
  133. """Update the theme dictionary with seaborn's default values."""
  134. self.update(self._default)
  135. def update(self, other: dict[str, Any] | None = None, /, **kwds):
  136. """Update the theme with a dictionary or keyword arguments of rc parameters."""
  137. if other is not None:
  138. theme = self._filter_params(other)
  139. else:
  140. theme = {}
  141. theme.update(kwds)
  142. super().update(theme)
  143. def _filter_params(self, params: dict[str, Any]) -> dict[str, Any]:
  144. """Restruct to thematic rc params."""
  145. return {
  146. k: v for k, v in params.items()
  147. if any(k.startswith(p) for p in self.THEME_GROUPS)
  148. }
  149. def _html_table(self, params: dict[str, Any]) -> list[str]:
  150. lines = ["<table>"]
  151. for k, v in params.items():
  152. row = f"<tr><td>{k}:</td><td style='text-align:left'>{v!r}</td></tr>"
  153. lines.append(row)
  154. lines.append("</table>")
  155. return lines
  156. def _repr_html_(self) -> str:
  157. repr = [
  158. "<div style='height: 300px'>",
  159. "<div style='border-style: inset; border-width: 2px'>",
  160. *self._html_table(self),
  161. "</div>",
  162. "</div>",
  163. ]
  164. return "\n".join(repr)
  165. class DisplayConfig(TypedDict):
  166. """Configuration for IPython's rich display hooks."""
  167. format: Literal["png", "svg"]
  168. scaling: float
  169. hidpi: bool
  170. class PlotConfig:
  171. """Configuration for default behavior / appearance of class:`Plot` instances."""
  172. def __init__(self):
  173. self._theme = ThemeConfig()
  174. self._display = {"format": "png", "scaling": .85, "hidpi": True}
  175. @property
  176. def theme(self) -> dict[str, Any]:
  177. """
  178. Dictionary of base theme parameters for :class:`Plot`.
  179. Keys and values correspond to matplotlib rc params, as documented here:
  180. https://matplotlib.org/stable/tutorials/introductory/customizing.html
  181. """
  182. return self._theme
  183. @property
  184. def display(self) -> DisplayConfig:
  185. """
  186. Dictionary of parameters for rich display in Jupyter notebook.
  187. Valid parameters:
  188. - format ("png" or "svg"): Image format to produce
  189. - scaling (float): Relative scaling of embedded image
  190. - hidpi (bool): When True, double the DPI while preserving the size
  191. """
  192. return self._display
  193. # ---- The main interface for declarative plotting --------------------------------- #
  194. @build_plot_signature
  195. class Plot:
  196. """
  197. An interface for declaratively specifying statistical graphics.
  198. Plots are constructed by initializing this class and adding one or more
  199. layers, comprising a `Mark` and optional `Stat` or `Move`. Additionally,
  200. faceting variables or variable pairings may be defined to divide the space
  201. into multiple subplots. The mappings from data values to visual properties
  202. can be parametrized using scales, although the plot will try to infer good
  203. defaults when scales are not explicitly defined.
  204. The constructor accepts a data source (a :class:`pandas.DataFrame` or
  205. dictionary with columnar values) and variable assignments. Variables can be
  206. passed as keys to the data source or directly as data vectors. If multiple
  207. data-containing objects are provided, they will be index-aligned.
  208. The data source and variables defined in the constructor will be used for
  209. all layers in the plot, unless overridden or disabled when adding a layer.
  210. The following variables can be defined in the constructor:
  211. {known_properties}
  212. The `data`, `x`, and `y` variables can be passed as positional arguments or
  213. using keywords. Whether the first positional argument is interpreted as a
  214. data source or `x` variable depends on its type.
  215. The methods of this class return a copy of the instance; use chaining to
  216. build up a plot through multiple calls. Methods can be called in any order.
  217. Most methods only add information to the plot spec; no actual processing
  218. happens until the plot is shown or saved. It is also possible to compile
  219. the plot without rendering it to access the lower-level representation.
  220. """
  221. config = PlotConfig()
  222. _data: PlotData
  223. _layers: list[Layer]
  224. _scales: dict[str, Scale]
  225. _shares: dict[str, bool | str]
  226. _limits: dict[str, tuple[Any, Any]]
  227. _labels: dict[str, str | Callable[[str], str]]
  228. _theme: dict[str, Any]
  229. _facet_spec: FacetSpec
  230. _pair_spec: PairSpec
  231. _figure_spec: dict[str, Any]
  232. _subplot_spec: dict[str, Any]
  233. _layout_spec: dict[str, Any]
  234. def __init__(
  235. self,
  236. *args: DataSource | VariableSpec,
  237. data: DataSource = None,
  238. **variables: VariableSpec,
  239. ):
  240. if args:
  241. data, variables = self._resolve_positionals(args, data, variables)
  242. unknown = [x for x in variables if x not in PROPERTIES]
  243. if unknown:
  244. err = f"Plot() got unexpected keyword argument(s): {', '.join(unknown)}"
  245. raise TypeError(err)
  246. self._data = PlotData(data, variables)
  247. self._layers = []
  248. self._scales = {}
  249. self._shares = {}
  250. self._limits = {}
  251. self._labels = {}
  252. self._theme = {}
  253. self._facet_spec = {}
  254. self._pair_spec = {}
  255. self._figure_spec = {}
  256. self._subplot_spec = {}
  257. self._layout_spec = {}
  258. self._target = None
  259. def _resolve_positionals(
  260. self,
  261. args: tuple[DataSource | VariableSpec, ...],
  262. data: DataSource,
  263. variables: dict[str, VariableSpec],
  264. ) -> tuple[DataSource, dict[str, VariableSpec]]:
  265. """Handle positional arguments, which may contain data / x / y."""
  266. if len(args) > 3:
  267. err = "Plot() accepts no more than 3 positional arguments (data, x, y)."
  268. raise TypeError(err)
  269. if (
  270. isinstance(args[0], (abc.Mapping, pd.DataFrame))
  271. or hasattr(args[0], "__dataframe__")
  272. ):
  273. if data is not None:
  274. raise TypeError("`data` given by both name and position.")
  275. data, args = args[0], args[1:]
  276. if len(args) == 2:
  277. x, y = args
  278. elif len(args) == 1:
  279. x, y = *args, None
  280. else:
  281. x = y = None
  282. for name, var in zip("yx", (y, x)):
  283. if var is not None:
  284. if name in variables:
  285. raise TypeError(f"`{name}` given by both name and position.")
  286. # Keep coordinates at the front of the variables dict
  287. # Cast type because we know this isn't a DataSource at this point
  288. variables = {name: cast(VariableSpec, var), **variables}
  289. return data, variables
  290. def __add__(self, other):
  291. if isinstance(other, Mark) or isinstance(other, Stat):
  292. raise TypeError("Sorry, this isn't ggplot! Perhaps try Plot.add?")
  293. other_type = other.__class__.__name__
  294. raise TypeError(f"Unsupported operand type(s) for +: 'Plot' and '{other_type}")
  295. def _repr_png_(self) -> tuple[bytes, dict[str, float]] | None:
  296. if Plot.config.display["format"] != "png":
  297. return None
  298. return self.plot()._repr_png_()
  299. def _repr_svg_(self) -> str | None:
  300. if Plot.config.display["format"] != "svg":
  301. return None
  302. return self.plot()._repr_svg_()
  303. def _clone(self) -> Plot:
  304. """Generate a new object with the same information as the current spec."""
  305. new = Plot()
  306. # TODO any way to enforce that data does not get mutated?
  307. new._data = self._data
  308. new._layers.extend(self._layers)
  309. new._scales.update(self._scales)
  310. new._shares.update(self._shares)
  311. new._limits.update(self._limits)
  312. new._labels.update(self._labels)
  313. new._theme.update(self._theme)
  314. new._facet_spec.update(self._facet_spec)
  315. new._pair_spec.update(self._pair_spec)
  316. new._figure_spec.update(self._figure_spec)
  317. new._subplot_spec.update(self._subplot_spec)
  318. new._layout_spec.update(self._layout_spec)
  319. new._target = self._target
  320. return new
  321. def _theme_with_defaults(self) -> dict[str, Any]:
  322. theme = self.config.theme.copy()
  323. theme.update(self._theme)
  324. return theme
  325. @property
  326. def _variables(self) -> list[str]:
  327. variables = (
  328. list(self._data.frame)
  329. + list(self._pair_spec.get("variables", []))
  330. + list(self._facet_spec.get("variables", []))
  331. )
  332. for layer in self._layers:
  333. variables.extend(v for v in layer["vars"] if v not in variables)
  334. # Coerce to str in return to appease mypy; we know these will only
  335. # ever be strings but I don't think we can type a DataFrame that way yet
  336. return [str(v) for v in variables]
  337. def on(self, target: Axes | SubFigure | Figure) -> Plot:
  338. """
  339. Provide existing Matplotlib figure or axes for drawing the plot.
  340. When using this method, you will also need to explicitly call a method that
  341. triggers compilation, such as :meth:`Plot.show` or :meth:`Plot.save`. If you
  342. want to postprocess using matplotlib, you'd need to call :meth:`Plot.plot`
  343. first to compile the plot without rendering it.
  344. Parameters
  345. ----------
  346. target : Axes, SubFigure, or Figure
  347. Matplotlib object to use. Passing :class:`matplotlib.axes.Axes` will add
  348. artists without otherwise modifying the figure. Otherwise, subplots will be
  349. created within the space of the given :class:`matplotlib.figure.Figure` or
  350. :class:`matplotlib.figure.SubFigure`.
  351. Examples
  352. --------
  353. .. include:: ../docstrings/objects.Plot.on.rst
  354. """
  355. accepted_types: tuple # Allow tuple of various length
  356. if hasattr(mpl.figure, "SubFigure"): # Added in mpl 3.4
  357. accepted_types = (
  358. mpl.axes.Axes, mpl.figure.SubFigure, mpl.figure.Figure
  359. )
  360. accepted_types_str = (
  361. f"{mpl.axes.Axes}, {mpl.figure.SubFigure}, or {mpl.figure.Figure}"
  362. )
  363. else:
  364. accepted_types = mpl.axes.Axes, mpl.figure.Figure
  365. accepted_types_str = f"{mpl.axes.Axes} or {mpl.figure.Figure}"
  366. if not isinstance(target, accepted_types):
  367. err = (
  368. f"The `Plot.on` target must be an instance of {accepted_types_str}. "
  369. f"You passed an instance of {target.__class__} instead."
  370. )
  371. raise TypeError(err)
  372. new = self._clone()
  373. new._target = target
  374. return new
  375. def add(
  376. self,
  377. mark: Mark,
  378. *transforms: Stat | Move,
  379. orient: str | None = None,
  380. legend: bool = True,
  381. label: str | None = None,
  382. data: DataSource = None,
  383. **variables: VariableSpec,
  384. ) -> Plot:
  385. """
  386. Specify a layer of the visualization in terms of mark and data transform(s).
  387. This is the main method for specifying how the data should be visualized.
  388. It can be called multiple times with different arguments to define
  389. a plot with multiple layers.
  390. Parameters
  391. ----------
  392. mark : :class:`Mark`
  393. The visual representation of the data to use in this layer.
  394. transforms : :class:`Stat` or :class:`Move`
  395. Objects representing transforms to be applied before plotting the data.
  396. Currently, at most one :class:`Stat` can be used, and it
  397. must be passed first. This constraint will be relaxed in the future.
  398. orient : "x", "y", "v", or "h"
  399. The orientation of the mark, which also affects how transforms are computed.
  400. Typically corresponds to the axis that defines groups for aggregation.
  401. The "v" (vertical) and "h" (horizontal) options are synonyms for "x" / "y",
  402. but may be more intuitive with some marks. When not provided, an
  403. orientation will be inferred from characteristics of the data and scales.
  404. legend : bool
  405. Option to suppress the mark/mappings for this layer from the legend.
  406. label : str
  407. A label to use for the layer in the legend, independent of any mappings.
  408. data : DataFrame or dict
  409. Data source to override the global source provided in the constructor.
  410. variables : data vectors or identifiers
  411. Additional layer-specific variables, including variables that will be
  412. passed directly to the transforms without scaling.
  413. Examples
  414. --------
  415. .. include:: ../docstrings/objects.Plot.add.rst
  416. """
  417. if not isinstance(mark, Mark):
  418. msg = f"mark must be a Mark instance, not {type(mark)!r}."
  419. raise TypeError(msg)
  420. # TODO This API for transforms was a late decision, and previously Plot.add
  421. # accepted 0 or 1 Stat instances and 0, 1, or a list of Move instances.
  422. # It will take some work to refactor the internals so that Stat and Move are
  423. # treated identically, and until then well need to "unpack" the transforms
  424. # here and enforce limitations on the order / types.
  425. stat: Optional[Stat]
  426. move: Optional[List[Move]]
  427. error = False
  428. if not transforms:
  429. stat, move = None, None
  430. elif isinstance(transforms[0], Stat):
  431. stat = transforms[0]
  432. move = [m for m in transforms[1:] if isinstance(m, Move)]
  433. error = len(move) != len(transforms) - 1
  434. else:
  435. stat = None
  436. move = [m for m in transforms if isinstance(m, Move)]
  437. error = len(move) != len(transforms)
  438. if error:
  439. msg = " ".join([
  440. "Transforms must have at most one Stat type (in the first position),",
  441. "and all others must be a Move type. Given transform type(s):",
  442. ", ".join(str(type(t).__name__) for t in transforms) + "."
  443. ])
  444. raise TypeError(msg)
  445. new = self._clone()
  446. new._layers.append({
  447. "mark": mark,
  448. "stat": stat,
  449. "move": move,
  450. # TODO it doesn't work to supply scalars to variables, but it should
  451. "vars": variables,
  452. "source": data,
  453. "legend": legend,
  454. "label": label,
  455. "orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore
  456. })
  457. return new
  458. def pair(
  459. self,
  460. x: VariableSpecList = None,
  461. y: VariableSpecList = None,
  462. wrap: int | None = None,
  463. cross: bool = True,
  464. ) -> Plot:
  465. """
  466. Produce subplots by pairing multiple `x` and/or `y` variables.
  467. Parameters
  468. ----------
  469. x, y : sequence(s) of data vectors or identifiers
  470. Variables that will define the grid of subplots.
  471. wrap : int
  472. When using only `x` or `y`, "wrap" subplots across a two-dimensional grid
  473. with this many columns (when using `x`) or rows (when using `y`).
  474. cross : bool
  475. When False, zip the `x` and `y` lists such that the first subplot gets the
  476. first pair, the second gets the second pair, etc. Otherwise, create a
  477. two-dimensional grid from the cartesian product of the lists.
  478. Examples
  479. --------
  480. .. include:: ../docstrings/objects.Plot.pair.rst
  481. """
  482. # TODO Add transpose= arg, which would then draw pair(y=[...]) across rows
  483. # This may also be possible by setting `wrap=1`, but is that too unobvious?
  484. # TODO PairGrid features not currently implemented: diagonals, corner
  485. pair_spec: PairSpec = {}
  486. axes = {"x": [] if x is None else x, "y": [] if y is None else y}
  487. for axis, arg in axes.items():
  488. if isinstance(arg, (str, int)):
  489. err = f"You must pass a sequence of variable keys to `{axis}`"
  490. raise TypeError(err)
  491. pair_spec["variables"] = {}
  492. pair_spec["structure"] = {}
  493. for axis in "xy":
  494. keys = []
  495. for i, col in enumerate(axes[axis]):
  496. key = f"{axis}{i}"
  497. keys.append(key)
  498. pair_spec["variables"][key] = col
  499. if keys:
  500. pair_spec["structure"][axis] = keys
  501. if not cross and len(axes["x"]) != len(axes["y"]):
  502. err = "Lengths of the `x` and `y` lists must match with cross=False"
  503. raise ValueError(err)
  504. pair_spec["cross"] = cross
  505. pair_spec["wrap"] = wrap
  506. new = self._clone()
  507. new._pair_spec.update(pair_spec)
  508. return new
  509. def facet(
  510. self,
  511. col: VariableSpec = None,
  512. row: VariableSpec = None,
  513. order: OrderSpec | dict[str, OrderSpec] = None,
  514. wrap: int | None = None,
  515. ) -> Plot:
  516. """
  517. Produce subplots with conditional subsets of the data.
  518. Parameters
  519. ----------
  520. col, row : data vectors or identifiers
  521. Variables used to define subsets along the columns and/or rows of the grid.
  522. Can be references to the global data source passed in the constructor.
  523. order : list of strings, or dict with dimensional keys
  524. Define the order of the faceting variables.
  525. wrap : int
  526. When using only `col` or `row`, wrap subplots across a two-dimensional
  527. grid with this many subplots on the faceting dimension.
  528. Examples
  529. --------
  530. .. include:: ../docstrings/objects.Plot.facet.rst
  531. """
  532. variables: dict[str, VariableSpec] = {}
  533. if col is not None:
  534. variables["col"] = col
  535. if row is not None:
  536. variables["row"] = row
  537. structure = {}
  538. if isinstance(order, dict):
  539. for dim in ["col", "row"]:
  540. dim_order = order.get(dim)
  541. if dim_order is not None:
  542. structure[dim] = list(dim_order)
  543. elif order is not None:
  544. if col is not None and row is not None:
  545. err = " ".join([
  546. "When faceting on both col= and row=, passing `order` as a list"
  547. "is ambiguous. Use a dict with 'col' and/or 'row' keys instead."
  548. ])
  549. raise RuntimeError(err)
  550. elif col is not None:
  551. structure["col"] = list(order)
  552. elif row is not None:
  553. structure["row"] = list(order)
  554. spec: FacetSpec = {
  555. "variables": variables,
  556. "structure": structure,
  557. "wrap": wrap,
  558. }
  559. new = self._clone()
  560. new._facet_spec.update(spec)
  561. return new
  562. # TODO def twin()?
  563. def scale(self, **scales: Scale) -> Plot:
  564. """
  565. Specify mappings from data units to visual properties.
  566. Keywords correspond to variables defined in the plot, including coordinate
  567. variables (`x`, `y`) and semantic variables (`color`, `pointsize`, etc.).
  568. A number of "magic" arguments are accepted, including:
  569. - The name of a transform (e.g., `"log"`, `"sqrt"`)
  570. - The name of a palette (e.g., `"viridis"`, `"muted"`)
  571. - A tuple of values, defining the output range (e.g. `(1, 5)`)
  572. - A dict, implying a :class:`Nominal` scale (e.g. `{"a": .2, "b": .5}`)
  573. - A list of values, implying a :class:`Nominal` scale (e.g. `["b", "r"]`)
  574. For more explicit control, pass a scale spec object such as :class:`Continuous`
  575. or :class:`Nominal`. Or pass `None` to use an "identity" scale, which treats
  576. data values as literally encoding visual properties.
  577. Examples
  578. --------
  579. .. include:: ../docstrings/objects.Plot.scale.rst
  580. """
  581. new = self._clone()
  582. new._scales.update(scales)
  583. return new
  584. def share(self, **shares: bool | str) -> Plot:
  585. """
  586. Control sharing of axis limits and ticks across subplots.
  587. Keywords correspond to variables defined in the plot, and values can be
  588. boolean (to share across all subplots), or one of "row" or "col" (to share
  589. more selectively across one dimension of a grid).
  590. Behavior for non-coordinate variables is currently undefined.
  591. Examples
  592. --------
  593. .. include:: ../docstrings/objects.Plot.share.rst
  594. """
  595. new = self._clone()
  596. new._shares.update(shares)
  597. return new
  598. def limit(self, **limits: tuple[Any, Any]) -> Plot:
  599. """
  600. Control the range of visible data.
  601. Keywords correspond to variables defined in the plot, and values are a
  602. `(min, max)` tuple (where either can be `None` to leave unset).
  603. Limits apply only to the axis; data outside the visible range are
  604. still used for any stat transforms and added to the plot.
  605. Behavior for non-coordinate variables is currently undefined.
  606. Examples
  607. --------
  608. .. include:: ../docstrings/objects.Plot.limit.rst
  609. """
  610. new = self._clone()
  611. new._limits.update(limits)
  612. return new
  613. def label(
  614. self, *,
  615. title: str | None = None,
  616. legend: str | None = None,
  617. **variables: str | Callable[[str], str]
  618. ) -> Plot:
  619. """
  620. Control the labels and titles for axes, legends, and subplots.
  621. Additional keywords correspond to variables defined in the plot.
  622. Values can be one of the following types:
  623. - string (used literally; pass "" to clear the default label)
  624. - function (called on the default label)
  625. For coordinate variables, the value sets the axis label.
  626. For semantic variables, the value sets the legend title.
  627. For faceting variables, `title=` modifies the subplot-specific label,
  628. while `col=` and/or `row=` add a label for the faceting variable.
  629. When using a single subplot, `title=` sets its title.
  630. The `legend=` parameter sets the title for the "layer" legend
  631. (i.e., when using `label` in :meth:`Plot.add`).
  632. Examples
  633. --------
  634. .. include:: ../docstrings/objects.Plot.label.rst
  635. """
  636. new = self._clone()
  637. if title is not None:
  638. new._labels["title"] = title
  639. if legend is not None:
  640. new._labels["legend"] = legend
  641. new._labels.update(variables)
  642. return new
  643. def layout(
  644. self,
  645. *,
  646. size: tuple[float, float] | Default = default,
  647. engine: str | None | Default = default,
  648. ) -> Plot:
  649. """
  650. Control the figure size and layout.
  651. .. note::
  652. Default figure sizes and the API for specifying the figure size are subject
  653. to change in future "experimental" releases of the objects API. The default
  654. layout engine may also change.
  655. Parameters
  656. ----------
  657. size : (width, height)
  658. Size of the resulting figure, in inches. Size is inclusive of legend when
  659. using pyplot, but not otherwise.
  660. engine : {{"tight", "constrained", None}}
  661. Name of method for automatically adjusting the layout to remove overlap.
  662. The default depends on whether :meth:`Plot.on` is used.
  663. Examples
  664. --------
  665. .. include:: ../docstrings/objects.Plot.layout.rst
  666. """
  667. # TODO add an "auto" mode for figsize that roughly scales with the rcParams
  668. # figsize (so that works), but expands to prevent subplots from being squished
  669. # Also should we have height=, aspect=, exclusive with figsize? Or working
  670. # with figsize when only one is defined?
  671. new = self._clone()
  672. if size is not default:
  673. new._figure_spec["figsize"] = size
  674. if engine is not default:
  675. new._layout_spec["engine"] = engine
  676. return new
  677. # TODO def legend (ugh)
  678. def theme(self, *args: dict[str, Any]) -> Plot:
  679. """
  680. Control the appearance of elements in the plot.
  681. .. note::
  682. The API for customizing plot appearance is not yet finalized.
  683. Currently, the only valid argument is a dict of matplotlib rc parameters.
  684. (This dict must be passed as a positional argument.)
  685. It is likely that this method will be enhanced in future releases.
  686. Matplotlib rc parameters are documented on the following page:
  687. https://matplotlib.org/stable/tutorials/introductory/customizing.html
  688. Examples
  689. --------
  690. .. include:: ../docstrings/objects.Plot.theme.rst
  691. """
  692. new = self._clone()
  693. # We can skip this whole block on Python 3.8+ with positional-only syntax
  694. nargs = len(args)
  695. if nargs != 1:
  696. err = f"theme() takes 1 positional argument, but {nargs} were given"
  697. raise TypeError(err)
  698. rc = mpl.RcParams(args[0])
  699. new._theme.update(rc)
  700. return new
  701. def save(self, loc, **kwargs) -> Plot:
  702. """
  703. Compile the plot and write it to a buffer or file on disk.
  704. Parameters
  705. ----------
  706. loc : str, path, or buffer
  707. Location on disk to save the figure, or a buffer to write into.
  708. kwargs
  709. Other keyword arguments are passed through to
  710. :meth:`matplotlib.figure.Figure.savefig`.
  711. """
  712. # TODO expose important keyword arguments in our signature?
  713. with theme_context(self._theme_with_defaults()):
  714. self._plot().save(loc, **kwargs)
  715. return self
  716. def show(self, **kwargs) -> None:
  717. """
  718. Compile the plot and display it by hooking into pyplot.
  719. Calling this method is not necessary to render a plot in notebook context,
  720. but it may be in other environments (e.g., in a terminal). After compiling the
  721. plot, it calls :func:`matplotlib.pyplot.show` (passing any keyword parameters).
  722. Unlike other :class:`Plot` methods, there is no return value. This should be
  723. the last method you call when specifying a plot.
  724. """
  725. # TODO make pyplot configurable at the class level, and when not using,
  726. # import IPython.display and call on self to populate cell output?
  727. # Keep an eye on whether matplotlib implements "attaching" an existing
  728. # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024
  729. self.plot(pyplot=True).show(**kwargs)
  730. def plot(self, pyplot: bool = False) -> Plotter:
  731. """
  732. Compile the plot spec and return the Plotter object.
  733. """
  734. with theme_context(self._theme_with_defaults()):
  735. return self._plot(pyplot)
  736. def _plot(self, pyplot: bool = False) -> Plotter:
  737. # TODO if we have _target object, pyplot should be determined by whether it
  738. # is hooked into the pyplot state machine (how do we check?)
  739. plotter = Plotter(pyplot=pyplot, theme=self._theme_with_defaults())
  740. # Process the variable assignments and initialize the figure
  741. common, layers = plotter._extract_data(self)
  742. plotter._setup_figure(self, common, layers)
  743. # Process the scale spec for coordinate variables and transform their data
  744. coord_vars = [v for v in self._variables if re.match(r"^x|y", v)]
  745. plotter._setup_scales(self, common, layers, coord_vars)
  746. # Apply statistical transform(s)
  747. plotter._compute_stats(self, layers)
  748. # Process scale spec for semantic variables and coordinates computed by stat
  749. plotter._setup_scales(self, common, layers)
  750. # TODO Remove these after updating other methods
  751. # ---- Maybe have debug= param that attaches these when True?
  752. plotter._data = common
  753. plotter._layers = layers
  754. # Process the data for each layer and add matplotlib artists
  755. for layer in layers:
  756. plotter._plot_layer(self, layer)
  757. # Add various figure decorations
  758. plotter._make_legend(self)
  759. plotter._finalize_figure(self)
  760. return plotter
  761. # ---- The plot compilation engine ---------------------------------------------- #
  762. class Plotter:
  763. """
  764. Engine for compiling a :class:`Plot` spec into a Matplotlib figure.
  765. This class is not intended to be instantiated directly by users.
  766. """
  767. # TODO decide if we ever want these (Plot.plot(debug=True))?
  768. _data: PlotData
  769. _layers: list[Layer]
  770. _figure: Figure
  771. def __init__(self, pyplot: bool, theme: dict[str, Any]):
  772. self._pyplot = pyplot
  773. self._theme = theme
  774. self._legend_contents: list[tuple[
  775. tuple[str, str | int], list[Artist], list[str],
  776. ]] = []
  777. self._scales: dict[str, Scale] = {}
  778. def save(self, loc, **kwargs) -> Plotter: # TODO type args
  779. kwargs.setdefault("dpi", 96)
  780. try:
  781. loc = os.path.expanduser(loc)
  782. except TypeError:
  783. # loc may be a buffer in which case that would not work
  784. pass
  785. self._figure.savefig(loc, **kwargs)
  786. return self
  787. def show(self, **kwargs) -> None:
  788. """
  789. Display the plot by hooking into pyplot.
  790. This method calls :func:`matplotlib.pyplot.show` with any keyword parameters.
  791. """
  792. # TODO if we did not create the Plotter with pyplot, is it possible to do this?
  793. # If not we should clearly raise.
  794. import matplotlib.pyplot as plt
  795. with theme_context(self._theme):
  796. plt.show(**kwargs)
  797. # TODO API for accessing the underlying matplotlib objects
  798. # TODO what else is useful in the public API for this class?
  799. def _repr_png_(self) -> tuple[bytes, dict[str, float]] | None:
  800. # TODO use matplotlib backend directly instead of going through savefig?
  801. # TODO perhaps have self.show() flip a switch to disable this, so that
  802. # user does not end up with two versions of the figure in the output
  803. # TODO use bbox_inches="tight" like the inline backend?
  804. # pro: better results, con: (sometimes) confusing results
  805. # Better solution would be to default (with option to change)
  806. # to using constrained/tight layout.
  807. if Plot.config.display["format"] != "png":
  808. return None
  809. buffer = io.BytesIO()
  810. factor = 2 if Plot.config.display["hidpi"] else 1
  811. scaling = Plot.config.display["scaling"] / factor
  812. dpi = 96 * factor # TODO put dpi in Plot.config?
  813. with theme_context(self._theme): # TODO _theme_with_defaults?
  814. self._figure.savefig(buffer, dpi=dpi, format="png", bbox_inches="tight")
  815. data = buffer.getvalue()
  816. w, h = Image.open(buffer).size
  817. metadata = {"width": w * scaling, "height": h * scaling}
  818. return data, metadata
  819. def _repr_svg_(self) -> str | None:
  820. if Plot.config.display["format"] != "svg":
  821. return None
  822. # TODO DPI for rasterized artists?
  823. scaling = Plot.config.display["scaling"]
  824. buffer = io.StringIO()
  825. with theme_context(self._theme): # TODO _theme_with_defaults?
  826. self._figure.savefig(buffer, format="svg", bbox_inches="tight")
  827. root = ElementTree.fromstring(buffer.getvalue())
  828. w = scaling * float(root.attrib["width"][:-2])
  829. h = scaling * float(root.attrib["height"][:-2])
  830. root.attrib.update(width=f"{w}pt", height=f"{h}pt", viewbox=f"0 0 {w} {h}")
  831. ElementTree.ElementTree(root).write(out := io.BytesIO())
  832. return out.getvalue().decode()
  833. def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]:
  834. common_data = (
  835. p._data
  836. .join(None, p._facet_spec.get("variables"))
  837. .join(None, p._pair_spec.get("variables"))
  838. )
  839. layers: list[Layer] = []
  840. for layer in p._layers:
  841. spec = layer.copy()
  842. spec["data"] = common_data.join(layer.get("source"), layer.get("vars"))
  843. layers.append(spec)
  844. return common_data, layers
  845. def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str:
  846. if re.match(r"[xy]\d+", var):
  847. key = var if var in p._labels else var[0]
  848. else:
  849. key = var
  850. label: str
  851. if key in p._labels:
  852. manual_label = p._labels[key]
  853. if callable(manual_label) and auto_label is not None:
  854. label = manual_label(auto_label)
  855. else:
  856. label = cast(str, manual_label)
  857. elif auto_label is None:
  858. label = ""
  859. else:
  860. label = auto_label
  861. return label
  862. def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
  863. # --- Parsing the faceting/pairing parameterization to specify figure grid
  864. subplot_spec = p._subplot_spec.copy()
  865. facet_spec = p._facet_spec.copy()
  866. pair_spec = p._pair_spec.copy()
  867. for axis in "xy":
  868. if axis in p._shares:
  869. subplot_spec[f"share{axis}"] = p._shares[axis]
  870. for dim in ["col", "row"]:
  871. if dim in common.frame and dim not in facet_spec["structure"]:
  872. order = categorical_order(common.frame[dim])
  873. facet_spec["structure"][dim] = order
  874. self._subplots = subplots = Subplots(subplot_spec, facet_spec, pair_spec)
  875. # --- Figure initialization
  876. self._figure = subplots.init_figure(
  877. pair_spec, self._pyplot, p._figure_spec, p._target,
  878. )
  879. # --- Figure annotation
  880. for sub in subplots:
  881. ax = sub["ax"]
  882. for axis in "xy":
  883. axis_key = sub[axis]
  884. # ~~ Axis labels
  885. # TODO Should we make it possible to use only one x/y label for
  886. # all rows/columns in a faceted plot? Maybe using sub{axis}label,
  887. # although the alignments of the labels from that method leaves
  888. # something to be desired (in terms of how it defines 'centered').
  889. names = [
  890. common.names.get(axis_key),
  891. *(layer["data"].names.get(axis_key) for layer in layers)
  892. ]
  893. auto_label = next((name for name in names if name is not None), None)
  894. label = self._resolve_label(p, axis_key, auto_label)
  895. ax.set(**{f"{axis}label": label})
  896. # ~~ Decoration visibility
  897. # TODO there should be some override (in Plot.layout?) so that
  898. # axis / tick labels can be shown on interior shared axes if desired
  899. axis_obj = getattr(ax, f"{axis}axis")
  900. visible_side = {"x": "bottom", "y": "left"}.get(axis)
  901. show_axis_label = (
  902. sub[visible_side]
  903. or not p._pair_spec.get("cross", True)
  904. or (
  905. axis in p._pair_spec.get("structure", {})
  906. and bool(p._pair_spec.get("wrap"))
  907. )
  908. )
  909. axis_obj.get_label().set_visible(show_axis_label)
  910. show_tick_labels = (
  911. show_axis_label
  912. or subplot_spec.get(f"share{axis}") not in (
  913. True, "all", {"x": "col", "y": "row"}[axis]
  914. )
  915. )
  916. for group in ("major", "minor"):
  917. for t in getattr(axis_obj, f"get_{group}ticklabels")():
  918. t.set_visible(show_tick_labels)
  919. # TODO we want right-side titles for row facets in most cases?
  920. # Let's have what we currently call "margin titles" but properly using the
  921. # ax.set_title interface (see my gist)
  922. title_parts = []
  923. for dim in ["col", "row"]:
  924. if sub[dim] is not None:
  925. val = self._resolve_label(p, "title", f"{sub[dim]}")
  926. if dim in p._labels:
  927. key = self._resolve_label(p, dim, common.names.get(dim))
  928. val = f"{key} {val}"
  929. title_parts.append(val)
  930. has_col = sub["col"] is not None
  931. has_row = sub["row"] is not None
  932. show_title = (
  933. has_col and has_row
  934. or (has_col or has_row) and p._facet_spec.get("wrap")
  935. or (has_col and sub["top"])
  936. # TODO or has_row and sub["right"] and <right titles>
  937. or has_row # TODO and not <right titles>
  938. )
  939. if title_parts:
  940. title = " | ".join(title_parts)
  941. title_text = ax.set_title(title)
  942. title_text.set_visible(show_title)
  943. elif not (has_col or has_row):
  944. title = self._resolve_label(p, "title", None)
  945. title_text = ax.set_title(title)
  946. def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None:
  947. grouping_vars = [v for v in PROPERTIES if v not in "xy"]
  948. grouping_vars += ["col", "row", "group"]
  949. pair_vars = spec._pair_spec.get("structure", {})
  950. for layer in layers:
  951. data = layer["data"]
  952. mark = layer["mark"]
  953. stat = layer["stat"]
  954. if stat is None:
  955. continue
  956. iter_axes = itertools.product(*[
  957. pair_vars.get(axis, [axis]) for axis in "xy"
  958. ])
  959. old = data.frame
  960. if pair_vars:
  961. data.frames = {}
  962. data.frame = data.frame.iloc[:0] # TODO to simplify typing
  963. for coord_vars in iter_axes:
  964. pairings = "xy", coord_vars
  965. df = old.copy()
  966. scales = self._scales.copy()
  967. for axis, var in zip(*pairings):
  968. if axis != var:
  969. df = df.rename(columns={var: axis})
  970. drop_cols = [x for x in df if re.match(rf"{axis}\d+", str(x))]
  971. df = df.drop(drop_cols, axis=1)
  972. scales[axis] = scales[var]
  973. orient = layer["orient"] or mark._infer_orient(scales)
  974. if stat.group_by_orient:
  975. grouper = [orient, *grouping_vars]
  976. else:
  977. grouper = grouping_vars
  978. groupby = GroupBy(grouper)
  979. res = stat(df, groupby, orient, scales)
  980. if pair_vars:
  981. data.frames[coord_vars] = res
  982. else:
  983. data.frame = res
  984. def _get_scale(
  985. self, p: Plot, var: str, prop: Property, values: Series
  986. ) -> Scale:
  987. if re.match(r"[xy]\d+", var):
  988. key = var if var in p._scales else var[0]
  989. else:
  990. key = var
  991. if key in p._scales:
  992. arg = p._scales[key]
  993. if arg is None or isinstance(arg, Scale):
  994. scale = arg
  995. else:
  996. scale = prop.infer_scale(arg, values)
  997. else:
  998. scale = prop.default_scale(values)
  999. return scale
  1000. def _get_subplot_data(self, df, var, view, share_state):
  1001. if share_state in [True, "all"]:
  1002. # The all-shared case is easiest, every subplot sees all the data
  1003. seed_values = df[var]
  1004. else:
  1005. # Otherwise, we need to setup separate scales for different subplots
  1006. if share_state in [False, "none"]:
  1007. # Fully independent axes are also easy: use each subplot's data
  1008. idx = self._get_subplot_index(df, view)
  1009. elif share_state in df:
  1010. # Sharing within row/col is more complicated
  1011. use_rows = df[share_state] == view[share_state]
  1012. idx = df.index[use_rows]
  1013. else:
  1014. # This configuration doesn't make much sense, but it's fine
  1015. idx = df.index
  1016. seed_values = df.loc[idx, var]
  1017. return seed_values
  1018. def _setup_scales(
  1019. self,
  1020. p: Plot,
  1021. common: PlotData,
  1022. layers: list[Layer],
  1023. variables: list[str] | None = None,
  1024. ) -> None:
  1025. if variables is None:
  1026. # Add variables that have data but not a scale, which happens
  1027. # because this method can be called multiple time, to handle
  1028. # variables added during the Stat transform.
  1029. variables = []
  1030. for layer in layers:
  1031. variables.extend(layer["data"].frame.columns)
  1032. for df in layer["data"].frames.values():
  1033. variables.extend(str(v) for v in df if v not in variables)
  1034. variables = [v for v in variables if v not in self._scales]
  1035. for var in variables:
  1036. # Determine whether this is a coordinate variable
  1037. # (i.e., x/y, paired x/y, or derivative such as xmax)
  1038. m = re.match(r"^(?P<coord>(?P<axis>x|y)\d*).*", var)
  1039. if m is None:
  1040. coord = axis = None
  1041. else:
  1042. coord = m["coord"]
  1043. axis = m["axis"]
  1044. # Get keys that handle things like x0, xmax, properly where relevant
  1045. prop_key = var if axis is None else axis
  1046. scale_key = var if coord is None else coord
  1047. if prop_key not in PROPERTIES:
  1048. continue
  1049. # Concatenate layers, using only the relevant coordinate and faceting vars,
  1050. # This is unnecessarily wasteful, as layer data will often be redundant.
  1051. # But figuring out the minimal amount we need is more complicated.
  1052. cols = [var, "col", "row"]
  1053. parts = [common.frame.filter(cols)]
  1054. for layer in layers:
  1055. parts.append(layer["data"].frame.filter(cols))
  1056. for df in layer["data"].frames.values():
  1057. parts.append(df.filter(cols))
  1058. var_df = pd.concat(parts, ignore_index=True)
  1059. prop = PROPERTIES[prop_key]
  1060. scale = self._get_scale(p, scale_key, prop, var_df[var])
  1061. if scale_key not in p._variables:
  1062. # TODO this implies that the variable was added by the stat
  1063. # It allows downstream orientation inference to work properly.
  1064. # But it feels rather hacky, so ideally revisit.
  1065. scale._priority = 0 # type: ignore
  1066. if axis is None:
  1067. # We could think about having a broader concept of (un)shared properties
  1068. # In general, not something you want to do (different scales in facets)
  1069. # But could make sense e.g. with paired plots. Build later.
  1070. share_state = None
  1071. subplots = []
  1072. else:
  1073. share_state = self._subplots.subplot_spec[f"share{axis}"]
  1074. subplots = [view for view in self._subplots if view[axis] == coord]
  1075. # Shared categorical axes are broken on matplotlib<3.4.0.
  1076. # https://github.com/matplotlib/matplotlib/pull/18308
  1077. # This only affects us when sharing *paired* axes. This is a novel/niche
  1078. # behavior, so we will raise rather than hack together a workaround.
  1079. if axis is not None and _version_predates(mpl, "3.4"):
  1080. paired_axis = axis in p._pair_spec.get("structure", {})
  1081. cat_scale = isinstance(scale, Nominal)
  1082. ok_dim = {"x": "col", "y": "row"}[axis]
  1083. shared_axes = share_state not in [False, "none", ok_dim]
  1084. if paired_axis and cat_scale and shared_axes:
  1085. err = "Sharing paired categorical axes requires matplotlib>=3.4.0"
  1086. raise RuntimeError(err)
  1087. if scale is None:
  1088. self._scales[var] = Scale._identity()
  1089. else:
  1090. try:
  1091. self._scales[var] = scale._setup(var_df[var], prop)
  1092. except Exception as err:
  1093. raise PlotSpecError._during("Scale setup", var) from err
  1094. if axis is None or (var != coord and coord in p._variables):
  1095. # Everything below here applies only to coordinate variables
  1096. continue
  1097. # Set up an empty series to receive the transformed values.
  1098. # We need this to handle piecemeal transforms of categories -> floats.
  1099. transformed_data = []
  1100. for layer in layers:
  1101. index = layer["data"].frame.index
  1102. empty_series = pd.Series(dtype=float, index=index, name=var)
  1103. transformed_data.append(empty_series)
  1104. for view in subplots:
  1105. axis_obj = getattr(view["ax"], f"{axis}axis")
  1106. seed_values = self._get_subplot_data(var_df, var, view, share_state)
  1107. view_scale = scale._setup(seed_values, prop, axis=axis_obj)
  1108. set_scale_obj(view["ax"], axis, view_scale._matplotlib_scale)
  1109. for layer, new_series in zip(layers, transformed_data):
  1110. layer_df = layer["data"].frame
  1111. if var not in layer_df:
  1112. continue
  1113. idx = self._get_subplot_index(layer_df, view)
  1114. try:
  1115. new_series.loc[idx] = view_scale(layer_df.loc[idx, var])
  1116. except Exception as err:
  1117. spec_error = PlotSpecError._during("Scaling operation", var)
  1118. raise spec_error from err
  1119. # Now the transformed data series are complete, update the layer data
  1120. for layer, new_series in zip(layers, transformed_data):
  1121. layer_df = layer["data"].frame
  1122. if var in layer_df:
  1123. layer_df[var] = pd.to_numeric(new_series)
  1124. def _plot_layer(self, p: Plot, layer: Layer) -> None:
  1125. data = layer["data"]
  1126. mark = layer["mark"]
  1127. move = layer["move"]
  1128. default_grouping_vars = ["col", "row", "group"] # TODO where best to define?
  1129. grouping_properties = [v for v in PROPERTIES if v[0] not in "xy"]
  1130. pair_variables = p._pair_spec.get("structure", {})
  1131. for subplots, df, scales in self._generate_pairings(data, pair_variables):
  1132. orient = layer["orient"] or mark._infer_orient(scales)
  1133. def get_order(var):
  1134. # Ignore order for x/y: they have been scaled to numeric indices,
  1135. # so any original order is no longer valid. Default ordering rules
  1136. # sorted unique numbers will correctly reconstruct intended order
  1137. # TODO This is tricky, make sure we add some tests for this
  1138. if var not in "xy" and var in scales:
  1139. return getattr(scales[var], "order", None)
  1140. if orient in df:
  1141. width = pd.Series(index=df.index, dtype=float)
  1142. for view in subplots:
  1143. view_idx = self._get_subplot_data(
  1144. df, orient, view, p._shares.get(orient)
  1145. ).index
  1146. view_df = df.loc[view_idx]
  1147. if "width" in mark._mappable_props:
  1148. view_width = mark._resolve(view_df, "width", None)
  1149. elif "width" in df:
  1150. view_width = view_df["width"]
  1151. else:
  1152. view_width = 0.8 # TODO what default?
  1153. spacing = scales[orient]._spacing(view_df.loc[view_idx, orient])
  1154. width.loc[view_idx] = view_width * spacing
  1155. df["width"] = width
  1156. if "baseline" in mark._mappable_props:
  1157. # TODO what marks should have this?
  1158. # If we can set baseline with, e.g., Bar(), then the
  1159. # "other" (e.g. y for x oriented bars) parameterization
  1160. # is somewhat ambiguous.
  1161. baseline = mark._resolve(df, "baseline", None)
  1162. else:
  1163. # TODO unlike width, we might not want to add baseline to data
  1164. # if the mark doesn't use it. Practically, there is a concern about
  1165. # Mark abstraction like Area / Ribbon
  1166. baseline = 0 if "baseline" not in df else df["baseline"]
  1167. df["baseline"] = baseline
  1168. if move is not None:
  1169. moves = move if isinstance(move, list) else [move]
  1170. for move_step in moves:
  1171. move_by = getattr(move_step, "by", None)
  1172. if move_by is None:
  1173. move_by = grouping_properties
  1174. move_groupers = [*move_by, *default_grouping_vars]
  1175. if move_step.group_by_orient:
  1176. move_groupers.insert(0, orient)
  1177. order = {var: get_order(var) for var in move_groupers}
  1178. groupby = GroupBy(order)
  1179. df = move_step(df, groupby, orient, scales)
  1180. df = self._unscale_coords(subplots, df, orient)
  1181. grouping_vars = mark._grouping_props + default_grouping_vars
  1182. split_generator = self._setup_split_generator(grouping_vars, df, subplots)
  1183. mark._plot(split_generator, scales, orient)
  1184. # TODO is this the right place for this?
  1185. for view in self._subplots:
  1186. view["ax"].autoscale_view()
  1187. if layer["legend"]:
  1188. self._update_legend_contents(p, mark, data, scales, layer["label"])
  1189. def _unscale_coords(
  1190. self, subplots: list[dict], df: DataFrame, orient: str,
  1191. ) -> DataFrame:
  1192. # TODO do we still have numbers in the variable name at this point?
  1193. coord_cols = [c for c in df if re.match(r"^[xy]\D*$", str(c))]
  1194. out_df = (
  1195. df
  1196. .drop(coord_cols, axis=1)
  1197. .reindex(df.columns, axis=1) # So unscaled columns retain their place
  1198. .copy(deep=False)
  1199. )
  1200. for view in subplots:
  1201. view_df = self._filter_subplot_data(df, view)
  1202. axes_df = view_df[coord_cols]
  1203. for var, values in axes_df.items():
  1204. axis = getattr(view["ax"], f"{str(var)[0]}axis")
  1205. # TODO see https://github.com/matplotlib/matplotlib/issues/22713
  1206. transform = axis.get_transform().inverted().transform
  1207. inverted = transform(values)
  1208. out_df.loc[values.index, str(var)] = inverted
  1209. return out_df
  1210. def _generate_pairings(
  1211. self, data: PlotData, pair_variables: dict,
  1212. ) -> Generator[
  1213. tuple[list[dict], DataFrame, dict[str, Scale]], None, None
  1214. ]:
  1215. # TODO retype return with subplot_spec or similar
  1216. iter_axes = itertools.product(*[
  1217. pair_variables.get(axis, [axis]) for axis in "xy"
  1218. ])
  1219. for x, y in iter_axes:
  1220. subplots = []
  1221. for view in self._subplots:
  1222. if (view["x"] == x) and (view["y"] == y):
  1223. subplots.append(view)
  1224. if data.frame.empty and data.frames:
  1225. out_df = data.frames[(x, y)].copy()
  1226. elif not pair_variables:
  1227. out_df = data.frame.copy()
  1228. else:
  1229. if data.frame.empty and data.frames:
  1230. out_df = data.frames[(x, y)].copy()
  1231. else:
  1232. out_df = data.frame.copy()
  1233. scales = self._scales.copy()
  1234. if x in out_df:
  1235. scales["x"] = self._scales[x]
  1236. if y in out_df:
  1237. scales["y"] = self._scales[y]
  1238. for axis, var in zip("xy", (x, y)):
  1239. if axis != var:
  1240. out_df = out_df.rename(columns={var: axis})
  1241. cols = [col for col in out_df if re.match(rf"{axis}\d+", str(col))]
  1242. out_df = out_df.drop(cols, axis=1)
  1243. yield subplots, out_df, scales
  1244. def _get_subplot_index(self, df: DataFrame, subplot: dict) -> Index:
  1245. dims = df.columns.intersection(["col", "row"])
  1246. if dims.empty:
  1247. return df.index
  1248. keep_rows = pd.Series(True, df.index, dtype=bool)
  1249. for dim in dims:
  1250. keep_rows &= df[dim] == subplot[dim]
  1251. return df.index[keep_rows]
  1252. def _filter_subplot_data(self, df: DataFrame, subplot: dict) -> DataFrame:
  1253. # TODO note redundancies with preceding function ... needs refactoring
  1254. dims = df.columns.intersection(["col", "row"])
  1255. if dims.empty:
  1256. return df
  1257. keep_rows = pd.Series(True, df.index, dtype=bool)
  1258. for dim in dims:
  1259. keep_rows &= df[dim] == subplot[dim]
  1260. return df[keep_rows]
  1261. def _setup_split_generator(
  1262. self, grouping_vars: list[str], df: DataFrame, subplots: list[dict[str, Any]],
  1263. ) -> Callable[[], Generator]:
  1264. grouping_keys = []
  1265. grouping_vars = [
  1266. v for v in grouping_vars if v in df and v not in ["col", "row"]
  1267. ]
  1268. for var in grouping_vars:
  1269. order = getattr(self._scales[var], "order", None)
  1270. if order is None:
  1271. order = categorical_order(df[var])
  1272. grouping_keys.append(order)
  1273. def split_generator(keep_na=False) -> Generator:
  1274. for view in subplots:
  1275. axes_df = self._filter_subplot_data(df, view)
  1276. axes_df_inf_as_nan = axes_df.copy()
  1277. axes_df_inf_as_nan = axes_df_inf_as_nan.mask(
  1278. axes_df_inf_as_nan.isin([np.inf, -np.inf]), np.nan
  1279. )
  1280. if keep_na:
  1281. # The simpler thing to do would be x.dropna().reindex(x.index).
  1282. # But that doesn't work with the way that the subset iteration
  1283. # is written below, which assumes data for grouping vars.
  1284. # Matplotlib (usually?) masks nan data, so this should "work".
  1285. # Downstream code can also drop these rows, at some speed cost.
  1286. present = axes_df_inf_as_nan.notna().all(axis=1)
  1287. nulled = {}
  1288. for axis in "xy":
  1289. if axis in axes_df:
  1290. nulled[axis] = axes_df[axis].where(present)
  1291. axes_df = axes_df_inf_as_nan.assign(**nulled)
  1292. else:
  1293. axes_df = axes_df_inf_as_nan.dropna()
  1294. subplot_keys = {}
  1295. for dim in ["col", "row"]:
  1296. if view[dim] is not None:
  1297. subplot_keys[dim] = view[dim]
  1298. if not grouping_vars or not any(grouping_keys):
  1299. if not axes_df.empty:
  1300. yield subplot_keys, axes_df.copy(), view["ax"]
  1301. continue
  1302. grouped_df = axes_df.groupby(
  1303. grouping_vars, sort=False, as_index=False, observed=False,
  1304. )
  1305. for key in itertools.product(*grouping_keys):
  1306. # Pandas fails with singleton tuple inputs
  1307. pd_key = key[0] if len(key) == 1 else key
  1308. try:
  1309. df_subset = grouped_df.get_group(pd_key)
  1310. except KeyError:
  1311. # TODO (from initial work on categorical plots refactor)
  1312. # We are adding this to allow backwards compatability
  1313. # with the empty artists that old categorical plots would
  1314. # add (before 0.12), which we may decide to break, in which
  1315. # case this option could be removed
  1316. df_subset = axes_df.loc[[]]
  1317. if df_subset.empty:
  1318. continue
  1319. sub_vars = dict(zip(grouping_vars, key))
  1320. sub_vars.update(subplot_keys)
  1321. # TODO need copy(deep=...) policy (here, above, anywhere else?)
  1322. yield sub_vars, df_subset.copy(), view["ax"]
  1323. return split_generator
  1324. def _update_legend_contents(
  1325. self,
  1326. p: Plot,
  1327. mark: Mark,
  1328. data: PlotData,
  1329. scales: dict[str, Scale],
  1330. layer_label: str | None,
  1331. ) -> None:
  1332. """Add legend artists / labels for one layer in the plot."""
  1333. if data.frame.empty and data.frames:
  1334. legend_vars: list[str] = []
  1335. for frame in data.frames.values():
  1336. frame_vars = frame.columns.intersection(list(scales))
  1337. legend_vars.extend(v for v in frame_vars if v not in legend_vars)
  1338. else:
  1339. legend_vars = list(data.frame.columns.intersection(list(scales)))
  1340. # First handle layer legends, which occupy a single entry in legend_contents.
  1341. if layer_label is not None:
  1342. legend_title = str(p._labels.get("legend", ""))
  1343. layer_key = (legend_title, -1)
  1344. artist = mark._legend_artist([], None, {})
  1345. if artist is not None:
  1346. for content in self._legend_contents:
  1347. if content[0] == layer_key:
  1348. content[1].append(artist)
  1349. content[2].append(layer_label)
  1350. break
  1351. else:
  1352. self._legend_contents.append((layer_key, [artist], [layer_label]))
  1353. # Then handle the scale legends
  1354. # First pass: Identify the values that will be shown for each variable
  1355. schema: list[tuple[
  1356. tuple[str, str | int], list[str], tuple[list[Any], list[str]]
  1357. ]] = []
  1358. schema = []
  1359. for var in legend_vars:
  1360. var_legend = scales[var]._legend
  1361. if var_legend is not None:
  1362. values, labels = var_legend
  1363. for (_, part_id), part_vars, _ in schema:
  1364. if data.ids[var] == part_id:
  1365. # Allow multiple plot semantics to represent same data variable
  1366. part_vars.append(var)
  1367. break
  1368. else:
  1369. title = self._resolve_label(p, var, data.names[var])
  1370. entry = (title, data.ids[var]), [var], (values, labels)
  1371. schema.append(entry)
  1372. # Second pass, generate an artist corresponding to each value
  1373. contents: list[tuple[tuple[str, str | int], Any, list[str]]] = []
  1374. for key, variables, (values, labels) in schema:
  1375. artists = []
  1376. for val in values:
  1377. artist = mark._legend_artist(variables, val, scales)
  1378. if artist is not None:
  1379. artists.append(artist)
  1380. if artists:
  1381. contents.append((key, artists, labels))
  1382. self._legend_contents.extend(contents)
  1383. def _make_legend(self, p: Plot) -> None:
  1384. """Create the legend artist(s) and add onto the figure."""
  1385. # Combine artists representing same information across layers
  1386. # Input list has an entry for each distinct variable in each layer
  1387. # Output dict has an entry for each distinct variable
  1388. merged_contents: dict[
  1389. tuple[str, str | int], tuple[list[tuple[Artist, ...]], list[str]],
  1390. ] = {}
  1391. for key, new_artists, labels in self._legend_contents:
  1392. # Key is (name, id); we need the id to resolve variable uniqueness,
  1393. # but will need the name in the next step to title the legend
  1394. if key not in merged_contents:
  1395. # Matplotlib accepts a tuple of artists and will overlay them
  1396. new_artist_tuples = [tuple([a]) for a in new_artists]
  1397. merged_contents[key] = new_artist_tuples, labels
  1398. else:
  1399. existing_artists = merged_contents[key][0]
  1400. for i, new_artist in enumerate(new_artists):
  1401. existing_artists[i] += tuple([new_artist])
  1402. # When using pyplot, an "external" legend won't be shown, so this
  1403. # keeps it inside the axes (though still attached to the figure)
  1404. # This is necessary because matplotlib layout engines currently don't
  1405. # support figure legends — ideally this will change.
  1406. loc = "center right" if self._pyplot else "center left"
  1407. base_legend = None
  1408. for (name, _), (handles, labels) in merged_contents.items():
  1409. legend = mpl.legend.Legend(
  1410. self._figure,
  1411. handles, # type: ignore # matplotlib/issues/26639
  1412. labels,
  1413. title=name,
  1414. loc=loc,
  1415. bbox_to_anchor=(.98, .55),
  1416. )
  1417. if base_legend:
  1418. # Matplotlib has no public API for this so it is a bit of a hack.
  1419. # Ideally we'd define our own legend class with more flexibility,
  1420. # but that is a lot of work!
  1421. base_legend_box = base_legend.get_children()[0]
  1422. this_legend_box = legend.get_children()[0]
  1423. base_legend_box.get_children().extend(this_legend_box.get_children())
  1424. else:
  1425. base_legend = legend
  1426. self._figure.legends.append(legend)
  1427. def _finalize_figure(self, p: Plot) -> None:
  1428. for sub in self._subplots:
  1429. ax = sub["ax"]
  1430. for axis in "xy":
  1431. axis_key = sub[axis]
  1432. axis_obj = getattr(ax, f"{axis}axis")
  1433. # Axis limits
  1434. if axis_key in p._limits or axis in p._limits:
  1435. convert_units = getattr(ax, f"{axis}axis").convert_units
  1436. a, b = p._limits.get(axis_key) or p._limits[axis]
  1437. lo = a if a is None else convert_units(a)
  1438. hi = b if b is None else convert_units(b)
  1439. if isinstance(a, str):
  1440. lo = cast(float, lo) - 0.5
  1441. if isinstance(b, str):
  1442. hi = cast(float, hi) + 0.5
  1443. ax.set(**{f"{axis}lim": (lo, hi)})
  1444. if axis_key in self._scales: # TODO when would it not be?
  1445. self._scales[axis_key]._finalize(p, axis_obj)
  1446. if (engine := p._layout_spec.get("engine", default)) is not default:
  1447. # None is a valid arg for Figure.set_layout_engine, hence `default`
  1448. set_layout_engine(self._figure, engine)
  1449. elif p._target is None:
  1450. # Don't modify the layout engine if the user supplied their own
  1451. # matplotlib figure and didn't specify an engine through Plot
  1452. # TODO switch default to "constrained"?
  1453. # TODO either way, make configurable
  1454. set_layout_engine(self._figure, "tight")