subplots.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. from __future__ import annotations
  2. from collections.abc import Generator
  3. import numpy as np
  4. import matplotlib as mpl
  5. import matplotlib.pyplot as plt
  6. from matplotlib.axes import Axes
  7. from matplotlib.figure import Figure
  8. from typing import TYPE_CHECKING
  9. if TYPE_CHECKING: # TODO move to seaborn._core.typing?
  10. from seaborn._core.plot import FacetSpec, PairSpec
  11. from matplotlib.figure import SubFigure
  12. class Subplots:
  13. """
  14. Interface for creating and using matplotlib subplots based on seaborn parameters.
  15. Parameters
  16. ----------
  17. subplot_spec : dict
  18. Keyword args for :meth:`matplotlib.figure.Figure.subplots`.
  19. facet_spec : dict
  20. Parameters that control subplot faceting.
  21. pair_spec : dict
  22. Parameters that control subplot pairing.
  23. data : PlotData
  24. Data used to define figure setup.
  25. """
  26. def __init__(
  27. self,
  28. subplot_spec: dict, # TODO define as TypedDict
  29. facet_spec: FacetSpec,
  30. pair_spec: PairSpec,
  31. ):
  32. self.subplot_spec = subplot_spec
  33. self._check_dimension_uniqueness(facet_spec, pair_spec)
  34. self._determine_grid_dimensions(facet_spec, pair_spec)
  35. self._handle_wrapping(facet_spec, pair_spec)
  36. self._determine_axis_sharing(pair_spec)
  37. def _check_dimension_uniqueness(
  38. self, facet_spec: FacetSpec, pair_spec: PairSpec
  39. ) -> None:
  40. """Reject specs that pair and facet on (or wrap to) same figure dimension."""
  41. err = None
  42. facet_vars = facet_spec.get("variables", {})
  43. if facet_spec.get("wrap") and {"col", "row"} <= set(facet_vars):
  44. err = "Cannot wrap facets when specifying both `col` and `row`."
  45. elif (
  46. pair_spec.get("wrap")
  47. and pair_spec.get("cross", True)
  48. and len(pair_spec.get("structure", {}).get("x", [])) > 1
  49. and len(pair_spec.get("structure", {}).get("y", [])) > 1
  50. ):
  51. err = "Cannot wrap subplots when pairing on both `x` and `y`."
  52. collisions = {"x": ["columns", "rows"], "y": ["rows", "columns"]}
  53. for pair_axis, (multi_dim, wrap_dim) in collisions.items():
  54. if pair_axis not in pair_spec.get("structure", {}):
  55. continue
  56. elif multi_dim[:3] in facet_vars:
  57. err = f"Cannot facet the {multi_dim} while pairing on `{pair_axis}``."
  58. elif wrap_dim[:3] in facet_vars and facet_spec.get("wrap"):
  59. err = f"Cannot wrap the {wrap_dim} while pairing on `{pair_axis}``."
  60. elif wrap_dim[:3] in facet_vars and pair_spec.get("wrap"):
  61. err = f"Cannot wrap the {multi_dim} while faceting the {wrap_dim}."
  62. if err is not None:
  63. raise RuntimeError(err) # TODO what err class? Define PlotSpecError?
  64. def _determine_grid_dimensions(
  65. self, facet_spec: FacetSpec, pair_spec: PairSpec
  66. ) -> None:
  67. """Parse faceting and pairing information to define figure structure."""
  68. self.grid_dimensions: dict[str, list] = {}
  69. for dim, axis in zip(["col", "row"], ["x", "y"]):
  70. facet_vars = facet_spec.get("variables", {})
  71. if dim in facet_vars:
  72. self.grid_dimensions[dim] = facet_spec["structure"][dim]
  73. elif axis in pair_spec.get("structure", {}):
  74. self.grid_dimensions[dim] = [
  75. None for _ in pair_spec.get("structure", {})[axis]
  76. ]
  77. else:
  78. self.grid_dimensions[dim] = [None]
  79. self.subplot_spec[f"n{dim}s"] = len(self.grid_dimensions[dim])
  80. if not pair_spec.get("cross", True):
  81. self.subplot_spec["nrows"] = 1
  82. self.n_subplots = self.subplot_spec["ncols"] * self.subplot_spec["nrows"]
  83. def _handle_wrapping(
  84. self, facet_spec: FacetSpec, pair_spec: PairSpec
  85. ) -> None:
  86. """Update figure structure parameters based on facet/pair wrapping."""
  87. self.wrap = wrap = facet_spec.get("wrap") or pair_spec.get("wrap")
  88. if not wrap:
  89. return
  90. wrap_dim = "row" if self.subplot_spec["nrows"] > 1 else "col"
  91. flow_dim = {"row": "col", "col": "row"}[wrap_dim]
  92. n_subplots = self.subplot_spec[f"n{wrap_dim}s"]
  93. flow = int(np.ceil(n_subplots / wrap))
  94. if wrap < self.subplot_spec[f"n{wrap_dim}s"]:
  95. self.subplot_spec[f"n{wrap_dim}s"] = wrap
  96. self.subplot_spec[f"n{flow_dim}s"] = flow
  97. self.n_subplots = n_subplots
  98. self.wrap_dim = wrap_dim
  99. def _determine_axis_sharing(self, pair_spec: PairSpec) -> None:
  100. """Update subplot spec with default or specified axis sharing parameters."""
  101. axis_to_dim = {"x": "col", "y": "row"}
  102. key: str
  103. val: str | bool
  104. for axis in "xy":
  105. key = f"share{axis}"
  106. # Always use user-specified value, if present
  107. if key not in self.subplot_spec:
  108. if axis in pair_spec.get("structure", {}):
  109. # Paired axes are shared along one dimension by default
  110. if self.wrap is None and pair_spec.get("cross", True):
  111. val = axis_to_dim[axis]
  112. else:
  113. val = False
  114. else:
  115. # This will pick up faceted plots, as well as single subplot
  116. # figures, where the value doesn't really matter
  117. val = True
  118. self.subplot_spec[key] = val
  119. def init_figure(
  120. self,
  121. pair_spec: PairSpec,
  122. pyplot: bool = False,
  123. figure_kws: dict | None = None,
  124. target: Axes | Figure | SubFigure = None,
  125. ) -> Figure:
  126. """Initialize matplotlib objects and add seaborn-relevant metadata."""
  127. # TODO reduce need to pass pair_spec here?
  128. if figure_kws is None:
  129. figure_kws = {}
  130. if isinstance(target, mpl.axes.Axes):
  131. if max(self.subplot_spec["nrows"], self.subplot_spec["ncols"]) > 1:
  132. err = " ".join([
  133. "Cannot create multiple subplots after calling `Plot.on` with",
  134. f"a {mpl.axes.Axes} object.",
  135. ])
  136. try:
  137. err += f" You may want to use a {mpl.figure.SubFigure} instead."
  138. except AttributeError: # SubFigure added in mpl 3.4
  139. pass
  140. raise RuntimeError(err)
  141. self._subplot_list = [{
  142. "ax": target,
  143. "left": True,
  144. "right": True,
  145. "top": True,
  146. "bottom": True,
  147. "col": None,
  148. "row": None,
  149. "x": "x",
  150. "y": "y",
  151. }]
  152. self._figure = target.figure
  153. return self._figure
  154. elif (
  155. hasattr(mpl.figure, "SubFigure") # Added in mpl 3.4
  156. and isinstance(target, mpl.figure.SubFigure)
  157. ):
  158. figure = target.figure
  159. elif isinstance(target, mpl.figure.Figure):
  160. figure = target
  161. else:
  162. if pyplot:
  163. figure = plt.figure(**figure_kws)
  164. else:
  165. figure = mpl.figure.Figure(**figure_kws)
  166. target = figure
  167. self._figure = figure
  168. axs = target.subplots(**self.subplot_spec, squeeze=False)
  169. if self.wrap:
  170. # Remove unused Axes and flatten the rest into a (2D) vector
  171. axs_flat = axs.ravel({"col": "C", "row": "F"}[self.wrap_dim])
  172. axs, extra = np.split(axs_flat, [self.n_subplots])
  173. for ax in extra:
  174. ax.remove()
  175. if self.wrap_dim == "col":
  176. axs = axs[np.newaxis, :]
  177. else:
  178. axs = axs[:, np.newaxis]
  179. # Get i, j coordinates for each Axes object
  180. # Note that i, j are with respect to faceting/pairing,
  181. # not the subplot grid itself, (which only matters in the case of wrapping).
  182. iter_axs: np.ndenumerate | zip
  183. if not pair_spec.get("cross", True):
  184. indices = np.arange(self.n_subplots)
  185. iter_axs = zip(zip(indices, indices), axs.flat)
  186. else:
  187. iter_axs = np.ndenumerate(axs)
  188. self._subplot_list = []
  189. for (i, j), ax in iter_axs:
  190. info = {"ax": ax}
  191. nrows, ncols = self.subplot_spec["nrows"], self.subplot_spec["ncols"]
  192. if not self.wrap:
  193. info["left"] = j % ncols == 0
  194. info["right"] = (j + 1) % ncols == 0
  195. info["top"] = i == 0
  196. info["bottom"] = i == nrows - 1
  197. elif self.wrap_dim == "col":
  198. info["left"] = j % ncols == 0
  199. info["right"] = ((j + 1) % ncols == 0) or ((j + 1) == self.n_subplots)
  200. info["top"] = j < ncols
  201. info["bottom"] = j >= (self.n_subplots - ncols)
  202. elif self.wrap_dim == "row":
  203. info["left"] = i < nrows
  204. info["right"] = i >= self.n_subplots - nrows
  205. info["top"] = i % nrows == 0
  206. info["bottom"] = ((i + 1) % nrows == 0) or ((i + 1) == self.n_subplots)
  207. if not pair_spec.get("cross", True):
  208. info["top"] = j < ncols
  209. info["bottom"] = j >= self.n_subplots - ncols
  210. for dim in ["row", "col"]:
  211. idx = {"row": i, "col": j}[dim]
  212. info[dim] = self.grid_dimensions[dim][idx]
  213. for axis in "xy":
  214. idx = {"x": j, "y": i}[axis]
  215. if axis in pair_spec.get("structure", {}):
  216. key = f"{axis}{idx}"
  217. else:
  218. key = axis
  219. info[axis] = key
  220. self._subplot_list.append(info)
  221. return figure
  222. def __iter__(self) -> Generator[dict, None, None]: # TODO TypedDict?
  223. """Yield each subplot dictionary with Axes object and metadata."""
  224. yield from self._subplot_list
  225. def __len__(self) -> int:
  226. """Return the number of subplots in this figure."""
  227. return len(self._subplot_list)