relational.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. from functools import partial
  2. import warnings
  3. import numpy as np
  4. import pandas as pd
  5. import matplotlib as mpl
  6. import matplotlib.pyplot as plt
  7. from ._base import (
  8. VectorPlotter,
  9. )
  10. from .utils import (
  11. adjust_legend_subtitles,
  12. _default_color,
  13. _deprecate_ci,
  14. _get_transform_functions,
  15. _normalize_kwargs,
  16. _scatter_legend_artist,
  17. )
  18. from ._statistics import EstimateAggregator
  19. from .axisgrid import FacetGrid, _facet_docs
  20. from ._docstrings import DocstringComponents, _core_docs
  21. __all__ = ["relplot", "scatterplot", "lineplot"]
  22. _relational_narrative = DocstringComponents(dict(
  23. # --- Introductory prose
  24. main_api="""
  25. The relationship between `x` and `y` can be shown for different subsets
  26. of the data using the `hue`, `size`, and `style` parameters. These
  27. parameters control what visual semantics are used to identify the different
  28. subsets. It is possible to show up to three dimensions independently by
  29. using all three semantic types, but this style of plot can be hard to
  30. interpret and is often ineffective. Using redundant semantics (i.e. both
  31. `hue` and `style` for the same variable) can be helpful for making
  32. graphics more accessible.
  33. See the :ref:`tutorial <relational_tutorial>` for more information.
  34. """,
  35. relational_semantic="""
  36. The default treatment of the `hue` (and to a lesser extent, `size`)
  37. semantic, if present, depends on whether the variable is inferred to
  38. represent "numeric" or "categorical" data. In particular, numeric variables
  39. are represented with a sequential colormap by default, and the legend
  40. entries show regular "ticks" with values that may or may not exist in the
  41. data. This behavior can be controlled through various parameters, as
  42. described and illustrated below.
  43. """,
  44. ))
  45. _relational_docs = dict(
  46. # --- Shared function parameters
  47. data_vars="""
  48. x, y : names of variables in `data` or vector data
  49. Input data variables; must be numeric. Can pass data directly or
  50. reference columns in `data`.
  51. """,
  52. data="""
  53. data : DataFrame, array, or list of arrays
  54. Input data structure. If `x` and `y` are specified as names, this
  55. should be a "long-form" DataFrame containing those columns. Otherwise
  56. it is treated as "wide-form" data and grouping variables are ignored.
  57. See the examples for the various ways this parameter can be specified
  58. and the different effects of each.
  59. """,
  60. palette="""
  61. palette : string, list, dict, or matplotlib colormap
  62. An object that determines how colors are chosen when `hue` is used.
  63. It can be the name of a seaborn palette or matplotlib colormap, a list
  64. of colors (anything matplotlib understands), a dict mapping levels
  65. of the `hue` variable to colors, or a matplotlib colormap object.
  66. """,
  67. hue_order="""
  68. hue_order : list
  69. Specified order for the appearance of the `hue` variable levels,
  70. otherwise they are determined from the data. Not relevant when the
  71. `hue` variable is numeric.
  72. """,
  73. hue_norm="""
  74. hue_norm : tuple or :class:`matplotlib.colors.Normalize` object
  75. Normalization in data units for colormap applied to the `hue`
  76. variable when it is numeric. Not relevant if `hue` is categorical.
  77. """,
  78. sizes="""
  79. sizes : list, dict, or tuple
  80. An object that determines how sizes are chosen when `size` is used.
  81. List or dict arguments should provide a size for each unique data value,
  82. which forces a categorical interpretation. The argument may also be a
  83. min, max tuple.
  84. """,
  85. size_order="""
  86. size_order : list
  87. Specified order for appearance of the `size` variable levels,
  88. otherwise they are determined from the data. Not relevant when the
  89. `size` variable is numeric.
  90. """,
  91. size_norm="""
  92. size_norm : tuple or Normalize object
  93. Normalization in data units for scaling plot objects when the
  94. `size` variable is numeric.
  95. """,
  96. dashes="""
  97. dashes : boolean, list, or dictionary
  98. Object determining how to draw the lines for different levels of the
  99. `style` variable. Setting to `True` will use default dash codes, or
  100. you can pass a list of dash codes or a dictionary mapping levels of the
  101. `style` variable to dash codes. Setting to `False` will use solid
  102. lines for all subsets. Dashes are specified as in matplotlib: a tuple
  103. of `(segment, gap)` lengths, or an empty string to draw a solid line.
  104. """,
  105. markers="""
  106. markers : boolean, list, or dictionary
  107. Object determining how to draw the markers for different levels of the
  108. `style` variable. Setting to `True` will use default markers, or
  109. you can pass a list of markers or a dictionary mapping levels of the
  110. `style` variable to markers. Setting to `False` will draw
  111. marker-less lines. Markers are specified as in matplotlib.
  112. """,
  113. style_order="""
  114. style_order : list
  115. Specified order for appearance of the `style` variable levels
  116. otherwise they are determined from the data. Not relevant when the
  117. `style` variable is numeric.
  118. """,
  119. units="""
  120. units : vector or key in `data`
  121. Grouping variable identifying sampling units. When used, a separate
  122. line will be drawn for each unit with appropriate semantics, but no
  123. legend entry will be added. Useful for showing distribution of
  124. experimental replicates when exact identities are not needed.
  125. """,
  126. estimator="""
  127. estimator : name of pandas method or callable or None
  128. Method for aggregating across multiple observations of the `y`
  129. variable at the same `x` level. If `None`, all observations will
  130. be drawn.
  131. """,
  132. ci="""
  133. ci : int or "sd" or None
  134. Size of the confidence interval to draw when aggregating.
  135. .. deprecated:: 0.12.0
  136. Use the new `errorbar` parameter for more flexibility.
  137. """,
  138. n_boot="""
  139. n_boot : int
  140. Number of bootstraps to use for computing the confidence interval.
  141. """,
  142. seed="""
  143. seed : int, numpy.random.Generator, or numpy.random.RandomState
  144. Seed or random number generator for reproducible bootstrapping.
  145. """,
  146. legend="""
  147. legend : "auto", "brief", "full", or False
  148. How to draw the legend. If "brief", numeric `hue` and `size`
  149. variables will be represented with a sample of evenly spaced values.
  150. If "full", every group will get an entry in the legend. If "auto",
  151. choose between brief or full representation based on number of levels.
  152. If `False`, no legend data is added and no legend is drawn.
  153. """,
  154. ax_in="""
  155. ax : matplotlib Axes
  156. Axes object to draw the plot onto, otherwise uses the current Axes.
  157. """,
  158. ax_out="""
  159. ax : matplotlib Axes
  160. Returns the Axes object with the plot drawn onto it.
  161. """,
  162. )
  163. _param_docs = DocstringComponents.from_nested_components(
  164. core=_core_docs["params"],
  165. facets=DocstringComponents(_facet_docs),
  166. rel=DocstringComponents(_relational_docs),
  167. stat=DocstringComponents.from_function_params(EstimateAggregator.__init__),
  168. )
  169. class _RelationalPlotter(VectorPlotter):
  170. wide_structure = {
  171. "x": "@index", "y": "@values", "hue": "@columns", "style": "@columns",
  172. }
  173. # TODO where best to define default parameters?
  174. sort = True
  175. class _LinePlotter(_RelationalPlotter):
  176. _legend_attributes = ["color", "linewidth", "marker", "dashes"]
  177. def __init__(
  178. self, *,
  179. data=None, variables={},
  180. estimator=None, n_boot=None, seed=None, errorbar=None,
  181. sort=True, orient="x", err_style=None, err_kws=None, legend=None
  182. ):
  183. # TODO this is messy, we want the mapping to be agnostic about
  184. # the kind of plot to draw, but for the time being we need to set
  185. # this information so the SizeMapping can use it
  186. self._default_size_range = (
  187. np.r_[.5, 2] * mpl.rcParams["lines.linewidth"]
  188. )
  189. super().__init__(data=data, variables=variables)
  190. self.estimator = estimator
  191. self.errorbar = errorbar
  192. self.n_boot = n_boot
  193. self.seed = seed
  194. self.sort = sort
  195. self.orient = orient
  196. self.err_style = err_style
  197. self.err_kws = {} if err_kws is None else err_kws
  198. self.legend = legend
  199. def plot(self, ax, kws):
  200. """Draw the plot onto an axes, passing matplotlib kwargs."""
  201. # Draw a test plot, using the passed in kwargs. The goal here is to
  202. # honor both (a) the current state of the plot cycler and (b) the
  203. # specified kwargs on all the lines we will draw, overriding when
  204. # relevant with the data semantics. Note that we won't cycle
  205. # internally; in other words, if `hue` is not used, all elements will
  206. # have the same color, but they will have the color that you would have
  207. # gotten from the corresponding matplotlib function, and calling the
  208. # function will advance the axes property cycle.
  209. kws = _normalize_kwargs(kws, mpl.lines.Line2D)
  210. kws.setdefault("markeredgewidth", 0.75)
  211. kws.setdefault("markeredgecolor", "w")
  212. # Set default error kwargs
  213. err_kws = self.err_kws.copy()
  214. if self.err_style == "band":
  215. err_kws.setdefault("alpha", .2)
  216. elif self.err_style == "bars":
  217. pass
  218. elif self.err_style is not None:
  219. err = "`err_style` must be 'band' or 'bars', not {}"
  220. raise ValueError(err.format(self.err_style))
  221. # Initialize the aggregation object
  222. agg = EstimateAggregator(
  223. self.estimator, self.errorbar, n_boot=self.n_boot, seed=self.seed,
  224. )
  225. # TODO abstract variable to aggregate over here-ish. Better name?
  226. orient = self.orient
  227. if orient not in {"x", "y"}:
  228. err = f"`orient` must be either 'x' or 'y', not {orient!r}."
  229. raise ValueError(err)
  230. other = {"x": "y", "y": "x"}[orient]
  231. # TODO How to handle NA? We don't want NA to propagate through to the
  232. # estimate/CI when some values are present, but we would also like
  233. # matplotlib to show "gaps" in the line when all values are missing.
  234. # This is straightforward absent aggregation, but complicated with it.
  235. # If we want to use nas, we need to conditionalize dropna in iter_data.
  236. # Loop over the semantic subsets and add to the plot
  237. grouping_vars = "hue", "size", "style"
  238. for sub_vars, sub_data in self.iter_data(grouping_vars, from_comp_data=True):
  239. if self.sort:
  240. sort_vars = ["units", orient, other]
  241. sort_cols = [var for var in sort_vars if var in self.variables]
  242. sub_data = sub_data.sort_values(sort_cols)
  243. if (
  244. self.estimator is not None
  245. and sub_data[orient].value_counts().max() > 1
  246. ):
  247. if "units" in self.variables:
  248. # TODO eventually relax this constraint
  249. err = "estimator must be None when specifying units"
  250. raise ValueError(err)
  251. grouped = sub_data.groupby(orient, sort=self.sort)
  252. # Could pass as_index=False instead of reset_index,
  253. # but that fails on a corner case with older pandas.
  254. sub_data = grouped.apply(agg, other).reset_index()
  255. else:
  256. sub_data[f"{other}min"] = np.nan
  257. sub_data[f"{other}max"] = np.nan
  258. # Apply inverse axis scaling
  259. for var in "xy":
  260. _, inv = _get_transform_functions(ax, var)
  261. for col in sub_data.filter(regex=f"^{var}"):
  262. sub_data[col] = inv(sub_data[col])
  263. # --- Draw the main line(s)
  264. if "units" in self.variables: # XXX why not add to grouping variables?
  265. lines = []
  266. for _, unit_data in sub_data.groupby("units"):
  267. lines.extend(ax.plot(unit_data["x"], unit_data["y"], **kws))
  268. else:
  269. lines = ax.plot(sub_data["x"], sub_data["y"], **kws)
  270. for line in lines:
  271. if "hue" in sub_vars:
  272. line.set_color(self._hue_map(sub_vars["hue"]))
  273. if "size" in sub_vars:
  274. line.set_linewidth(self._size_map(sub_vars["size"]))
  275. if "style" in sub_vars:
  276. attributes = self._style_map(sub_vars["style"])
  277. if "dashes" in attributes:
  278. line.set_dashes(attributes["dashes"])
  279. if "marker" in attributes:
  280. line.set_marker(attributes["marker"])
  281. line_color = line.get_color()
  282. line_alpha = line.get_alpha()
  283. line_capstyle = line.get_solid_capstyle()
  284. # --- Draw the confidence intervals
  285. if self.estimator is not None and self.errorbar is not None:
  286. # TODO handling of orientation will need to happen here
  287. if self.err_style == "band":
  288. func = {"x": ax.fill_between, "y": ax.fill_betweenx}[orient]
  289. func(
  290. sub_data[orient],
  291. sub_data[f"{other}min"], sub_data[f"{other}max"],
  292. color=line_color, **err_kws
  293. )
  294. elif self.err_style == "bars":
  295. error_param = {
  296. f"{other}err": (
  297. sub_data[other] - sub_data[f"{other}min"],
  298. sub_data[f"{other}max"] - sub_data[other],
  299. )
  300. }
  301. ebars = ax.errorbar(
  302. sub_data["x"], sub_data["y"], **error_param,
  303. linestyle="", color=line_color, alpha=line_alpha,
  304. **err_kws
  305. )
  306. # Set the capstyle properly on the error bars
  307. for obj in ebars.get_children():
  308. if isinstance(obj, mpl.collections.LineCollection):
  309. obj.set_capstyle(line_capstyle)
  310. # Finalize the axes details
  311. self._add_axis_labels(ax)
  312. if self.legend:
  313. legend_artist = partial(mpl.lines.Line2D, xdata=[], ydata=[])
  314. attrs = {"hue": "color", "size": "linewidth", "style": None}
  315. self.add_legend_data(ax, legend_artist, kws, attrs)
  316. handles, _ = ax.get_legend_handles_labels()
  317. if handles:
  318. legend = ax.legend(title=self.legend_title)
  319. adjust_legend_subtitles(legend)
  320. class _ScatterPlotter(_RelationalPlotter):
  321. _legend_attributes = ["color", "s", "marker"]
  322. def __init__(self, *, data=None, variables={}, legend=None):
  323. # TODO this is messy, we want the mapping to be agnostic about
  324. # the kind of plot to draw, but for the time being we need to set
  325. # this information so the SizeMapping can use it
  326. self._default_size_range = (
  327. np.r_[.5, 2] * np.square(mpl.rcParams["lines.markersize"])
  328. )
  329. super().__init__(data=data, variables=variables)
  330. self.legend = legend
  331. def plot(self, ax, kws):
  332. # --- Determine the visual attributes of the plot
  333. data = self.comp_data.dropna()
  334. if data.empty:
  335. return
  336. kws = _normalize_kwargs(kws, mpl.collections.PathCollection)
  337. # Define the vectors of x and y positions
  338. empty = np.full(len(data), np.nan)
  339. x = data.get("x", empty)
  340. y = data.get("y", empty)
  341. # Apply inverse scaling to the coordinate variables
  342. _, inv_x = _get_transform_functions(ax, "x")
  343. _, inv_y = _get_transform_functions(ax, "y")
  344. x, y = inv_x(x), inv_y(y)
  345. if "style" in self.variables:
  346. # Use a representative marker so scatter sets the edgecolor
  347. # properly for line art markers. We currently enforce either
  348. # all or none line art so this works.
  349. example_level = self._style_map.levels[0]
  350. example_marker = self._style_map(example_level, "marker")
  351. kws.setdefault("marker", example_marker)
  352. # Conditionally set the marker edgecolor based on whether the marker is "filled"
  353. # See https://github.com/matplotlib/matplotlib/issues/17849 for context
  354. m = kws.get("marker", mpl.rcParams.get("marker", "o"))
  355. if not isinstance(m, mpl.markers.MarkerStyle):
  356. # TODO in more recent matplotlib (which?) can pass a MarkerStyle here
  357. m = mpl.markers.MarkerStyle(m)
  358. if m.is_filled():
  359. kws.setdefault("edgecolor", "w")
  360. # Draw the scatter plot
  361. points = ax.scatter(x=x, y=y, **kws)
  362. # Apply the mapping from semantic variables to artist attributes
  363. if "hue" in self.variables:
  364. points.set_facecolors(self._hue_map(data["hue"]))
  365. if "size" in self.variables:
  366. points.set_sizes(self._size_map(data["size"]))
  367. if "style" in self.variables:
  368. p = [self._style_map(val, "path") for val in data["style"]]
  369. points.set_paths(p)
  370. # Apply dependent default attributes
  371. if "linewidth" not in kws:
  372. sizes = points.get_sizes()
  373. linewidth = .08 * np.sqrt(np.percentile(sizes, 10))
  374. points.set_linewidths(linewidth)
  375. kws["linewidth"] = linewidth
  376. # Finalize the axes details
  377. self._add_axis_labels(ax)
  378. if self.legend:
  379. attrs = {"hue": "color", "size": "s", "style": None}
  380. self.add_legend_data(ax, _scatter_legend_artist, kws, attrs)
  381. handles, _ = ax.get_legend_handles_labels()
  382. if handles:
  383. legend = ax.legend(title=self.legend_title)
  384. adjust_legend_subtitles(legend)
  385. def lineplot(
  386. data=None, *,
  387. x=None, y=None, hue=None, size=None, style=None, units=None,
  388. palette=None, hue_order=None, hue_norm=None,
  389. sizes=None, size_order=None, size_norm=None,
  390. dashes=True, markers=None, style_order=None,
  391. estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None,
  392. orient="x", sort=True, err_style="band", err_kws=None,
  393. legend="auto", ci="deprecated", ax=None, **kwargs
  394. ):
  395. # Handle deprecation of ci parameter
  396. errorbar = _deprecate_ci(errorbar, ci)
  397. p = _LinePlotter(
  398. data=data,
  399. variables=dict(x=x, y=y, hue=hue, size=size, style=style, units=units),
  400. estimator=estimator, n_boot=n_boot, seed=seed, errorbar=errorbar,
  401. sort=sort, orient=orient, err_style=err_style, err_kws=err_kws,
  402. legend=legend,
  403. )
  404. p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
  405. p.map_size(sizes=sizes, order=size_order, norm=size_norm)
  406. p.map_style(markers=markers, dashes=dashes, order=style_order)
  407. if ax is None:
  408. ax = plt.gca()
  409. if "style" not in p.variables and not {"ls", "linestyle"} & set(kwargs): # XXX
  410. kwargs["dashes"] = "" if dashes is None or isinstance(dashes, bool) else dashes
  411. if not p.has_xy_data:
  412. return ax
  413. p._attach(ax)
  414. # Other functions have color as an explicit param,
  415. # and we should probably do that here too
  416. color = kwargs.pop("color", kwargs.pop("c", None))
  417. kwargs["color"] = _default_color(ax.plot, hue, color, kwargs)
  418. p.plot(ax, kwargs)
  419. return ax
  420. lineplot.__doc__ = """\
  421. Draw a line plot with possibility of several semantic groupings.
  422. {narrative.main_api}
  423. {narrative.relational_semantic}
  424. By default, the plot aggregates over multiple `y` values at each value of
  425. `x` and shows an estimate of the central tendency and a confidence
  426. interval for that estimate.
  427. Parameters
  428. ----------
  429. {params.core.data}
  430. {params.core.xy}
  431. hue : vector or key in `data`
  432. Grouping variable that will produce lines with different colors.
  433. Can be either categorical or numeric, although color mapping will
  434. behave differently in latter case.
  435. size : vector or key in `data`
  436. Grouping variable that will produce lines with different widths.
  437. Can be either categorical or numeric, although size mapping will
  438. behave differently in latter case.
  439. style : vector or key in `data`
  440. Grouping variable that will produce lines with different dashes
  441. and/or markers. Can have a numeric dtype but will always be treated
  442. as categorical.
  443. {params.rel.units}
  444. {params.core.palette}
  445. {params.core.hue_order}
  446. {params.core.hue_norm}
  447. {params.rel.sizes}
  448. {params.rel.size_order}
  449. {params.rel.size_norm}
  450. {params.rel.dashes}
  451. {params.rel.markers}
  452. {params.rel.style_order}
  453. {params.rel.estimator}
  454. {params.stat.errorbar}
  455. {params.rel.n_boot}
  456. {params.rel.seed}
  457. orient : "x" or "y"
  458. Dimension along which the data are sorted / aggregated. Equivalently,
  459. the "independent variable" of the resulting function.
  460. sort : boolean
  461. If True, the data will be sorted by the x and y variables, otherwise
  462. lines will connect points in the order they appear in the dataset.
  463. err_style : "band" or "bars"
  464. Whether to draw the confidence intervals with translucent error bands
  465. or discrete error bars.
  466. err_kws : dict of keyword arguments
  467. Additional parameters to control the aesthetics of the error bars. The
  468. kwargs are passed either to :meth:`matplotlib.axes.Axes.fill_between`
  469. or :meth:`matplotlib.axes.Axes.errorbar`, depending on `err_style`.
  470. {params.rel.legend}
  471. {params.rel.ci}
  472. {params.core.ax}
  473. kwargs : key, value mappings
  474. Other keyword arguments are passed down to
  475. :meth:`matplotlib.axes.Axes.plot`.
  476. Returns
  477. -------
  478. {returns.ax}
  479. See Also
  480. --------
  481. {seealso.scatterplot}
  482. {seealso.pointplot}
  483. Examples
  484. --------
  485. .. include:: ../docstrings/lineplot.rst
  486. """.format(
  487. narrative=_relational_narrative,
  488. params=_param_docs,
  489. returns=_core_docs["returns"],
  490. seealso=_core_docs["seealso"],
  491. )
  492. def scatterplot(
  493. data=None, *,
  494. x=None, y=None, hue=None, size=None, style=None,
  495. palette=None, hue_order=None, hue_norm=None,
  496. sizes=None, size_order=None, size_norm=None,
  497. markers=True, style_order=None, legend="auto", ax=None,
  498. **kwargs
  499. ):
  500. p = _ScatterPlotter(
  501. data=data,
  502. variables=dict(x=x, y=y, hue=hue, size=size, style=style),
  503. legend=legend
  504. )
  505. p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
  506. p.map_size(sizes=sizes, order=size_order, norm=size_norm)
  507. p.map_style(markers=markers, order=style_order)
  508. if ax is None:
  509. ax = plt.gca()
  510. if not p.has_xy_data:
  511. return ax
  512. p._attach(ax)
  513. color = kwargs.pop("color", None)
  514. kwargs["color"] = _default_color(ax.scatter, hue, color, kwargs)
  515. p.plot(ax, kwargs)
  516. return ax
  517. scatterplot.__doc__ = """\
  518. Draw a scatter plot with possibility of several semantic groupings.
  519. {narrative.main_api}
  520. {narrative.relational_semantic}
  521. Parameters
  522. ----------
  523. {params.core.data}
  524. {params.core.xy}
  525. hue : vector or key in `data`
  526. Grouping variable that will produce points with different colors.
  527. Can be either categorical or numeric, although color mapping will
  528. behave differently in latter case.
  529. size : vector or key in `data`
  530. Grouping variable that will produce points with different sizes.
  531. Can be either categorical or numeric, although size mapping will
  532. behave differently in latter case.
  533. style : vector or key in `data`
  534. Grouping variable that will produce points with different markers.
  535. Can have a numeric dtype but will always be treated as categorical.
  536. {params.core.palette}
  537. {params.core.hue_order}
  538. {params.core.hue_norm}
  539. {params.rel.sizes}
  540. {params.rel.size_order}
  541. {params.rel.size_norm}
  542. {params.rel.markers}
  543. {params.rel.style_order}
  544. {params.rel.legend}
  545. {params.core.ax}
  546. kwargs : key, value mappings
  547. Other keyword arguments are passed down to
  548. :meth:`matplotlib.axes.Axes.scatter`.
  549. Returns
  550. -------
  551. {returns.ax}
  552. See Also
  553. --------
  554. {seealso.lineplot}
  555. {seealso.stripplot}
  556. {seealso.swarmplot}
  557. Examples
  558. --------
  559. .. include:: ../docstrings/scatterplot.rst
  560. """.format(
  561. narrative=_relational_narrative,
  562. params=_param_docs,
  563. returns=_core_docs["returns"],
  564. seealso=_core_docs["seealso"],
  565. )
  566. def relplot(
  567. data=None, *,
  568. x=None, y=None, hue=None, size=None, style=None, units=None,
  569. row=None, col=None, col_wrap=None, row_order=None, col_order=None,
  570. palette=None, hue_order=None, hue_norm=None,
  571. sizes=None, size_order=None, size_norm=None,
  572. markers=None, dashes=None, style_order=None,
  573. legend="auto", kind="scatter", height=5, aspect=1, facet_kws=None,
  574. **kwargs
  575. ):
  576. if kind == "scatter":
  577. Plotter = _ScatterPlotter
  578. func = scatterplot
  579. markers = True if markers is None else markers
  580. elif kind == "line":
  581. Plotter = _LinePlotter
  582. func = lineplot
  583. dashes = True if dashes is None else dashes
  584. else:
  585. err = f"Plot kind {kind} not recognized"
  586. raise ValueError(err)
  587. # Check for attempt to plot onto specific axes and warn
  588. if "ax" in kwargs:
  589. msg = (
  590. "relplot is a figure-level function and does not accept "
  591. "the `ax` parameter. You may wish to try {}".format(kind + "plot")
  592. )
  593. warnings.warn(msg, UserWarning)
  594. kwargs.pop("ax")
  595. # Use the full dataset to map the semantics
  596. variables = dict(x=x, y=y, hue=hue, size=size, style=style)
  597. if kind == "line":
  598. variables["units"] = units
  599. elif units is not None:
  600. msg = "The `units` parameter of `relplot` has no effect with kind='scatter'"
  601. warnings.warn(msg, stacklevel=2)
  602. p = Plotter(
  603. data=data,
  604. variables=variables,
  605. legend=legend,
  606. )
  607. p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
  608. p.map_size(sizes=sizes, order=size_order, norm=size_norm)
  609. p.map_style(markers=markers, dashes=dashes, order=style_order)
  610. # Extract the semantic mappings
  611. if "hue" in p.variables:
  612. palette = p._hue_map.lookup_table
  613. hue_order = p._hue_map.levels
  614. hue_norm = p._hue_map.norm
  615. else:
  616. palette = hue_order = hue_norm = None
  617. if "size" in p.variables:
  618. sizes = p._size_map.lookup_table
  619. size_order = p._size_map.levels
  620. size_norm = p._size_map.norm
  621. if "style" in p.variables:
  622. style_order = p._style_map.levels
  623. if markers:
  624. markers = {k: p._style_map(k, "marker") for k in style_order}
  625. else:
  626. markers = None
  627. if dashes:
  628. dashes = {k: p._style_map(k, "dashes") for k in style_order}
  629. else:
  630. dashes = None
  631. else:
  632. markers = dashes = style_order = None
  633. # Now extract the data that would be used to draw a single plot
  634. variables = p.variables
  635. plot_data = p.plot_data
  636. # Define the common plotting parameters
  637. plot_kws = dict(
  638. palette=palette, hue_order=hue_order, hue_norm=hue_norm,
  639. sizes=sizes, size_order=size_order, size_norm=size_norm,
  640. markers=markers, dashes=dashes, style_order=style_order,
  641. legend=False,
  642. )
  643. plot_kws.update(kwargs)
  644. if kind == "scatter":
  645. plot_kws.pop("dashes")
  646. # Add the grid semantics onto the plotter
  647. grid_variables = dict(
  648. x=x, y=y, row=row, col=col,
  649. hue=hue, size=size, style=style,
  650. )
  651. if kind == "line":
  652. grid_variables["units"] = units
  653. p.assign_variables(data, grid_variables)
  654. # Define the named variables for plotting on each facet
  655. # Rename the variables with a leading underscore to avoid
  656. # collisions with faceting variable names
  657. plot_variables = {v: f"_{v}" for v in variables}
  658. plot_kws.update(plot_variables)
  659. # Pass the row/col variables to FacetGrid with their original
  660. # names so that the axes titles render correctly
  661. for var in ["row", "col"]:
  662. # Handle faceting variables that lack name information
  663. if var in p.variables and p.variables[var] is None:
  664. p.variables[var] = f"_{var}_"
  665. grid_kws = {v: p.variables.get(v) for v in ["row", "col"]}
  666. # Rename the columns of the plot_data structure appropriately
  667. new_cols = plot_variables.copy()
  668. new_cols.update(grid_kws)
  669. full_data = p.plot_data.rename(columns=new_cols)
  670. # Set up the FacetGrid object
  671. facet_kws = {} if facet_kws is None else facet_kws.copy()
  672. g = FacetGrid(
  673. data=full_data.dropna(axis=1, how="all"),
  674. **grid_kws,
  675. col_wrap=col_wrap, row_order=row_order, col_order=col_order,
  676. height=height, aspect=aspect, dropna=False,
  677. **facet_kws
  678. )
  679. # Draw the plot
  680. g.map_dataframe(func, **plot_kws)
  681. # Label the axes, using the original variables
  682. # Pass "" when the variable name is None to overwrite internal variables
  683. g.set_axis_labels(variables.get("x") or "", variables.get("y") or "")
  684. if legend:
  685. # Replace the original plot data so the legend uses numeric data with
  686. # the correct type, since we force a categorical mapping above.
  687. p.plot_data = plot_data
  688. # Handle the additional non-semantic keyword arguments out here.
  689. # We're selective because some kwargs may be seaborn function specific
  690. # and not relevant to the matplotlib artists going into the legend.
  691. # Ideally, we will have a better solution where we don't need to re-make
  692. # the legend out here and will have parity with the axes-level functions.
  693. keys = ["c", "color", "alpha", "m", "marker"]
  694. if kind == "scatter":
  695. legend_artist = _scatter_legend_artist
  696. keys += ["s", "facecolor", "fc", "edgecolor", "ec", "linewidth", "lw"]
  697. else:
  698. legend_artist = partial(mpl.lines.Line2D, xdata=[], ydata=[])
  699. keys += [
  700. "markersize", "ms",
  701. "markeredgewidth", "mew",
  702. "markeredgecolor", "mec",
  703. "linestyle", "ls",
  704. "linewidth", "lw",
  705. ]
  706. common_kws = {k: v for k, v in kwargs.items() if k in keys}
  707. attrs = {"hue": "color", "style": None}
  708. if kind == "scatter":
  709. attrs["size"] = "s"
  710. elif kind == "line":
  711. attrs["size"] = "linewidth"
  712. p.add_legend_data(g.axes.flat[0], legend_artist, common_kws, attrs)
  713. if p.legend_data:
  714. g.add_legend(legend_data=p.legend_data,
  715. label_order=p.legend_order,
  716. title=p.legend_title,
  717. adjust_subtitles=True)
  718. # Rename the columns of the FacetGrid's `data` attribute
  719. # to match the original column names
  720. orig_cols = {
  721. f"_{k}": f"_{k}_" if v is None else v for k, v in variables.items()
  722. }
  723. grid_data = g.data.rename(columns=orig_cols)
  724. if data is not None and (x is not None or y is not None):
  725. if not isinstance(data, pd.DataFrame):
  726. data = pd.DataFrame(data)
  727. g.data = pd.merge(
  728. data,
  729. grid_data[grid_data.columns.difference(data.columns)],
  730. left_index=True,
  731. right_index=True,
  732. )
  733. else:
  734. g.data = grid_data
  735. return g
  736. relplot.__doc__ = """\
  737. Figure-level interface for drawing relational plots onto a FacetGrid.
  738. This function provides access to several different axes-level functions
  739. that show the relationship between two variables with semantic mappings
  740. of subsets. The `kind` parameter selects the underlying axes-level
  741. function to use:
  742. - :func:`scatterplot` (with `kind="scatter"`; the default)
  743. - :func:`lineplot` (with `kind="line"`)
  744. Extra keyword arguments are passed to the underlying function, so you
  745. should refer to the documentation for each to see kind-specific options.
  746. {narrative.main_api}
  747. {narrative.relational_semantic}
  748. After plotting, the :class:`FacetGrid` with the plot is returned and can
  749. be used directly to tweak supporting plot details or add other layers.
  750. Parameters
  751. ----------
  752. {params.core.data}
  753. {params.core.xy}
  754. hue : vector or key in `data`
  755. Grouping variable that will produce elements with different colors.
  756. Can be either categorical or numeric, although color mapping will
  757. behave differently in latter case.
  758. size : vector or key in `data`
  759. Grouping variable that will produce elements with different sizes.
  760. Can be either categorical or numeric, although size mapping will
  761. behave differently in latter case.
  762. style : vector or key in `data`
  763. Grouping variable that will produce elements with different styles.
  764. Can have a numeric dtype but will always be treated as categorical.
  765. {params.rel.units}
  766. {params.facets.rowcol}
  767. {params.facets.col_wrap}
  768. row_order, col_order : lists of strings
  769. Order to organize the rows and/or columns of the grid in, otherwise the
  770. orders are inferred from the data objects.
  771. {params.core.palette}
  772. {params.core.hue_order}
  773. {params.core.hue_norm}
  774. {params.rel.sizes}
  775. {params.rel.size_order}
  776. {params.rel.size_norm}
  777. {params.rel.style_order}
  778. {params.rel.dashes}
  779. {params.rel.markers}
  780. {params.rel.legend}
  781. kind : string
  782. Kind of plot to draw, corresponding to a seaborn relational plot.
  783. Options are `"scatter"` or `"line"`.
  784. {params.facets.height}
  785. {params.facets.aspect}
  786. facet_kws : dict
  787. Dictionary of other keyword arguments to pass to :class:`FacetGrid`.
  788. kwargs : key, value pairings
  789. Other keyword arguments are passed through to the underlying plotting
  790. function.
  791. Returns
  792. -------
  793. {returns.facetgrid}
  794. Examples
  795. --------
  796. .. include:: ../docstrings/relplot.rst
  797. """.format(
  798. narrative=_relational_narrative,
  799. params=_param_docs,
  800. returns=_core_docs["returns"],
  801. )