axisgrid.py 86 KB


  1. from __future__ import annotations
  2. from itertools import product
  3. from inspect import signature
  4. import warnings
  5. from textwrap import dedent
  6. import numpy as np
  7. import pandas as pd
  8. import matplotlib as mpl
  9. import matplotlib.pyplot as plt
  10. from ._base import VectorPlotter, variable_type, categorical_order
  11. from ._core.data import handle_data_source
  12. from ._compat import share_axis, get_legend_handles
  13. from . import utils
  14. from .utils import (
  15. adjust_legend_subtitles,
  16. set_hls_values,
  17. _check_argument,
  18. _draw_figure,
  19. _disable_autolayout
  20. )
  21. from .palettes import color_palette, blend_palette
  22. from ._docstrings import (
  23. DocstringComponents,
  24. _core_docs,
  25. )
  26. __all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"]
  27. _param_docs = DocstringComponents.from_nested_components(
  28. core=_core_docs["params"],
  29. )
  30. class _BaseGrid:
  31. """Base class for grids of subplots."""
  32. def set(self, **kwargs):
  33. """Set attributes on each subplot Axes."""
  34. for ax in self.axes.flat:
  35. if ax is not None: # Handle removed axes
  36. ax.set(**kwargs)
  37. return self
  38. @property
  39. def fig(self):
  40. """DEPRECATED: prefer the `figure` property."""
  41. # Grid.figure is preferred because it matches the Axes attribute name.
  42. # But as the maintanace burden on having this property is minimal,
  43. # let's be slow about formally deprecating it. For now just note its deprecation
  44. # in the docstring; add a warning in version 0.13, and eventually remove it.
  45. return self._figure
  46. @property
  47. def figure(self):
  48. """Access the :class:`matplotlib.figure.Figure` object underlying the grid."""
  49. return self._figure
  50. def apply(self, func, *args, **kwargs):
  51. """
  52. Pass the grid to a user-supplied function and return self.
  53. The `func` must accept an object of this type for its first
  54. positional argument. Additional arguments are passed through.
  55. The return value of `func` is ignored; this method returns self.
  56. See the `pipe` method if you want the return value.
  57. Added in v0.12.0.
  58. """
  59. func(self, *args, **kwargs)
  60. return self
  61. def pipe(self, func, *args, **kwargs):
  62. """
  63. Pass the grid to a user-supplied function and return its value.
  64. The `func` must accept an object of this type for its first
  65. positional argument. Additional arguments are passed through.
  66. The return value of `func` becomes the return value of this method.
  67. See the `apply` method if you want to return self instead.
  68. Added in v0.12.0.
  69. """
  70. return func(self, *args, **kwargs)
  71. def savefig(self, *args, **kwargs):
  72. """
  73. Save an image of the plot.
  74. This wraps :meth:`matplotlib.figure.Figure.savefig`, using bbox_inches="tight"
  75. by default. Parameters are passed through to the matplotlib function.
  76. """
  77. kwargs = kwargs.copy()
  78. kwargs.setdefault("bbox_inches", "tight")
  79. self.figure.savefig(*args, **kwargs)
  80. class Grid(_BaseGrid):
  81. """A grid that can have multiple subplots and an external legend."""
  82. _margin_titles = False
  83. _legend_out = True
  84. def __init__(self):
  85. self._tight_layout_rect = [0, 0, 1, 1]
  86. self._tight_layout_pad = None
  87. # This attribute is set externally and is a hack to handle newer functions that
  88. # don't add proxy artists onto the Axes. We need an overall cleaner approach.
  89. self._extract_legend_handles = False
  90. def tight_layout(self, *args, **kwargs):
  91. """Call fig.tight_layout within rect that exclude the legend."""
  92. kwargs = kwargs.copy()
  93. kwargs.setdefault("rect", self._tight_layout_rect)
  94. if self._tight_layout_pad is not None:
  95. kwargs.setdefault("pad", self._tight_layout_pad)
  96. self._figure.tight_layout(*args, **kwargs)
  97. return self
  98. def add_legend(self, legend_data=None, title=None, label_order=None,
  99. adjust_subtitles=False, **kwargs):
  100. """Draw a legend, maybe placing it outside axes and resizing the figure.
  101. Parameters
  102. ----------
  103. legend_data : dict
  104. Dictionary mapping label names (or two-element tuples where the
  105. second element is a label name) to matplotlib artist handles. The
  106. default reads from ``self._legend_data``.
  107. title : string
  108. Title for the legend. The default reads from ``self._hue_var``.
  109. label_order : list of labels
  110. The order that the legend entries should appear in. The default
  111. reads from ``self.hue_names``.
  112. adjust_subtitles : bool
  113. If True, modify entries with invisible artists to left-align
  114. the labels and set the font size to that of a title.
  115. kwargs : key, value pairings
  116. Other keyword arguments are passed to the underlying legend methods
  117. on the Figure or Axes object.
  118. Returns
  119. -------
  120. self : Grid instance
  121. Returns self for easy chaining.
  122. """
  123. # Find the data for the legend
  124. if legend_data is None:
  125. legend_data = self._legend_data
  126. if label_order is None:
  127. if self.hue_names is None:
  128. label_order = list(legend_data.keys())
  129. else:
  130. label_order = list(map(utils.to_utf8, self.hue_names))
  131. blank_handle = mpl.patches.Patch(alpha=0, linewidth=0)
  132. handles = [legend_data.get(lab, blank_handle) for lab in label_order]
  133. title = self._hue_var if title is None else title
  134. title_size = mpl.rcParams["legend.title_fontsize"]
  135. # Unpack nested labels from a hierarchical legend
  136. labels = []
  137. for entry in label_order:
  138. if isinstance(entry, tuple):
  139. _, label = entry
  140. else:
  141. label = entry
  142. labels.append(label)
  143. # Set default legend kwargs
  144. kwargs.setdefault("scatterpoints", 1)
  145. if self._legend_out:
  146. kwargs.setdefault("frameon", False)
  147. kwargs.setdefault("loc", "center right")
  148. # Draw a full-figure legend outside the grid
  149. figlegend = self._figure.legend(handles, labels, **kwargs)
  150. self._legend = figlegend
  151. figlegend.set_title(title, prop={"size": title_size})
  152. if adjust_subtitles:
  153. adjust_legend_subtitles(figlegend)
  154. # Draw the plot to set the bounding boxes correctly
  155. _draw_figure(self._figure)
  156. # Calculate and set the new width of the figure so the legend fits
  157. legend_width = figlegend.get_window_extent().width / self._figure.dpi
  158. fig_width, fig_height = self._figure.get_size_inches()
  159. self._figure.set_size_inches(fig_width + legend_width, fig_height)
  160. # Draw the plot again to get the new transformations
  161. _draw_figure(self._figure)
  162. # Now calculate how much space we need on the right side
  163. legend_width = figlegend.get_window_extent().width / self._figure.dpi
  164. space_needed = legend_width / (fig_width + legend_width)
  165. margin = .04 if self._margin_titles else .01
  166. self._space_needed = margin + space_needed
  167. right = 1 - self._space_needed
  168. # Place the subplot axes to give space for the legend
  169. self._figure.subplots_adjust(right=right)
  170. self._tight_layout_rect[2] = right
  171. else:
  172. # Draw a legend in the first axis
  173. ax = self.axes.flat[0]
  174. kwargs.setdefault("loc", "best")
  175. leg = ax.legend(handles, labels, **kwargs)
  176. leg.set_title(title, prop={"size": title_size})
  177. self._legend = leg
  178. if adjust_subtitles:
  179. adjust_legend_subtitles(leg)
  180. return self
  181. def _update_legend_data(self, ax):
  182. """Extract the legend data from an axes object and save it."""
  183. data = {}
  184. # Get data directly from the legend, which is necessary
  185. # for newer functions that don't add labeled proxy artists
  186. if ax.legend_ is not None and self._extract_legend_handles:
  187. handles = get_legend_handles(ax.legend_)
  188. labels = [t.get_text() for t in ax.legend_.texts]
  189. data.update({label: handle for handle, label in zip(handles, labels)})
  190. handles, labels = ax.get_legend_handles_labels()
  191. data.update({label: handle for handle, label in zip(handles, labels)})
  192. self._legend_data.update(data)
  193. # Now clear the legend
  194. ax.legend_ = None
  195. def _get_palette(self, data, hue, hue_order, palette):
  196. """Get a list of colors for the hue variable."""
  197. if hue is None:
  198. palette = color_palette(n_colors=1)
  199. else:
  200. hue_names = categorical_order(data[hue], hue_order)
  201. n_colors = len(hue_names)
  202. # By default use either the current color palette or HUSL
  203. if palette is None:
  204. current_palette = utils.get_color_cycle()
  205. if n_colors > len(current_palette):
  206. colors = color_palette("husl", n_colors)
  207. else:
  208. colors = color_palette(n_colors=n_colors)
  209. # Allow for palette to map from hue variable names
  210. elif isinstance(palette, dict):
  211. color_names = [palette[h] for h in hue_names]
  212. colors = color_palette(color_names, n_colors)
  213. # Otherwise act as if we just got a list of colors
  214. else:
  215. colors = color_palette(palette, n_colors)
  216. palette = color_palette(colors, n_colors)
  217. return palette
  218. @property
  219. def legend(self):
  220. """The :class:`matplotlib.legend.Legend` object, if present."""
  221. try:
  222. return self._legend
  223. except AttributeError:
  224. return None
  225. def tick_params(self, axis='both', **kwargs):
  226. """Modify the ticks, tick labels, and gridlines.
  227. Parameters
  228. ----------
  229. axis : {'x', 'y', 'both'}
  230. The axis on which to apply the formatting.
  231. kwargs : keyword arguments
  232. Additional keyword arguments to pass to
  233. :meth:`matplotlib.axes.Axes.tick_params`.
  234. Returns
  235. -------
  236. self : Grid instance
  237. Returns self for easy chaining.
  238. """
  239. for ax in self.figure.axes:
  240. ax.tick_params(axis=axis, **kwargs)
  241. return self
  242. _facet_docs = dict(
  243. data=dedent("""\
  244. data : DataFrame
  245. Tidy ("long-form") dataframe where each column is a variable and each
  246. row is an observation.\
  247. """),
  248. rowcol=dedent("""\
  249. row, col : vectors or keys in ``data``
  250. Variables that define subsets to plot on different facets.\
  251. """),
  252. rowcol_order=dedent("""\
  253. {row,col}_order : vector of strings
  254. Specify the order in which levels of the ``row`` and/or ``col`` variables
  255. appear in the grid of subplots.\
  256. """),
  257. col_wrap=dedent("""\
  258. col_wrap : int
  259. "Wrap" the column variable at this width, so that the column facets
  260. span multiple rows. Incompatible with a ``row`` facet.\
  261. """),
  262. share_xy=dedent("""\
  263. share{x,y} : bool, 'col', or 'row' optional
  264. If true, the facets will share y axes across columns and/or x axes
  265. across rows.\
  266. """),
  267. height=dedent("""\
  268. height : scalar
  269. Height (in inches) of each facet. See also: ``aspect``.\
  270. """),
  271. aspect=dedent("""\
  272. aspect : scalar
  273. Aspect ratio of each facet, so that ``aspect * height`` gives the width
  274. of each facet in inches.\
  275. """),
  276. palette=dedent("""\
  277. palette : palette name, list, or dict
  278. Colors to use for the different levels of the ``hue`` variable. Should
  279. be something that can be interpreted by :func:`color_palette`, or a
  280. dictionary mapping hue levels to matplotlib colors.\
  281. """),
  282. legend_out=dedent("""\
  283. legend_out : bool
  284. If ``True``, the figure size will be extended, and the legend will be
  285. drawn outside the plot on the center right.\
  286. """),
  287. margin_titles=dedent("""\
  288. margin_titles : bool
  289. If ``True``, the titles for the row variable are drawn to the right of
  290. the last column. This option is experimental and may not work in all
  291. cases.\
  292. """),
  293. facet_kws=dedent("""\
  294. facet_kws : dict
  295. Additional parameters passed to :class:`FacetGrid`.
  296. """),
  297. )
  298. class FacetGrid(Grid):
  299. """Multi-plot grid for plotting conditional relationships."""
  300. def __init__(
  301. self, data, *,
  302. row=None, col=None, hue=None, col_wrap=None,
  303. sharex=True, sharey=True, height=3, aspect=1, palette=None,
  304. row_order=None, col_order=None, hue_order=None, hue_kws=None,
  305. dropna=False, legend_out=True, despine=True,
  306. margin_titles=False, xlim=None, ylim=None, subplot_kws=None,
  307. gridspec_kws=None,
  308. ):
  309. super().__init__()
  310. data = handle_data_source(data)
  311. # Determine the hue facet layer information
  312. hue_var = hue
  313. if hue is None:
  314. hue_names = None
  315. else:
  316. hue_names = categorical_order(data[hue], hue_order)
  317. colors = self._get_palette(data, hue, hue_order, palette)
  318. # Set up the lists of names for the row and column facet variables
  319. if row is None:
  320. row_names = []
  321. else:
  322. row_names = categorical_order(data[row], row_order)
  323. if col is None:
  324. col_names = []
  325. else:
  326. col_names = categorical_order(data[col], col_order)
  327. # Additional dict of kwarg -> list of values for mapping the hue var
  328. hue_kws = hue_kws if hue_kws is not None else {}
  329. # Make a boolean mask that is True anywhere there is an NA
  330. # value in one of the faceting variables, but only if dropna is True
  331. none_na = np.zeros(len(data), bool)
  332. if dropna:
  333. row_na = none_na if row is None else data[row].isnull()
  334. col_na = none_na if col is None else data[col].isnull()
  335. hue_na = none_na if hue is None else data[hue].isnull()
  336. not_na = ~(row_na | col_na | hue_na)
  337. else:
  338. not_na = ~none_na
  339. # Compute the grid shape
  340. ncol = 1 if col is None else len(col_names)
  341. nrow = 1 if row is None else len(row_names)
  342. self._n_facets = ncol * nrow
  343. self._col_wrap = col_wrap
  344. if col_wrap is not None:
  345. if row is not None:
  346. err = "Cannot use `row` and `col_wrap` together."
  347. raise ValueError(err)
  348. ncol = col_wrap
  349. nrow = int(np.ceil(len(col_names) / col_wrap))
  350. self._ncol = ncol
  351. self._nrow = nrow
  352. # Calculate the base figure size
  353. # This can get stretched later by a legend
  354. # TODO this doesn't account for axis labels
  355. figsize = (ncol * height * aspect, nrow * height)
  356. # Validate some inputs
  357. if col_wrap is not None:
  358. margin_titles = False
  359. # Build the subplot keyword dictionary
  360. subplot_kws = {} if subplot_kws is None else subplot_kws.copy()
  361. gridspec_kws = {} if gridspec_kws is None else gridspec_kws.copy()
  362. if xlim is not None:
  363. subplot_kws["xlim"] = xlim
  364. if ylim is not None:
  365. subplot_kws["ylim"] = ylim
  366. # --- Initialize the subplot grid
  367. with _disable_autolayout():
  368. fig = plt.figure(figsize=figsize)
  369. if col_wrap is None:
  370. kwargs = dict(squeeze=False,
  371. sharex=sharex, sharey=sharey,
  372. subplot_kw=subplot_kws,
  373. gridspec_kw=gridspec_kws)
  374. axes = fig.subplots(nrow, ncol, **kwargs)
  375. if col is None and row is None:
  376. axes_dict = {}
  377. elif col is None:
  378. axes_dict = dict(zip(row_names, axes.flat))
  379. elif row is None:
  380. axes_dict = dict(zip(col_names, axes.flat))
  381. else:
  382. facet_product = product(row_names, col_names)
  383. axes_dict = dict(zip(facet_product, axes.flat))
  384. else:
  385. # If wrapping the col variable we need to make the grid ourselves
  386. if gridspec_kws:
  387. warnings.warn("`gridspec_kws` ignored when using `col_wrap`")
  388. n_axes = len(col_names)
  389. axes = np.empty(n_axes, object)
  390. axes[0] = fig.add_subplot(nrow, ncol, 1, **subplot_kws)
  391. if sharex:
  392. subplot_kws["sharex"] = axes[0]
  393. if sharey:
  394. subplot_kws["sharey"] = axes[0]
  395. for i in range(1, n_axes):
  396. axes[i] = fig.add_subplot(nrow, ncol, i + 1, **subplot_kws)
  397. axes_dict = dict(zip(col_names, axes))
  398. # --- Set up the class attributes
  399. # Attributes that are part of the public API but accessed through
  400. # a property so that Sphinx adds them to the auto class doc
  401. self._figure = fig
  402. self._axes = axes
  403. self._axes_dict = axes_dict
  404. self._legend = None
  405. # Public attributes that aren't explicitly documented
  406. # (It's not obvious that having them be public was a good idea)
  407. self.data = data
  408. self.row_names = row_names
  409. self.col_names = col_names
  410. self.hue_names = hue_names
  411. self.hue_kws = hue_kws
  412. # Next the private variables
  413. self._nrow = nrow
  414. self._row_var = row
  415. self._ncol = ncol
  416. self._col_var = col
  417. self._margin_titles = margin_titles
  418. self._margin_titles_texts = []
  419. self._col_wrap = col_wrap
  420. self._hue_var = hue_var
  421. self._colors = colors
  422. self._legend_out = legend_out
  423. self._legend_data = {}
  424. self._x_var = None
  425. self._y_var = None
  426. self._sharex = sharex
  427. self._sharey = sharey
  428. self._dropna = dropna
  429. self._not_na = not_na
  430. # --- Make the axes look good
  431. self.set_titles()
  432. self.tight_layout()
  433. if despine:
  434. self.despine()
  435. if sharex in [True, 'col']:
  436. for ax in self._not_bottom_axes:
  437. for label in ax.get_xticklabels():
  438. label.set_visible(False)
  439. ax.xaxis.offsetText.set_visible(False)
  440. ax.xaxis.label.set_visible(False)
  441. if sharey in [True, 'row']:
  442. for ax in self._not_left_axes:
  443. for label in ax.get_yticklabels():
  444. label.set_visible(False)
  445. ax.yaxis.offsetText.set_visible(False)
  446. ax.yaxis.label.set_visible(False)
  447. __init__.__doc__ = dedent("""\
  448. Initialize the matplotlib figure and FacetGrid object.
  449. This class maps a dataset onto multiple axes arrayed in a grid of rows
  450. and columns that correspond to *levels* of variables in the dataset.
  451. The plots it produces are often called "lattice", "trellis", or
  452. "small-multiple" graphics.
  453. It can also represent levels of a third variable with the ``hue``
  454. parameter, which plots different subsets of data in different colors.
  455. This uses color to resolve elements on a third dimension, but only
  456. draws subsets on top of each other and will not tailor the ``hue``
  457. parameter for the specific visualization the way that axes-level
  458. functions that accept ``hue`` will.
  459. The basic workflow is to initialize the :class:`FacetGrid` object with
  460. the dataset and the variables that are used to structure the grid. Then
  461. one or more plotting functions can be applied to each subset by calling
  462. :meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe`. Finally, the
  463. plot can be tweaked with other methods to do things like change the
  464. axis labels, use different ticks, or add a legend. See the detailed
  465. code examples below for more information.
  466. .. warning::
  467. When using seaborn functions that infer semantic mappings from a
  468. dataset, care must be taken to synchronize those mappings across
  469. facets (e.g., by defining the ``hue`` mapping with a palette dict or
  470. setting the data type of the variables to ``category``). In most cases,
  471. it will be better to use a figure-level function (e.g. :func:`relplot`
  472. or :func:`catplot`) than to use :class:`FacetGrid` directly.
  473. See the :ref:`tutorial <grid_tutorial>` for more information.
  474. Parameters
  475. ----------
  476. {data}
  477. row, col, hue : strings
  478. Variables that define subsets of the data, which will be drawn on
  479. separate facets in the grid. See the ``{{var}}_order`` parameters to
  480. control the order of levels of this variable.
  481. {col_wrap}
  482. {share_xy}
  483. {height}
  484. {aspect}
  485. {palette}
  486. {{row,col,hue}}_order : lists
  487. Order for the levels of the faceting variables. By default, this
  488. will be the order that the levels appear in ``data`` or, if the
  489. variables are pandas categoricals, the category order.
  490. hue_kws : dictionary of param -> list of values mapping
  491. Other keyword arguments to insert into the plotting call to let
  492. other plot attributes vary across levels of the hue variable (e.g.
  493. the markers in a scatterplot).
  494. {legend_out}
  495. despine : boolean
  496. Remove the top and right spines from the plots.
  497. {margin_titles}
  498. {{x, y}}lim: tuples
  499. Limits for each of the axes on each facet (only relevant when
  500. share{{x, y}} is True).
  501. subplot_kws : dict
  502. Dictionary of keyword arguments passed to matplotlib subplot(s)
  503. methods.
  504. gridspec_kws : dict
  505. Dictionary of keyword arguments passed to
  506. :class:`matplotlib.gridspec.GridSpec`
  507. (via :meth:`matplotlib.figure.Figure.subplots`).
  508. Ignored if ``col_wrap`` is not ``None``.
  509. See Also
  510. --------
  511. PairGrid : Subplot grid for plotting pairwise relationships
  512. relplot : Combine a relational plot and a :class:`FacetGrid`
  513. displot : Combine a distribution plot and a :class:`FacetGrid`
  514. catplot : Combine a categorical plot and a :class:`FacetGrid`
  515. lmplot : Combine a regression plot and a :class:`FacetGrid`
  516. Examples
  517. --------
  518. .. note::
  519. These examples use seaborn functions to demonstrate some of the
  520. advanced features of the class, but in most cases you will want
  521. to use figue-level functions (e.g. :func:`displot`, :func:`relplot`)
  522. to make the plots shown here.
  523. .. include:: ../docstrings/FacetGrid.rst
  524. """).format(**_facet_docs)
  525. def facet_data(self):
  526. """Generator for name indices and data subsets for each facet.
  527. Yields
  528. ------
  529. (i, j, k), data_ijk : tuple of ints, DataFrame
  530. The ints provide an index into the {row, col, hue}_names attribute,
  531. and the dataframe contains a subset of the full data corresponding
  532. to each facet. The generator yields subsets that correspond with
  533. the self.axes.flat iterator, or self.axes[i, j] when `col_wrap`
  534. is None.
  535. """
  536. data = self.data
  537. # Construct masks for the row variable
  538. if self.row_names:
  539. row_masks = [data[self._row_var] == n for n in self.row_names]
  540. else:
  541. row_masks = [np.repeat(True, len(self.data))]
  542. # Construct masks for the column variable
  543. if self.col_names:
  544. col_masks = [data[self._col_var] == n for n in self.col_names]
  545. else:
  546. col_masks = [np.repeat(True, len(self.data))]
  547. # Construct masks for the hue variable
  548. if self.hue_names:
  549. hue_masks = [data[self._hue_var] == n for n in self.hue_names]
  550. else:
  551. hue_masks = [np.repeat(True, len(self.data))]
  552. # Here is the main generator loop
  553. for (i, row), (j, col), (k, hue) in product(enumerate(row_masks),
  554. enumerate(col_masks),
  555. enumerate(hue_masks)):
  556. data_ijk = data[row & col & hue & self._not_na]
  557. yield (i, j, k), data_ijk
  558. def map(self, func, *args, **kwargs):
  559. """Apply a plotting function to each facet's subset of the data.
  560. Parameters
  561. ----------
  562. func : callable
  563. A plotting function that takes data and keyword arguments. It
  564. must plot to the currently active matplotlib Axes and take a
  565. `color` keyword argument. If faceting on the `hue` dimension,
  566. it must also take a `label` keyword argument.
  567. args : strings
  568. Column names in self.data that identify variables with data to
  569. plot. The data for each variable is passed to `func` in the
  570. order the variables are specified in the call.
  571. kwargs : keyword arguments
  572. All keyword arguments are passed to the plotting function.
  573. Returns
  574. -------
  575. self : object
  576. Returns self.
  577. """
  578. # If color was a keyword argument, grab it here
  579. kw_color = kwargs.pop("color", None)
  580. # How we use the function depends on where it comes from
  581. func_module = str(getattr(func, "__module__", ""))
  582. # Check for categorical plots without order information
  583. if func_module == "seaborn.categorical":
  584. if "order" not in kwargs:
  585. warning = ("Using the {} function without specifying "
  586. "`order` is likely to produce an incorrect "
  587. "plot.".format(func.__name__))
  588. warnings.warn(warning)
  589. if len(args) == 3 and "hue_order" not in kwargs:
  590. warning = ("Using the {} function without specifying "
  591. "`hue_order` is likely to produce an incorrect "
  592. "plot.".format(func.__name__))
  593. warnings.warn(warning)
  594. # Iterate over the data subsets
  595. for (row_i, col_j, hue_k), data_ijk in self.facet_data():
  596. # If this subset is null, move on
  597. if not data_ijk.values.size:
  598. continue
  599. # Get the current axis
  600. modify_state = not func_module.startswith("seaborn")
  601. ax = self.facet_axis(row_i, col_j, modify_state)
  602. # Decide what color to plot with
  603. kwargs["color"] = self._facet_color(hue_k, kw_color)
  604. # Insert the other hue aesthetics if appropriate
  605. for kw, val_list in self.hue_kws.items():
  606. kwargs[kw] = val_list[hue_k]
  607. # Insert a label in the keyword arguments for the legend
  608. if self._hue_var is not None:
  609. kwargs["label"] = utils.to_utf8(self.hue_names[hue_k])
  610. # Get the actual data we are going to plot with
  611. plot_data = data_ijk[list(args)]
  612. if self._dropna:
  613. plot_data = plot_data.dropna()
  614. plot_args = [v for k, v in plot_data.items()]
  615. # Some matplotlib functions don't handle pandas objects correctly
  616. if func_module.startswith("matplotlib"):
  617. plot_args = [v.values for v in plot_args]
  618. # Draw the plot
  619. self._facet_plot(func, ax, plot_args, kwargs)
  620. # Finalize the annotations and layout
  621. self._finalize_grid(args[:2])
  622. return self
  623. def map_dataframe(self, func, *args, **kwargs):
  624. """Like ``.map`` but passes args as strings and inserts data in kwargs.
  625. This method is suitable for plotting with functions that accept a
  626. long-form DataFrame as a `data` keyword argument and access the
  627. data in that DataFrame using string variable names.
  628. Parameters
  629. ----------
  630. func : callable
  631. A plotting function that takes data and keyword arguments. Unlike
  632. the `map` method, a function used here must "understand" Pandas
  633. objects. It also must plot to the currently active matplotlib Axes
  634. and take a `color` keyword argument. If faceting on the `hue`
  635. dimension, it must also take a `label` keyword argument.
  636. args : strings
  637. Column names in self.data that identify variables with data to
  638. plot. The data for each variable is passed to `func` in the
  639. order the variables are specified in the call.
  640. kwargs : keyword arguments
  641. All keyword arguments are passed to the plotting function.
  642. Returns
  643. -------
  644. self : object
  645. Returns self.
  646. """
  647. # If color was a keyword argument, grab it here
  648. kw_color = kwargs.pop("color", None)
  649. # Iterate over the data subsets
  650. for (row_i, col_j, hue_k), data_ijk in self.facet_data():
  651. # If this subset is null, move on
  652. if not data_ijk.values.size:
  653. continue
  654. # Get the current axis
  655. modify_state = not str(func.__module__).startswith("seaborn")
  656. ax = self.facet_axis(row_i, col_j, modify_state)
  657. # Decide what color to plot with
  658. kwargs["color"] = self._facet_color(hue_k, kw_color)
  659. # Insert the other hue aesthetics if appropriate
  660. for kw, val_list in self.hue_kws.items():
  661. kwargs[kw] = val_list[hue_k]
  662. # Insert a label in the keyword arguments for the legend
  663. if self._hue_var is not None:
  664. kwargs["label"] = self.hue_names[hue_k]
  665. # Stick the facet dataframe into the kwargs
  666. if self._dropna:
  667. data_ijk = data_ijk.dropna()
  668. kwargs["data"] = data_ijk
  669. # Draw the plot
  670. self._facet_plot(func, ax, args, kwargs)
  671. # For axis labels, prefer to use positional args for backcompat
  672. # but also extract the x/y kwargs and use if no corresponding arg
  673. axis_labels = [kwargs.get("x", None), kwargs.get("y", None)]
  674. for i, val in enumerate(args[:2]):
  675. axis_labels[i] = val
  676. self._finalize_grid(axis_labels)
  677. return self
  678. def _facet_color(self, hue_index, kw_color):
  679. color = self._colors[hue_index]
  680. if kw_color is not None:
  681. return kw_color
  682. elif color is not None:
  683. return color
  684. def _facet_plot(self, func, ax, plot_args, plot_kwargs):
  685. # Draw the plot
  686. if str(func.__module__).startswith("seaborn"):
  687. plot_kwargs = plot_kwargs.copy()
  688. semantics = ["x", "y", "hue", "size", "style"]
  689. for key, val in zip(semantics, plot_args):
  690. plot_kwargs[key] = val
  691. plot_args = []
  692. plot_kwargs["ax"] = ax
  693. func(*plot_args, **plot_kwargs)
  694. # Sort out the supporting information
  695. self._update_legend_data(ax)
  696. def _finalize_grid(self, axlabels):
  697. """Finalize the annotations and layout."""
  698. self.set_axis_labels(*axlabels)
  699. self.tight_layout()
  700. def facet_axis(self, row_i, col_j, modify_state=True):
  701. """Make the axis identified by these indices active and return it."""
  702. # Calculate the actual indices of the axes to plot on
  703. if self._col_wrap is not None:
  704. ax = self.axes.flat[col_j]
  705. else:
  706. ax = self.axes[row_i, col_j]
  707. # Get a reference to the axes object we want, and make it active
  708. if modify_state:
  709. plt.sca(ax)
  710. return ax
  711. def despine(self, **kwargs):
  712. """Remove axis spines from the facets."""
  713. utils.despine(self._figure, **kwargs)
  714. return self
  715. def set_axis_labels(self, x_var=None, y_var=None, clear_inner=True, **kwargs):
  716. """Set axis labels on the left column and bottom row of the grid."""
  717. if x_var is not None:
  718. self._x_var = x_var
  719. self.set_xlabels(x_var, clear_inner=clear_inner, **kwargs)
  720. if y_var is not None:
  721. self._y_var = y_var
  722. self.set_ylabels(y_var, clear_inner=clear_inner, **kwargs)
  723. return self
  724. def set_xlabels(self, label=None, clear_inner=True, **kwargs):
  725. """Label the x axis on the bottom row of the grid."""
  726. if label is None:
  727. label = self._x_var
  728. for ax in self._bottom_axes:
  729. ax.set_xlabel(label, **kwargs)
  730. if clear_inner:
  731. for ax in self._not_bottom_axes:
  732. ax.set_xlabel("")
  733. return self
  734. def set_ylabels(self, label=None, clear_inner=True, **kwargs):
  735. """Label the y axis on the left column of the grid."""
  736. if label is None:
  737. label = self._y_var
  738. for ax in self._left_axes:
  739. ax.set_ylabel(label, **kwargs)
  740. if clear_inner:
  741. for ax in self._not_left_axes:
  742. ax.set_ylabel("")
  743. return self
  744. def set_xticklabels(self, labels=None, step=None, **kwargs):
  745. """Set x axis tick labels of the grid."""
  746. for ax in self.axes.flat:
  747. curr_ticks = ax.get_xticks()
  748. ax.set_xticks(curr_ticks)
  749. if labels is None:
  750. curr_labels = [label.get_text() for label in ax.get_xticklabels()]
  751. if step is not None:
  752. xticks = ax.get_xticks()[::step]
  753. curr_labels = curr_labels[::step]
  754. ax.set_xticks(xticks)
  755. ax.set_xticklabels(curr_labels, **kwargs)
  756. else:
  757. ax.set_xticklabels(labels, **kwargs)
  758. return self
  759. def set_yticklabels(self, labels=None, **kwargs):
  760. """Set y axis tick labels on the left column of the grid."""
  761. for ax in self.axes.flat:
  762. curr_ticks = ax.get_yticks()
  763. ax.set_yticks(curr_ticks)
  764. if labels is None:
  765. curr_labels = [label.get_text() for label in ax.get_yticklabels()]
  766. ax.set_yticklabels(curr_labels, **kwargs)
  767. else:
  768. ax.set_yticklabels(labels, **kwargs)
  769. return self
  770. def set_titles(self, template=None, row_template=None, col_template=None, **kwargs):
  771. """Draw titles either above each facet or on the grid margins.
  772. Parameters
  773. ----------
  774. template : string
  775. Template for all titles with the formatting keys {col_var} and
  776. {col_name} (if using a `col` faceting variable) and/or {row_var}
  777. and {row_name} (if using a `row` faceting variable).
  778. row_template:
  779. Template for the row variable when titles are drawn on the grid
  780. margins. Must have {row_var} and {row_name} formatting keys.
  781. col_template:
  782. Template for the column variable when titles are drawn on the grid
  783. margins. Must have {col_var} and {col_name} formatting keys.
  784. Returns
  785. -------
  786. self: object
  787. Returns self.
  788. """
  789. args = dict(row_var=self._row_var, col_var=self._col_var)
  790. kwargs["size"] = kwargs.pop("size", mpl.rcParams["axes.labelsize"])
  791. # Establish default templates
  792. if row_template is None:
  793. row_template = "{row_var} = {row_name}"
  794. if col_template is None:
  795. col_template = "{col_var} = {col_name}"
  796. if template is None:
  797. if self._row_var is None:
  798. template = col_template
  799. elif self._col_var is None:
  800. template = row_template
  801. else:
  802. template = " | ".join([row_template, col_template])
  803. row_template = utils.to_utf8(row_template)
  804. col_template = utils.to_utf8(col_template)
  805. template = utils.to_utf8(template)
  806. if self._margin_titles:
  807. # Remove any existing title texts
  808. for text in self._margin_titles_texts:
  809. text.remove()
  810. self._margin_titles_texts = []
  811. if self.row_names is not None:
  812. # Draw the row titles on the right edge of the grid
  813. for i, row_name in enumerate(self.row_names):
  814. ax = self.axes[i, -1]
  815. args.update(dict(row_name=row_name))
  816. title = row_template.format(**args)
  817. text = ax.annotate(
  818. title, xy=(1.02, .5), xycoords="axes fraction",
  819. rotation=270, ha="left", va="center",
  820. **kwargs
  821. )
  822. self._margin_titles_texts.append(text)
  823. if self.col_names is not None:
  824. # Draw the column titles as normal titles
  825. for j, col_name in enumerate(self.col_names):
  826. args.update(dict(col_name=col_name))
  827. title = col_template.format(**args)
  828. self.axes[0, j].set_title(title, **kwargs)
  829. return self
  830. # Otherwise title each facet with all the necessary information
  831. if (self._row_var is not None) and (self._col_var is not None):
  832. for i, row_name in enumerate(self.row_names):
  833. for j, col_name in enumerate(self.col_names):
  834. args.update(dict(row_name=row_name, col_name=col_name))
  835. title = template.format(**args)
  836. self.axes[i, j].set_title(title, **kwargs)
  837. elif self.row_names is not None and len(self.row_names):
  838. for i, row_name in enumerate(self.row_names):
  839. args.update(dict(row_name=row_name))
  840. title = template.format(**args)
  841. self.axes[i, 0].set_title(title, **kwargs)
  842. elif self.col_names is not None and len(self.col_names):
  843. for i, col_name in enumerate(self.col_names):
  844. args.update(dict(col_name=col_name))
  845. title = template.format(**args)
  846. # Index the flat array so col_wrap works
  847. self.axes.flat[i].set_title(title, **kwargs)
  848. return self
  849. def refline(self, *, x=None, y=None, color='.5', linestyle='--', **line_kws):
  850. """Add a reference line(s) to each facet.
  851. Parameters
  852. ----------
  853. x, y : numeric
  854. Value(s) to draw the line(s) at.
  855. color : :mod:`matplotlib color <matplotlib.colors>`
  856. Specifies the color of the reference line(s). Pass ``color=None`` to
  857. use ``hue`` mapping.
  858. linestyle : str
  859. Specifies the style of the reference line(s).
  860. line_kws : key, value mappings
  861. Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline`
  862. when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y``
  863. is not None.
  864. Returns
  865. -------
  866. :class:`FacetGrid` instance
  867. Returns ``self`` for easy method chaining.
  868. """
  869. line_kws['color'] = color
  870. line_kws['linestyle'] = linestyle
  871. if x is not None:
  872. self.map(plt.axvline, x=x, **line_kws)
  873. if y is not None:
  874. self.map(plt.axhline, y=y, **line_kws)
  875. return self
  876. # ------ Properties that are part of the public API and documented by Sphinx
  877. @property
  878. def axes(self):
  879. """An array of the :class:`matplotlib.axes.Axes` objects in the grid."""
  880. return self._axes
  881. @property
  882. def ax(self):
  883. """The :class:`matplotlib.axes.Axes` when no faceting variables are assigned."""
  884. if self.axes.shape == (1, 1):
  885. return self.axes[0, 0]
  886. else:
  887. err = (
  888. "Use the `.axes` attribute when facet variables are assigned."
  889. )
  890. raise AttributeError(err)
  891. @property
  892. def axes_dict(self):
  893. """A mapping of facet names to corresponding :class:`matplotlib.axes.Axes`.
  894. If only one of ``row`` or ``col`` is assigned, each key is a string
  895. representing a level of that variable. If both facet dimensions are
  896. assigned, each key is a ``({row_level}, {col_level})`` tuple.
  897. """
  898. return self._axes_dict
  899. # ------ Private properties, that require some computation to get
  900. @property
  901. def _inner_axes(self):
  902. """Return a flat array of the inner axes."""
  903. if self._col_wrap is None:
  904. return self.axes[:-1, 1:].flat
  905. else:
  906. axes = []
  907. n_empty = self._nrow * self._ncol - self._n_facets
  908. for i, ax in enumerate(self.axes):
  909. append = (
  910. i % self._ncol
  911. and i < (self._ncol * (self._nrow - 1))
  912. and i < (self._ncol * (self._nrow - 1) - n_empty)
  913. )
  914. if append:
  915. axes.append(ax)
  916. return np.array(axes, object).flat
  917. @property
  918. def _left_axes(self):
  919. """Return a flat array of the left column of axes."""
  920. if self._col_wrap is None:
  921. return self.axes[:, 0].flat
  922. else:
  923. axes = []
  924. for i, ax in enumerate(self.axes):
  925. if not i % self._ncol:
  926. axes.append(ax)
  927. return np.array(axes, object).flat
  928. @property
  929. def _not_left_axes(self):
  930. """Return a flat array of axes that aren't on the left column."""
  931. if self._col_wrap is None:
  932. return self.axes[:, 1:].flat
  933. else:
  934. axes = []
  935. for i, ax in enumerate(self.axes):
  936. if i % self._ncol:
  937. axes.append(ax)
  938. return np.array(axes, object).flat
  939. @property
  940. def _bottom_axes(self):
  941. """Return a flat array of the bottom row of axes."""
  942. if self._col_wrap is None:
  943. return self.axes[-1, :].flat
  944. else:
  945. axes = []
  946. n_empty = self._nrow * self._ncol - self._n_facets
  947. for i, ax in enumerate(self.axes):
  948. append = (
  949. i >= (self._ncol * (self._nrow - 1))
  950. or i >= (self._ncol * (self._nrow - 1) - n_empty)
  951. )
  952. if append:
  953. axes.append(ax)
  954. return np.array(axes, object).flat
  955. @property
  956. def _not_bottom_axes(self):
  957. """Return a flat array of axes that aren't on the bottom row."""
  958. if self._col_wrap is None:
  959. return self.axes[:-1, :].flat
  960. else:
  961. axes = []
  962. n_empty = self._nrow * self._ncol - self._n_facets
  963. for i, ax in enumerate(self.axes):
  964. append = (
  965. i < (self._ncol * (self._nrow - 1))
  966. and i < (self._ncol * (self._nrow - 1) - n_empty)
  967. )
  968. if append:
  969. axes.append(ax)
  970. return np.array(axes, object).flat
  971. class PairGrid(Grid):
  972. """Subplot grid for plotting pairwise relationships in a dataset.
  973. This object maps each variable in a dataset onto a column and row in a
  974. grid of multiple axes. Different axes-level plotting functions can be
  975. used to draw bivariate plots in the upper and lower triangles, and the
  976. marginal distribution of each variable can be shown on the diagonal.
  977. Several different common plots can be generated in a single line using
  978. :func:`pairplot`. Use :class:`PairGrid` when you need more flexibility.
  979. See the :ref:`tutorial <grid_tutorial>` for more information.
  980. """
  981. def __init__(
  982. self, data, *, hue=None, vars=None, x_vars=None, y_vars=None,
  983. hue_order=None, palette=None, hue_kws=None, corner=False, diag_sharey=True,
  984. height=2.5, aspect=1, layout_pad=.5, despine=True, dropna=False,
  985. ):
  986. """Initialize the plot figure and PairGrid object.
  987. Parameters
  988. ----------
  989. data : DataFrame
  990. Tidy (long-form) dataframe where each column is a variable and
  991. each row is an observation.
  992. hue : string (variable name)
  993. Variable in ``data`` to map plot aspects to different colors. This
  994. variable will be excluded from the default x and y variables.
  995. vars : list of variable names
  996. Variables within ``data`` to use, otherwise use every column with
  997. a numeric datatype.
  998. {x, y}_vars : lists of variable names
  999. Variables within ``data`` to use separately for the rows and
  1000. columns of the figure; i.e. to make a non-square plot.
  1001. hue_order : list of strings
  1002. Order for the levels of the hue variable in the palette
  1003. palette : dict or seaborn color palette
  1004. Set of colors for mapping the ``hue`` variable. If a dict, keys
  1005. should be values in the ``hue`` variable.
  1006. hue_kws : dictionary of param -> list of values mapping
  1007. Other keyword arguments to insert into the plotting call to let
  1008. other plot attributes vary across levels of the hue variable (e.g.
  1009. the markers in a scatterplot).
  1010. corner : bool
  1011. If True, don't add axes to the upper (off-diagonal) triangle of the
  1012. grid, making this a "corner" plot.
  1013. height : scalar
  1014. Height (in inches) of each facet.
  1015. aspect : scalar
  1016. Aspect * height gives the width (in inches) of each facet.
  1017. layout_pad : scalar
  1018. Padding between axes; passed to ``fig.tight_layout``.
  1019. despine : boolean
  1020. Remove the top and right spines from the plots.
  1021. dropna : boolean
  1022. Drop missing values from the data before plotting.
  1023. See Also
  1024. --------
  1025. pairplot : Easily drawing common uses of :class:`PairGrid`.
  1026. FacetGrid : Subplot grid for plotting conditional relationships.
  1027. Examples
  1028. --------
  1029. .. include:: ../docstrings/PairGrid.rst
  1030. """
  1031. super().__init__()
  1032. data = handle_data_source(data)
  1033. # Sort out the variables that define the grid
  1034. numeric_cols = self._find_numeric_cols(data)
  1035. if hue in numeric_cols:
  1036. numeric_cols.remove(hue)
  1037. if vars is not None:
  1038. x_vars = list(vars)
  1039. y_vars = list(vars)
  1040. if x_vars is None:
  1041. x_vars = numeric_cols
  1042. if y_vars is None:
  1043. y_vars = numeric_cols
  1044. if np.isscalar(x_vars):
  1045. x_vars = [x_vars]
  1046. if np.isscalar(y_vars):
  1047. y_vars = [y_vars]
  1048. self.x_vars = x_vars = list(x_vars)
  1049. self.y_vars = y_vars = list(y_vars)
  1050. self.square_grid = self.x_vars == self.y_vars
  1051. if not x_vars:
  1052. raise ValueError("No variables found for grid columns.")
  1053. if not y_vars:
  1054. raise ValueError("No variables found for grid rows.")
  1055. # Create the figure and the array of subplots
  1056. figsize = len(x_vars) * height * aspect, len(y_vars) * height
  1057. with _disable_autolayout():
  1058. fig = plt.figure(figsize=figsize)
  1059. axes = fig.subplots(len(y_vars), len(x_vars),
  1060. sharex="col", sharey="row",
  1061. squeeze=False)
  1062. # Possibly remove upper axes to make a corner grid
  1063. # Note: setting up the axes is usually the most time-intensive part
  1064. # of using the PairGrid. We are foregoing the speed improvement that
  1065. # we would get by just not setting up the hidden axes so that we can
  1066. # avoid implementing fig.subplots ourselves. But worth thinking about.
  1067. self._corner = corner
  1068. if corner:
  1069. hide_indices = np.triu_indices_from(axes, 1)
  1070. for i, j in zip(*hide_indices):
  1071. axes[i, j].remove()
  1072. axes[i, j] = None
  1073. self._figure = fig
  1074. self.axes = axes
  1075. self.data = data
  1076. # Save what we are going to do with the diagonal
  1077. self.diag_sharey = diag_sharey
  1078. self.diag_vars = None
  1079. self.diag_axes = None
  1080. self._dropna = dropna
  1081. # Label the axes
  1082. self._add_axis_labels()
  1083. # Sort out the hue variable
  1084. self._hue_var = hue
  1085. if hue is None:
  1086. self.hue_names = hue_order = ["_nolegend_"]
  1087. self.hue_vals = pd.Series(["_nolegend_"] * len(data),
  1088. index=data.index)
  1089. else:
  1090. # We need hue_order and hue_names because the former is used to control
  1091. # the order of drawing and the latter is used to control the order of
  1092. # the legend. hue_names can become string-typed while hue_order must
  1093. # retain the type of the input data. This is messy but results from
  1094. # the fact that PairGrid can implement the hue-mapping logic itself
  1095. # (and was originally written exclusively that way) but now can delegate
  1096. # to the axes-level functions, while always handling legend creation.
  1097. # See GH2307
  1098. hue_names = hue_order = categorical_order(data[hue], hue_order)
  1099. if dropna:
  1100. # Filter NA from the list of unique hue names
  1101. hue_names = list(filter(pd.notnull, hue_names))
  1102. self.hue_names = hue_names
  1103. self.hue_vals = data[hue]
  1104. # Additional dict of kwarg -> list of values for mapping the hue var
  1105. self.hue_kws = hue_kws if hue_kws is not None else {}
  1106. self._orig_palette = palette
  1107. self._hue_order = hue_order
  1108. self.palette = self._get_palette(data, hue, hue_order, palette)
  1109. self._legend_data = {}
  1110. # Make the plot look nice
  1111. for ax in axes[:-1, :].flat:
  1112. if ax is None:
  1113. continue
  1114. for label in ax.get_xticklabels():
  1115. label.set_visible(False)
  1116. ax.xaxis.offsetText.set_visible(False)
  1117. ax.xaxis.label.set_visible(False)
  1118. for ax in axes[:, 1:].flat:
  1119. if ax is None:
  1120. continue
  1121. for label in ax.get_yticklabels():
  1122. label.set_visible(False)
  1123. ax.yaxis.offsetText.set_visible(False)
  1124. ax.yaxis.label.set_visible(False)
  1125. self._tight_layout_rect = [.01, .01, .99, .99]
  1126. self._tight_layout_pad = layout_pad
  1127. self._despine = despine
  1128. if despine:
  1129. utils.despine(fig=fig)
  1130. self.tight_layout(pad=layout_pad)
  1131. def map(self, func, **kwargs):
  1132. """Plot with the same function in every subplot.
  1133. Parameters
  1134. ----------
  1135. func : callable plotting function
  1136. Must take x, y arrays as positional arguments and draw onto the
  1137. "currently active" matplotlib Axes. Also needs to accept kwargs
  1138. called ``color`` and ``label``.
  1139. """
  1140. row_indices, col_indices = np.indices(self.axes.shape)
  1141. indices = zip(row_indices.flat, col_indices.flat)
  1142. self._map_bivariate(func, indices, **kwargs)
  1143. return self
  1144. def map_lower(self, func, **kwargs):
  1145. """Plot with a bivariate function on the lower diagonal subplots.
  1146. Parameters
  1147. ----------
  1148. func : callable plotting function
  1149. Must take x, y arrays as positional arguments and draw onto the
  1150. "currently active" matplotlib Axes. Also needs to accept kwargs
  1151. called ``color`` and ``label``.
  1152. """
  1153. indices = zip(*np.tril_indices_from(self.axes, -1))
  1154. self._map_bivariate(func, indices, **kwargs)
  1155. return self
  1156. def map_upper(self, func, **kwargs):
  1157. """Plot with a bivariate function on the upper diagonal subplots.
  1158. Parameters
  1159. ----------
  1160. func : callable plotting function
  1161. Must take x, y arrays as positional arguments and draw onto the
  1162. "currently active" matplotlib Axes. Also needs to accept kwargs
  1163. called ``color`` and ``label``.
  1164. """
  1165. indices = zip(*np.triu_indices_from(self.axes, 1))
  1166. self._map_bivariate(func, indices, **kwargs)
  1167. return self
  1168. def map_offdiag(self, func, **kwargs):
  1169. """Plot with a bivariate function on the off-diagonal subplots.
  1170. Parameters
  1171. ----------
  1172. func : callable plotting function
  1173. Must take x, y arrays as positional arguments and draw onto the
  1174. "currently active" matplotlib Axes. Also needs to accept kwargs
  1175. called ``color`` and ``label``.
  1176. """
  1177. if self.square_grid:
  1178. self.map_lower(func, **kwargs)
  1179. if not self._corner:
  1180. self.map_upper(func, **kwargs)
  1181. else:
  1182. indices = []
  1183. for i, (y_var) in enumerate(self.y_vars):
  1184. for j, (x_var) in enumerate(self.x_vars):
  1185. if x_var != y_var:
  1186. indices.append((i, j))
  1187. self._map_bivariate(func, indices, **kwargs)
  1188. return self
  1189. def map_diag(self, func, **kwargs):
  1190. """Plot with a univariate function on each diagonal subplot.
  1191. Parameters
  1192. ----------
  1193. func : callable plotting function
  1194. Must take an x array as a positional argument and draw onto the
  1195. "currently active" matplotlib Axes. Also needs to accept kwargs
  1196. called ``color`` and ``label``.
  1197. """
  1198. # Add special diagonal axes for the univariate plot
  1199. if self.diag_axes is None:
  1200. diag_vars = []
  1201. diag_axes = []
  1202. for i, y_var in enumerate(self.y_vars):
  1203. for j, x_var in enumerate(self.x_vars):
  1204. if x_var == y_var:
  1205. # Make the density axes
  1206. diag_vars.append(x_var)
  1207. ax = self.axes[i, j]
  1208. diag_ax = ax.twinx()
  1209. diag_ax.set_axis_off()
  1210. diag_axes.append(diag_ax)
  1211. # Work around matplotlib bug
  1212. # https://github.com/matplotlib/matplotlib/issues/15188
  1213. if not plt.rcParams.get("ytick.left", True):
  1214. for tick in ax.yaxis.majorTicks:
  1215. tick.tick1line.set_visible(False)
  1216. # Remove main y axis from density axes in a corner plot
  1217. if self._corner:
  1218. ax.yaxis.set_visible(False)
  1219. if self._despine:
  1220. utils.despine(ax=ax, left=True)
  1221. # TODO add optional density ticks (on the right)
  1222. # when drawing a corner plot?
  1223. if self.diag_sharey and diag_axes:
  1224. for ax in diag_axes[1:]:
  1225. share_axis(diag_axes[0], ax, "y")
  1226. self.diag_vars = diag_vars
  1227. self.diag_axes = diag_axes
  1228. if "hue" not in signature(func).parameters:
  1229. return self._map_diag_iter_hue(func, **kwargs)
  1230. # Loop over diagonal variables and axes, making one plot in each
  1231. for var, ax in zip(self.diag_vars, self.diag_axes):
  1232. plot_kwargs = kwargs.copy()
  1233. if str(func.__module__).startswith("seaborn"):
  1234. plot_kwargs["ax"] = ax
  1235. else:
  1236. plt.sca(ax)
  1237. vector = self.data[var]
  1238. if self._hue_var is not None:
  1239. hue = self.data[self._hue_var]
  1240. else:
  1241. hue = None
  1242. if self._dropna:
  1243. not_na = vector.notna()
  1244. if hue is not None:
  1245. not_na &= hue.notna()
  1246. vector = vector[not_na]
  1247. if hue is not None:
  1248. hue = hue[not_na]
  1249. plot_kwargs.setdefault("hue", hue)
  1250. plot_kwargs.setdefault("hue_order", self._hue_order)
  1251. plot_kwargs.setdefault("palette", self._orig_palette)
  1252. func(x=vector, **plot_kwargs)
  1253. ax.legend_ = None
  1254. self._add_axis_labels()
  1255. return self
  1256. def _map_diag_iter_hue(self, func, **kwargs):
  1257. """Put marginal plot on each diagonal axes, iterating over hue."""
  1258. # Plot on each of the diagonal axes
  1259. fixed_color = kwargs.pop("color", None)
  1260. for var, ax in zip(self.diag_vars, self.diag_axes):
  1261. hue_grouped = self.data[var].groupby(self.hue_vals, observed=True)
  1262. plot_kwargs = kwargs.copy()
  1263. if str(func.__module__).startswith("seaborn"):
  1264. plot_kwargs["ax"] = ax
  1265. else:
  1266. plt.sca(ax)
  1267. for k, label_k in enumerate(self._hue_order):
  1268. # Attempt to get data for this level, allowing for empty
  1269. try:
  1270. data_k = hue_grouped.get_group(label_k)
  1271. except KeyError:
  1272. data_k = pd.Series([], dtype=float)
  1273. if fixed_color is None:
  1274. color = self.palette[k]
  1275. else:
  1276. color = fixed_color
  1277. if self._dropna:
  1278. data_k = utils.remove_na(data_k)
  1279. if str(func.__module__).startswith("seaborn"):
  1280. func(x=data_k, label=label_k, color=color, **plot_kwargs)
  1281. else:
  1282. func(data_k, label=label_k, color=color, **plot_kwargs)
  1283. self._add_axis_labels()
  1284. return self
  1285. def _map_bivariate(self, func, indices, **kwargs):
  1286. """Draw a bivariate plot on the indicated axes."""
  1287. # This is a hack to handle the fact that new distribution plots don't add
  1288. # their artists onto the axes. This is probably superior in general, but
  1289. # we'll need a better way to handle it in the axisgrid functions.
  1290. from .distributions import histplot, kdeplot
  1291. if func is histplot or func is kdeplot:
  1292. self._extract_legend_handles = True
  1293. kws = kwargs.copy() # Use copy as we insert other kwargs
  1294. for i, j in indices:
  1295. x_var = self.x_vars[j]
  1296. y_var = self.y_vars[i]
  1297. ax = self.axes[i, j]
  1298. if ax is None: # i.e. we are in corner mode
  1299. continue
  1300. self._plot_bivariate(x_var, y_var, ax, func, **kws)
  1301. self._add_axis_labels()
  1302. if "hue" in signature(func).parameters:
  1303. self.hue_names = list(self._legend_data)
  1304. def _plot_bivariate(self, x_var, y_var, ax, func, **kwargs):
  1305. """Draw a bivariate plot on the specified axes."""
  1306. if "hue" not in signature(func).parameters:
  1307. self._plot_bivariate_iter_hue(x_var, y_var, ax, func, **kwargs)
  1308. return
  1309. kwargs = kwargs.copy()
  1310. if str(func.__module__).startswith("seaborn"):
  1311. kwargs["ax"] = ax
  1312. else:
  1313. plt.sca(ax)
  1314. if x_var == y_var:
  1315. axes_vars = [x_var]
  1316. else:
  1317. axes_vars = [x_var, y_var]
  1318. if self._hue_var is not None and self._hue_var not in axes_vars:
  1319. axes_vars.append(self._hue_var)
  1320. data = self.data[axes_vars]
  1321. if self._dropna:
  1322. data = data.dropna()
  1323. x = data[x_var]
  1324. y = data[y_var]
  1325. if self._hue_var is None:
  1326. hue = None
  1327. else:
  1328. hue = data.get(self._hue_var)
  1329. if "hue" not in kwargs:
  1330. kwargs.update({
  1331. "hue": hue, "hue_order": self._hue_order, "palette": self._orig_palette,
  1332. })
  1333. func(x=x, y=y, **kwargs)
  1334. self._update_legend_data(ax)
  1335. def _plot_bivariate_iter_hue(self, x_var, y_var, ax, func, **kwargs):
  1336. """Draw a bivariate plot while iterating over hue subsets."""
  1337. kwargs = kwargs.copy()
  1338. if str(func.__module__).startswith("seaborn"):
  1339. kwargs["ax"] = ax
  1340. else:
  1341. plt.sca(ax)
  1342. if x_var == y_var:
  1343. axes_vars = [x_var]
  1344. else:
  1345. axes_vars = [x_var, y_var]
  1346. hue_grouped = self.data.groupby(self.hue_vals, observed=True)
  1347. for k, label_k in enumerate(self._hue_order):
  1348. kws = kwargs.copy()
  1349. # Attempt to get data for this level, allowing for empty
  1350. try:
  1351. data_k = hue_grouped.get_group(label_k)
  1352. except KeyError:
  1353. data_k = pd.DataFrame(columns=axes_vars,
  1354. dtype=float)
  1355. if self._dropna:
  1356. data_k = data_k[axes_vars].dropna()
  1357. x = data_k[x_var]
  1358. y = data_k[y_var]
  1359. for kw, val_list in self.hue_kws.items():
  1360. kws[kw] = val_list[k]
  1361. kws.setdefault("color", self.palette[k])
  1362. if self._hue_var is not None:
  1363. kws["label"] = label_k
  1364. if str(func.__module__).startswith("seaborn"):
  1365. func(x=x, y=y, **kws)
  1366. else:
  1367. func(x, y, **kws)
  1368. self._update_legend_data(ax)
  1369. def _add_axis_labels(self):
  1370. """Add labels to the left and bottom Axes."""
  1371. for ax, label in zip(self.axes[-1, :], self.x_vars):
  1372. ax.set_xlabel(label)
  1373. for ax, label in zip(self.axes[:, 0], self.y_vars):
  1374. ax.set_ylabel(label)
  1375. def _find_numeric_cols(self, data):
  1376. """Find which variables in a DataFrame are numeric."""
  1377. numeric_cols = []
  1378. for col in data:
  1379. if variable_type(data[col]) == "numeric":
  1380. numeric_cols.append(col)
  1381. return numeric_cols
  1382. class JointGrid(_BaseGrid):
  1383. """Grid for drawing a bivariate plot with marginal univariate plots.
  1384. Many plots can be drawn by using the figure-level interface :func:`jointplot`.
  1385. Use this class directly when you need more flexibility.
  1386. """
  1387. def __init__(
  1388. self, data=None, *,
  1389. x=None, y=None, hue=None,
  1390. height=6, ratio=5, space=.2,
  1391. palette=None, hue_order=None, hue_norm=None,
  1392. dropna=False, xlim=None, ylim=None, marginal_ticks=False,
  1393. ):
  1394. # Set up the subplot grid
  1395. f = plt.figure(figsize=(height, height))
  1396. gs = plt.GridSpec(ratio + 1, ratio + 1)
  1397. ax_joint = f.add_subplot(gs[1:, :-1])
  1398. ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint)
  1399. ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint)
  1400. self._figure = f
  1401. self.ax_joint = ax_joint
  1402. self.ax_marg_x = ax_marg_x
  1403. self.ax_marg_y = ax_marg_y
  1404. # Turn off tick visibility for the measure axis on the marginal plots
  1405. plt.setp(ax_marg_x.get_xticklabels(), visible=False)
  1406. plt.setp(ax_marg_y.get_yticklabels(), visible=False)
  1407. plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False)
  1408. plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False)
  1409. # Turn off the ticks on the density axis for the marginal plots
  1410. if not marginal_ticks:
  1411. plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
  1412. plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
  1413. plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
  1414. plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
  1415. plt.setp(ax_marg_x.get_yticklabels(), visible=False)
  1416. plt.setp(ax_marg_y.get_xticklabels(), visible=False)
  1417. plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False)
  1418. plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False)
  1419. ax_marg_x.yaxis.grid(False)
  1420. ax_marg_y.xaxis.grid(False)
  1421. # Process the input variables
  1422. p = VectorPlotter(data=data, variables=dict(x=x, y=y, hue=hue))
  1423. plot_data = p.plot_data.loc[:, p.plot_data.notna().any()]
  1424. # Possibly drop NA
  1425. if dropna:
  1426. plot_data = plot_data.dropna()
  1427. def get_var(var):
  1428. vector = plot_data.get(var, None)
  1429. if vector is not None:
  1430. vector = vector.rename(p.variables.get(var, None))
  1431. return vector
  1432. self.x = get_var("x")
  1433. self.y = get_var("y")
  1434. self.hue = get_var("hue")
  1435. for axis in "xy":
  1436. name = p.variables.get(axis, None)
  1437. if name is not None:
  1438. getattr(ax_joint, f"set_{axis}label")(name)
  1439. if xlim is not None:
  1440. ax_joint.set_xlim(xlim)
  1441. if ylim is not None:
  1442. ax_joint.set_ylim(ylim)
  1443. # Store the semantic mapping parameters for axes-level functions
  1444. self._hue_params = dict(palette=palette, hue_order=hue_order, hue_norm=hue_norm)
  1445. # Make the grid look nice
  1446. utils.despine(f)
  1447. if not marginal_ticks:
  1448. utils.despine(ax=ax_marg_x, left=True)
  1449. utils.despine(ax=ax_marg_y, bottom=True)
  1450. for axes in [ax_marg_x, ax_marg_y]:
  1451. for axis in [axes.xaxis, axes.yaxis]:
  1452. axis.label.set_visible(False)
  1453. f.tight_layout()
  1454. f.subplots_adjust(hspace=space, wspace=space)
  1455. def _inject_kwargs(self, func, kws, params):
  1456. """Add params to kws if they are accepted by func."""
  1457. func_params = signature(func).parameters
  1458. for key, val in params.items():
  1459. if key in func_params:
  1460. kws.setdefault(key, val)
  1461. def plot(self, joint_func, marginal_func, **kwargs):
  1462. """Draw the plot by passing functions for joint and marginal axes.
  1463. This method passes the ``kwargs`` dictionary to both functions. If you
  1464. need more control, call :meth:`JointGrid.plot_joint` and
  1465. :meth:`JointGrid.plot_marginals` directly with specific parameters.
  1466. Parameters
  1467. ----------
  1468. joint_func, marginal_func : callables
  1469. Functions to draw the bivariate and univariate plots. See methods
  1470. referenced above for information about the required characteristics
  1471. of these functions.
  1472. kwargs
  1473. Additional keyword arguments are passed to both functions.
  1474. Returns
  1475. -------
  1476. :class:`JointGrid` instance
  1477. Returns ``self`` for easy method chaining.
  1478. """
  1479. self.plot_marginals(marginal_func, **kwargs)
  1480. self.plot_joint(joint_func, **kwargs)
  1481. return self
  1482. def plot_joint(self, func, **kwargs):
  1483. """Draw a bivariate plot on the joint axes of the grid.
  1484. Parameters
  1485. ----------
  1486. func : plotting callable
  1487. If a seaborn function, it should accept ``x`` and ``y``. Otherwise,
  1488. it must accept ``x`` and ``y`` vectors of data as the first two
  1489. positional arguments, and it must plot on the "current" axes.
  1490. If ``hue`` was defined in the class constructor, the function must
  1491. accept ``hue`` as a parameter.
  1492. kwargs
  1493. Keyword argument are passed to the plotting function.
  1494. Returns
  1495. -------
  1496. :class:`JointGrid` instance
  1497. Returns ``self`` for easy method chaining.
  1498. """
  1499. kwargs = kwargs.copy()
  1500. if str(func.__module__).startswith("seaborn"):
  1501. kwargs["ax"] = self.ax_joint
  1502. else:
  1503. plt.sca(self.ax_joint)
  1504. if self.hue is not None:
  1505. kwargs["hue"] = self.hue
  1506. self._inject_kwargs(func, kwargs, self._hue_params)
  1507. if str(func.__module__).startswith("seaborn"):
  1508. func(x=self.x, y=self.y, **kwargs)
  1509. else:
  1510. func(self.x, self.y, **kwargs)
  1511. return self
  1512. def plot_marginals(self, func, **kwargs):
  1513. """Draw univariate plots on each marginal axes.
  1514. Parameters
  1515. ----------
  1516. func : plotting callable
  1517. If a seaborn function, it should accept ``x`` and ``y`` and plot
  1518. when only one of them is defined. Otherwise, it must accept a vector
  1519. of data as the first positional argument and determine its orientation
  1520. using the ``vertical`` parameter, and it must plot on the "current" axes.
  1521. If ``hue`` was defined in the class constructor, it must accept ``hue``
  1522. as a parameter.
  1523. kwargs
  1524. Keyword argument are passed to the plotting function.
  1525. Returns
  1526. -------
  1527. :class:`JointGrid` instance
  1528. Returns ``self`` for easy method chaining.
  1529. """
  1530. seaborn_func = (
  1531. str(func.__module__).startswith("seaborn")
  1532. # deprecated distplot has a legacy API, special case it
  1533. and not func.__name__ == "distplot"
  1534. )
  1535. func_params = signature(func).parameters
  1536. kwargs = kwargs.copy()
  1537. if self.hue is not None:
  1538. kwargs["hue"] = self.hue
  1539. self._inject_kwargs(func, kwargs, self._hue_params)
  1540. if "legend" in func_params:
  1541. kwargs.setdefault("legend", False)
  1542. if "orientation" in func_params:
  1543. # e.g. plt.hist
  1544. orient_kw_x = {"orientation": "vertical"}
  1545. orient_kw_y = {"orientation": "horizontal"}
  1546. elif "vertical" in func_params:
  1547. # e.g. sns.distplot (also how did this get backwards?)
  1548. orient_kw_x = {"vertical": False}
  1549. orient_kw_y = {"vertical": True}
  1550. if seaborn_func:
  1551. func(x=self.x, ax=self.ax_marg_x, **kwargs)
  1552. else:
  1553. plt.sca(self.ax_marg_x)
  1554. func(self.x, **orient_kw_x, **kwargs)
  1555. if seaborn_func:
  1556. func(y=self.y, ax=self.ax_marg_y, **kwargs)
  1557. else:
  1558. plt.sca(self.ax_marg_y)
  1559. func(self.y, **orient_kw_y, **kwargs)
  1560. self.ax_marg_x.yaxis.get_label().set_visible(False)
  1561. self.ax_marg_y.xaxis.get_label().set_visible(False)
  1562. return self
  1563. def refline(
  1564. self, *, x=None, y=None, joint=True, marginal=True,
  1565. color='.5', linestyle='--', **line_kws
  1566. ):
  1567. """Add a reference line(s) to joint and/or marginal axes.
  1568. Parameters
  1569. ----------
  1570. x, y : numeric
  1571. Value(s) to draw the line(s) at.
  1572. joint, marginal : bools
  1573. Whether to add the reference line(s) to the joint/marginal axes.
  1574. color : :mod:`matplotlib color <matplotlib.colors>`
  1575. Specifies the color of the reference line(s).
  1576. linestyle : str
  1577. Specifies the style of the reference line(s).
  1578. line_kws : key, value mappings
  1579. Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline`
  1580. when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y``
  1581. is not None.
  1582. Returns
  1583. -------
  1584. :class:`JointGrid` instance
  1585. Returns ``self`` for easy method chaining.
  1586. """
  1587. line_kws['color'] = color
  1588. line_kws['linestyle'] = linestyle
  1589. if x is not None:
  1590. if joint:
  1591. self.ax_joint.axvline(x, **line_kws)
  1592. if marginal:
  1593. self.ax_marg_x.axvline(x, **line_kws)
  1594. if y is not None:
  1595. if joint:
  1596. self.ax_joint.axhline(y, **line_kws)
  1597. if marginal:
  1598. self.ax_marg_y.axhline(y, **line_kws)
  1599. return self
  1600. def set_axis_labels(self, xlabel="", ylabel="", **kwargs):
  1601. """Set axis labels on the bivariate axes.
  1602. Parameters
  1603. ----------
  1604. xlabel, ylabel : strings
  1605. Label names for the x and y variables.
  1606. kwargs : key, value mappings
  1607. Other keyword arguments are passed to the following functions:
  1608. - :meth:`matplotlib.axes.Axes.set_xlabel`
  1609. - :meth:`matplotlib.axes.Axes.set_ylabel`
  1610. Returns
  1611. -------
  1612. :class:`JointGrid` instance
  1613. Returns ``self`` for easy method chaining.
  1614. """
  1615. self.ax_joint.set_xlabel(xlabel, **kwargs)
  1616. self.ax_joint.set_ylabel(ylabel, **kwargs)
  1617. return self
  1618. JointGrid.__init__.__doc__ = """\
  1619. Set up the grid of subplots and store data internally for easy plotting.
  1620. Parameters
  1621. ----------
  1622. {params.core.data}
  1623. {params.core.xy}
  1624. height : number
  1625. Size of each side of the figure in inches (it will be square).
  1626. ratio : number
  1627. Ratio of joint axes height to marginal axes height.
  1628. space : number
  1629. Space between the joint and marginal axes
  1630. dropna : bool
  1631. If True, remove missing observations before plotting.
  1632. {{x, y}}lim : pairs of numbers
  1633. Set axis limits to these values before plotting.
  1634. marginal_ticks : bool
  1635. If False, suppress ticks on the count/density axis of the marginal plots.
  1636. {params.core.hue}
  1637. Note: unlike in :class:`FacetGrid` or :class:`PairGrid`, the axes-level
  1638. functions must support ``hue`` to use it in :class:`JointGrid`.
  1639. {params.core.palette}
  1640. {params.core.hue_order}
  1641. {params.core.hue_norm}
  1642. See Also
  1643. --------
  1644. {seealso.jointplot}
  1645. {seealso.pairgrid}
  1646. {seealso.pairplot}
  1647. Examples
  1648. --------
  1649. .. include:: ../docstrings/JointGrid.rst
  1650. """.format(
  1651. params=_param_docs,
  1652. seealso=_core_docs["seealso"],
  1653. )
  1654. def pairplot(
  1655. data, *,
  1656. hue=None, hue_order=None, palette=None,
  1657. vars=None, x_vars=None, y_vars=None,
  1658. kind="scatter", diag_kind="auto", markers=None,
  1659. height=2.5, aspect=1, corner=False, dropna=False,
  1660. plot_kws=None, diag_kws=None, grid_kws=None, size=None,
  1661. ):
  1662. """Plot pairwise relationships in a dataset.
  1663. By default, this function will create a grid of Axes such that each numeric
  1664. variable in ``data`` will by shared across the y-axes across a single row and
  1665. the x-axes across a single column. The diagonal plots are treated
  1666. differently: a univariate distribution plot is drawn to show the marginal
  1667. distribution of the data in each column.
  1668. It is also possible to show a subset of variables or plot different
  1669. variables on the rows and columns.
  1670. This is a high-level interface for :class:`PairGrid` that is intended to
  1671. make it easy to draw a few common styles. You should use :class:`PairGrid`
  1672. directly if you need more flexibility.
  1673. Parameters
  1674. ----------
  1675. data : `pandas.DataFrame`
  1676. Tidy (long-form) dataframe where each column is a variable and
  1677. each row is an observation.
  1678. hue : name of variable in ``data``
  1679. Variable in ``data`` to map plot aspects to different colors.
  1680. hue_order : list of strings
  1681. Order for the levels of the hue variable in the palette
  1682. palette : dict or seaborn color palette
  1683. Set of colors for mapping the ``hue`` variable. If a dict, keys
  1684. should be values in the ``hue`` variable.
  1685. vars : list of variable names
  1686. Variables within ``data`` to use, otherwise use every column with
  1687. a numeric datatype.
  1688. {x, y}_vars : lists of variable names
  1689. Variables within ``data`` to use separately for the rows and
  1690. columns of the figure; i.e. to make a non-square plot.
  1691. kind : {'scatter', 'kde', 'hist', 'reg'}
  1692. Kind of plot to make.
  1693. diag_kind : {'auto', 'hist', 'kde', None}
  1694. Kind of plot for the diagonal subplots. If 'auto', choose based on
  1695. whether or not ``hue`` is used.
  1696. markers : single matplotlib marker code or list
  1697. Either the marker to use for all scatterplot points or a list of markers
  1698. with a length the same as the number of levels in the hue variable so that
  1699. differently colored points will also have different scatterplot
  1700. markers.
  1701. height : scalar
  1702. Height (in inches) of each facet.
  1703. aspect : scalar
  1704. Aspect * height gives the width (in inches) of each facet.
  1705. corner : bool
  1706. If True, don't add axes to the upper (off-diagonal) triangle of the
  1707. grid, making this a "corner" plot.
  1708. dropna : boolean
  1709. Drop missing values from the data before plotting.
  1710. {plot, diag, grid}_kws : dicts
  1711. Dictionaries of keyword arguments. ``plot_kws`` are passed to the
  1712. bivariate plotting function, ``diag_kws`` are passed to the univariate
  1713. plotting function, and ``grid_kws`` are passed to the :class:`PairGrid`
  1714. constructor.
  1715. Returns
  1716. -------
  1717. grid : :class:`PairGrid`
  1718. Returns the underlying :class:`PairGrid` instance for further tweaking.
  1719. See Also
  1720. --------
  1721. PairGrid : Subplot grid for more flexible plotting of pairwise relationships.
  1722. JointGrid : Grid for plotting joint and marginal distributions of two variables.
  1723. Examples
  1724. --------
  1725. .. include:: ../docstrings/pairplot.rst
  1726. """
  1727. # Avoid circular import
  1728. from .distributions import histplot, kdeplot
  1729. # Handle deprecations
  1730. if size is not None:
  1731. height = size
  1732. msg = ("The `size` parameter has been renamed to `height`; "
  1733. "please update your code.")
  1734. warnings.warn(msg, UserWarning)
  1735. if not isinstance(data, pd.DataFrame):
  1736. raise TypeError(
  1737. f"'data' must be pandas DataFrame object, not: {type(data)}")
  1738. plot_kws = {} if plot_kws is None else plot_kws.copy()
  1739. diag_kws = {} if diag_kws is None else diag_kws.copy()
  1740. grid_kws = {} if grid_kws is None else grid_kws.copy()
  1741. # Resolve "auto" diag kind
  1742. if diag_kind == "auto":
  1743. if hue is None:
  1744. diag_kind = "kde" if kind == "kde" else "hist"
  1745. else:
  1746. diag_kind = "hist" if kind == "hist" else "kde"
  1747. # Set up the PairGrid
  1748. grid_kws.setdefault("diag_sharey", diag_kind == "hist")
  1749. grid = PairGrid(data, vars=vars, x_vars=x_vars, y_vars=y_vars, hue=hue,
  1750. hue_order=hue_order, palette=palette, corner=corner,
  1751. height=height, aspect=aspect, dropna=dropna, **grid_kws)
  1752. # Add the markers here as PairGrid has figured out how many levels of the
  1753. # hue variable are needed and we don't want to duplicate that process
  1754. if markers is not None:
  1755. if kind == "reg":
  1756. # Needed until regplot supports style
  1757. if grid.hue_names is None:
  1758. n_markers = 1
  1759. else:
  1760. n_markers = len(grid.hue_names)
  1761. if not isinstance(markers, list):
  1762. markers = [markers] * n_markers
  1763. if len(markers) != n_markers:
  1764. raise ValueError("markers must be a singleton or a list of "
  1765. "markers for each level of the hue variable")
  1766. grid.hue_kws = {"marker": markers}
  1767. elif kind == "scatter":
  1768. if isinstance(markers, str):
  1769. plot_kws["marker"] = markers
  1770. elif hue is not None:
  1771. plot_kws["style"] = data[hue]
  1772. plot_kws["markers"] = markers
  1773. # Draw the marginal plots on the diagonal
  1774. diag_kws = diag_kws.copy()
  1775. diag_kws.setdefault("legend", False)
  1776. if diag_kind == "hist":
  1777. grid.map_diag(histplot, **diag_kws)
  1778. elif diag_kind == "kde":
  1779. diag_kws.setdefault("fill", True)
  1780. diag_kws.setdefault("warn_singular", False)
  1781. grid.map_diag(kdeplot, **diag_kws)
  1782. # Maybe plot on the off-diagonals
  1783. if diag_kind is not None:
  1784. plotter = grid.map_offdiag
  1785. else:
  1786. plotter = grid.map
  1787. if kind == "scatter":
  1788. from .relational import scatterplot # Avoid circular import
  1789. plotter(scatterplot, **plot_kws)
  1790. elif kind == "reg":
  1791. from .regression import regplot # Avoid circular import
  1792. plotter(regplot, **plot_kws)
  1793. elif kind == "kde":
  1794. from .distributions import kdeplot # Avoid circular import
  1795. plot_kws.setdefault("warn_singular", False)
  1796. plotter(kdeplot, **plot_kws)
  1797. elif kind == "hist":
  1798. from .distributions import histplot # Avoid circular import
  1799. plotter(histplot, **plot_kws)
  1800. # Add a legend
  1801. if hue is not None:
  1802. grid.add_legend()
  1803. grid.tight_layout()
  1804. return grid
  1805. def jointplot(
  1806. data=None, *, x=None, y=None, hue=None, kind="scatter",
  1807. height=6, ratio=5, space=.2, dropna=False, xlim=None, ylim=None,
  1808. color=None, palette=None, hue_order=None, hue_norm=None, marginal_ticks=False,
  1809. joint_kws=None, marginal_kws=None,
  1810. **kwargs
  1811. ):
  1812. # Avoid circular imports
  1813. from .relational import scatterplot
  1814. from .regression import regplot, residplot
  1815. from .distributions import histplot, kdeplot, _freedman_diaconis_bins
  1816. if kwargs.pop("ax", None) is not None:
  1817. msg = "Ignoring `ax`; jointplot is a figure-level function."
  1818. warnings.warn(msg, UserWarning, stacklevel=2)
  1819. # Set up empty default kwarg dicts
  1820. joint_kws = {} if joint_kws is None else joint_kws.copy()
  1821. joint_kws.update(kwargs)
  1822. marginal_kws = {} if marginal_kws is None else marginal_kws.copy()
  1823. # Handle deprecations of distplot-specific kwargs
  1824. distplot_keys = [
  1825. "rug", "fit", "hist_kws", "norm_hist" "hist_kws", "rug_kws",
  1826. ]
  1827. unused_keys = []
  1828. for key in distplot_keys:
  1829. if key in marginal_kws:
  1830. unused_keys.append(key)
  1831. marginal_kws.pop(key)
  1832. if unused_keys and kind != "kde":
  1833. msg = (
  1834. "The marginal plotting function has changed to `histplot`,"
  1835. " which does not accept the following argument(s): {}."
  1836. ).format(", ".join(unused_keys))
  1837. warnings.warn(msg, UserWarning)
  1838. # Validate the plot kind
  1839. plot_kinds = ["scatter", "hist", "hex", "kde", "reg", "resid"]
  1840. _check_argument("kind", plot_kinds, kind)
  1841. # Raise early if using `hue` with a kind that does not support it
  1842. if hue is not None and kind in ["hex", "reg", "resid"]:
  1843. msg = f"Use of `hue` with `kind='{kind}'` is not currently supported."
  1844. raise ValueError(msg)
  1845. # Make a colormap based off the plot color
  1846. # (Currently used only for kind="hex")
  1847. if color is None:
  1848. color = "C0"
  1849. color_rgb = mpl.colors.colorConverter.to_rgb(color)
  1850. colors = [set_hls_values(color_rgb, l=val) for val in np.linspace(1, 0, 12)]
  1851. cmap = blend_palette(colors, as_cmap=True)
  1852. # Matplotlib's hexbin plot is not na-robust
  1853. if kind == "hex":
  1854. dropna = True
  1855. # Initialize the JointGrid object
  1856. grid = JointGrid(
  1857. data=data, x=x, y=y, hue=hue,
  1858. palette=palette, hue_order=hue_order, hue_norm=hue_norm,
  1859. dropna=dropna, height=height, ratio=ratio, space=space,
  1860. xlim=xlim, ylim=ylim, marginal_ticks=marginal_ticks,
  1861. )
  1862. if grid.hue is not None:
  1863. marginal_kws.setdefault("legend", False)
  1864. # Plot the data using the grid
  1865. if kind.startswith("scatter"):
  1866. joint_kws.setdefault("color", color)
  1867. grid.plot_joint(scatterplot, **joint_kws)
  1868. if grid.hue is None:
  1869. marg_func = histplot
  1870. else:
  1871. marg_func = kdeplot
  1872. marginal_kws.setdefault("warn_singular", False)
  1873. marginal_kws.setdefault("fill", True)
  1874. marginal_kws.setdefault("color", color)
  1875. grid.plot_marginals(marg_func, **marginal_kws)
  1876. elif kind.startswith("hist"):
  1877. # TODO process pair parameters for bins, etc. and pass
  1878. # to both joint and marginal plots
  1879. joint_kws.setdefault("color", color)
  1880. grid.plot_joint(histplot, **joint_kws)
  1881. marginal_kws.setdefault("kde", False)
  1882. marginal_kws.setdefault("color", color)
  1883. marg_x_kws = marginal_kws.copy()
  1884. marg_y_kws = marginal_kws.copy()
  1885. pair_keys = "bins", "binwidth", "binrange"
  1886. for key in pair_keys:
  1887. if isinstance(joint_kws.get(key), tuple):
  1888. x_val, y_val = joint_kws[key]
  1889. marg_x_kws.setdefault(key, x_val)
  1890. marg_y_kws.setdefault(key, y_val)
  1891. histplot(data=data, x=x, hue=hue, **marg_x_kws, ax=grid.ax_marg_x)
  1892. histplot(data=data, y=y, hue=hue, **marg_y_kws, ax=grid.ax_marg_y)
  1893. elif kind.startswith("kde"):
  1894. joint_kws.setdefault("color", color)
  1895. joint_kws.setdefault("warn_singular", False)
  1896. grid.plot_joint(kdeplot, **joint_kws)
  1897. marginal_kws.setdefault("color", color)
  1898. if "fill" in joint_kws:
  1899. marginal_kws.setdefault("fill", joint_kws["fill"])
  1900. grid.plot_marginals(kdeplot, **marginal_kws)
  1901. elif kind.startswith("hex"):
  1902. x_bins = min(_freedman_diaconis_bins(grid.x), 50)
  1903. y_bins = min(_freedman_diaconis_bins(grid.y), 50)
  1904. gridsize = int(np.mean([x_bins, y_bins]))
  1905. joint_kws.setdefault("gridsize", gridsize)
  1906. joint_kws.setdefault("cmap", cmap)
  1907. grid.plot_joint(plt.hexbin, **joint_kws)
  1908. marginal_kws.setdefault("kde", False)
  1909. marginal_kws.setdefault("color", color)
  1910. grid.plot_marginals(histplot, **marginal_kws)
  1911. elif kind.startswith("reg"):
  1912. marginal_kws.setdefault("color", color)
  1913. marginal_kws.setdefault("kde", True)
  1914. grid.plot_marginals(histplot, **marginal_kws)
  1915. joint_kws.setdefault("color", color)
  1916. grid.plot_joint(regplot, **joint_kws)
  1917. elif kind.startswith("resid"):
  1918. joint_kws.setdefault("color", color)
  1919. grid.plot_joint(residplot, **joint_kws)
  1920. x, y = grid.ax_joint.collections[0].get_offsets().T
  1921. marginal_kws.setdefault("color", color)
  1922. histplot(x=x, hue=hue, ax=grid.ax_marg_x, **marginal_kws)
  1923. histplot(y=y, hue=hue, ax=grid.ax_marg_y, **marginal_kws)
  1924. # Make the main axes active in the matplotlib state machine
  1925. plt.sca(grid.ax_joint)
  1926. return grid
  1927. jointplot.__doc__ = """\
  1928. Draw a plot of two variables with bivariate and univariate graphs.
  1929. This function provides a convenient interface to the :class:`JointGrid`
  1930. class, with several canned plot kinds. This is intended to be a fairly
  1931. lightweight wrapper; if you need more flexibility, you should use
  1932. :class:`JointGrid` directly.
  1933. Parameters
  1934. ----------
  1935. {params.core.data}
  1936. {params.core.xy}
  1937. {params.core.hue}
  1938. kind : {{ "scatter" | "kde" | "hist" | "hex" | "reg" | "resid" }}
  1939. Kind of plot to draw. See the examples for references to the underlying functions.
  1940. height : numeric
  1941. Size of the figure (it will be square).
  1942. ratio : numeric
  1943. Ratio of joint axes height to marginal axes height.
  1944. space : numeric
  1945. Space between the joint and marginal axes
  1946. dropna : bool
  1947. If True, remove observations that are missing from ``x`` and ``y``.
  1948. {{x, y}}lim : pairs of numbers
  1949. Axis limits to set before plotting.
  1950. {params.core.color}
  1951. {params.core.palette}
  1952. {params.core.hue_order}
  1953. {params.core.hue_norm}
  1954. marginal_ticks : bool
  1955. If False, suppress ticks on the count/density axis of the marginal plots.
  1956. {{joint, marginal}}_kws : dicts
  1957. Additional keyword arguments for the plot components.
  1958. kwargs
  1959. Additional keyword arguments are passed to the function used to
  1960. draw the plot on the joint Axes, superseding items in the
  1961. ``joint_kws`` dictionary.
  1962. Returns
  1963. -------
  1964. {returns.jointgrid}
  1965. See Also
  1966. --------
  1967. {seealso.jointgrid}
  1968. {seealso.pairgrid}
  1969. {seealso.pairplot}
  1970. Examples
  1971. --------
  1972. .. include:: ../docstrings/jointplot.rst
  1973. """.format(
  1974. params=_param_docs,
  1975. returns=_core_docs["returns"],
  1976. seealso=_core_docs["seealso"],
  1977. )