123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940 |
- """Plotting functions for linear models (broadly construed)."""
- import copy
- from textwrap import dedent
- import warnings
- import numpy as np
- import pandas as pd
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- try:
- import statsmodels
- assert statsmodels
- _has_statsmodels = True
- except ImportError:
- _has_statsmodels = False
- from . import utils
- from . import algorithms as algo
- from .axisgrid import FacetGrid, _facet_docs
- __all__ = ["lmplot", "regplot", "residplot"]
- class _LinearPlotter:
- """Base class for plotting relational data in tidy format.
- To get anything useful done you'll have to inherit from this, but setup
- code that can be abstracted out should be put here.
- """
- def establish_variables(self, data, **kws):
- """Extract variables from data or use directly."""
- self.data = data
- # Validate the inputs
- any_strings = any([isinstance(v, str) for v in kws.values()])
- if any_strings and data is None:
- raise ValueError("Must pass `data` if using named variables.")
- # Set the variables
- for var, val in kws.items():
- if isinstance(val, str):
- vector = data[val]
- elif isinstance(val, list):
- vector = np.asarray(val)
- else:
- vector = val
- if vector is not None and vector.shape != (1,):
- vector = np.squeeze(vector)
- if np.ndim(vector) > 1:
- err = "regplot inputs must be 1d"
- raise ValueError(err)
- setattr(self, var, vector)
- def dropna(self, *vars):
- """Remove observations with missing data."""
- vals = [getattr(self, var) for var in vars]
- vals = [v for v in vals if v is not None]
- not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1)
- for var in vars:
- val = getattr(self, var)
- if val is not None:
- setattr(self, var, val[not_na])
- def plot(self, ax):
- raise NotImplementedError
- class _RegressionPlotter(_LinearPlotter):
- """Plotter for numeric independent variables with regression model.
- This does the computations and drawing for the `regplot` function, and
- is thus also used indirectly by `lmplot`.
- """
- def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
- x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
- units=None, seed=None, order=1, logistic=False, lowess=False,
- robust=False, logx=False, x_partial=None, y_partial=None,
- truncate=False, dropna=True, x_jitter=None, y_jitter=None,
- color=None, label=None):
- # Set member attributes
- self.x_estimator = x_estimator
- self.ci = ci
- self.x_ci = ci if x_ci == "ci" else x_ci
- self.n_boot = n_boot
- self.seed = seed
- self.scatter = scatter
- self.fit_reg = fit_reg
- self.order = order
- self.logistic = logistic
- self.lowess = lowess
- self.robust = robust
- self.logx = logx
- self.truncate = truncate
- self.x_jitter = x_jitter
- self.y_jitter = y_jitter
- self.color = color
- self.label = label
- # Validate the regression options:
- if sum((order > 1, logistic, robust, lowess, logx)) > 1:
- raise ValueError("Mutually exclusive regression options.")
- # Extract the data vals from the arguments or passed dataframe
- self.establish_variables(data, x=x, y=y, units=units,
- x_partial=x_partial, y_partial=y_partial)
- # Drop null observations
- if dropna:
- self.dropna("x", "y", "units", "x_partial", "y_partial")
- # Regress nuisance variables out of the data
- if self.x_partial is not None:
- self.x = self.regress_out(self.x, self.x_partial)
- if self.y_partial is not None:
- self.y = self.regress_out(self.y, self.y_partial)
- # Possibly bin the predictor variable, which implies a point estimate
- if x_bins is not None:
- self.x_estimator = np.mean if x_estimator is None else x_estimator
- x_discrete, x_bins = self.bin_predictor(x_bins)
- self.x_discrete = x_discrete
- else:
- self.x_discrete = self.x
- # Disable regression in case of singleton inputs
- if len(self.x) <= 1:
- self.fit_reg = False
- # Save the range of the x variable for the grid later
- if self.fit_reg:
- self.x_range = self.x.min(), self.x.max()
- @property
- def scatter_data(self):
- """Data where each observation is a point."""
- x_j = self.x_jitter
- if x_j is None:
- x = self.x
- else:
- x = self.x + np.random.uniform(-x_j, x_j, len(self.x))
- y_j = self.y_jitter
- if y_j is None:
- y = self.y
- else:
- y = self.y + np.random.uniform(-y_j, y_j, len(self.y))
- return x, y
- @property
- def estimate_data(self):
- """Data with a point estimate and CI for each discrete x value."""
- x, y = self.x_discrete, self.y
- vals = sorted(np.unique(x))
- points, cis = [], []
- for val in vals:
- # Get the point estimate of the y variable
- _y = y[x == val]
- est = self.x_estimator(_y)
- points.append(est)
- # Compute the confidence interval for this estimate
- if self.x_ci is None:
- cis.append(None)
- else:
- units = None
- if self.x_ci == "sd":
- sd = np.std(_y)
- _ci = est - sd, est + sd
- else:
- if self.units is not None:
- units = self.units[x == val]
- boots = algo.bootstrap(_y,
- func=self.x_estimator,
- n_boot=self.n_boot,
- units=units,
- seed=self.seed)
- _ci = utils.ci(boots, self.x_ci)
- cis.append(_ci)
- return vals, points, cis
- def _check_statsmodels(self):
- """Check whether statsmodels is installed if any boolean options require it."""
- options = "logistic", "robust", "lowess"
- err = "`{}=True` requires statsmodels, an optional dependency, to be installed."
- for option in options:
- if getattr(self, option) and not _has_statsmodels:
- raise RuntimeError(err.format(option))
- def fit_regression(self, ax=None, x_range=None, grid=None):
- """Fit the regression model."""
- self._check_statsmodels()
- # Create the grid for the regression
- if grid is None:
- if self.truncate:
- x_min, x_max = self.x_range
- else:
- if ax is None:
- x_min, x_max = x_range
- else:
- x_min, x_max = ax.get_xlim()
- grid = np.linspace(x_min, x_max, 100)
- ci = self.ci
- # Fit the regression
- if self.order > 1:
- yhat, yhat_boots = self.fit_poly(grid, self.order)
- elif self.logistic:
- from statsmodels.genmod.generalized_linear_model import GLM
- from statsmodels.genmod.families import Binomial
- yhat, yhat_boots = self.fit_statsmodels(grid, GLM,
- family=Binomial())
- elif self.lowess:
- ci = None
- grid, yhat = self.fit_lowess()
- elif self.robust:
- from statsmodels.robust.robust_linear_model import RLM
- yhat, yhat_boots = self.fit_statsmodels(grid, RLM)
- elif self.logx:
- yhat, yhat_boots = self.fit_logx(grid)
- else:
- yhat, yhat_boots = self.fit_fast(grid)
- # Compute the confidence interval at each grid point
- if ci is None:
- err_bands = None
- else:
- err_bands = utils.ci(yhat_boots, ci, axis=0)
- return grid, yhat, err_bands
- def fit_fast(self, grid):
- """Low-level regression and prediction using linear algebra."""
- def reg_func(_x, _y):
- return np.linalg.pinv(_x).dot(_y)
- X, y = np.c_[np.ones(len(self.x)), self.x], self.y
- grid = np.c_[np.ones(len(grid)), grid]
- yhat = grid.dot(reg_func(X, y))
- if self.ci is None:
- return yhat, None
- beta_boots = algo.bootstrap(X, y,
- func=reg_func,
- n_boot=self.n_boot,
- units=self.units,
- seed=self.seed).T
- yhat_boots = grid.dot(beta_boots).T
- return yhat, yhat_boots
- def fit_poly(self, grid, order):
- """Regression using numpy polyfit for higher-order trends."""
- def reg_func(_x, _y):
- return np.polyval(np.polyfit(_x, _y, order), grid)
- x, y = self.x, self.y
- yhat = reg_func(x, y)
- if self.ci is None:
- return yhat, None
- yhat_boots = algo.bootstrap(x, y,
- func=reg_func,
- n_boot=self.n_boot,
- units=self.units,
- seed=self.seed)
- return yhat, yhat_boots
- def fit_statsmodels(self, grid, model, **kwargs):
- """More general regression function using statsmodels objects."""
- import statsmodels.tools.sm_exceptions as sme
- X, y = np.c_[np.ones(len(self.x)), self.x], self.y
- grid = np.c_[np.ones(len(grid)), grid]
- def reg_func(_x, _y):
- err_classes = (sme.PerfectSeparationError,)
- try:
- with warnings.catch_warnings():
- if hasattr(sme, "PerfectSeparationWarning"):
- # statsmodels>=0.14.0
- warnings.simplefilter("error", sme.PerfectSeparationWarning)
- err_classes = (*err_classes, sme.PerfectSeparationWarning)
- yhat = model(_y, _x, **kwargs).fit().predict(grid)
- except err_classes:
- yhat = np.empty(len(grid))
- yhat.fill(np.nan)
- return yhat
- yhat = reg_func(X, y)
- if self.ci is None:
- return yhat, None
- yhat_boots = algo.bootstrap(X, y,
- func=reg_func,
- n_boot=self.n_boot,
- units=self.units,
- seed=self.seed)
- return yhat, yhat_boots
- def fit_lowess(self):
- """Fit a locally-weighted regression, which returns its own grid."""
- from statsmodels.nonparametric.smoothers_lowess import lowess
- grid, yhat = lowess(self.y, self.x).T
- return grid, yhat
- def fit_logx(self, grid):
- """Fit the model in log-space."""
- X, y = np.c_[np.ones(len(self.x)), self.x], self.y
- grid = np.c_[np.ones(len(grid)), np.log(grid)]
- def reg_func(_x, _y):
- _x = np.c_[_x[:, 0], np.log(_x[:, 1])]
- return np.linalg.pinv(_x).dot(_y)
- yhat = grid.dot(reg_func(X, y))
- if self.ci is None:
- return yhat, None
- beta_boots = algo.bootstrap(X, y,
- func=reg_func,
- n_boot=self.n_boot,
- units=self.units,
- seed=self.seed).T
- yhat_boots = grid.dot(beta_boots).T
- return yhat, yhat_boots
- def bin_predictor(self, bins):
- """Discretize a predictor by assigning value to closest bin."""
- x = np.asarray(self.x)
- if np.isscalar(bins):
- percentiles = np.linspace(0, 100, bins + 2)[1:-1]
- bins = np.percentile(x, percentiles)
- else:
- bins = np.ravel(bins)
- dist = np.abs(np.subtract.outer(x, bins))
- x_binned = bins[np.argmin(dist, axis=1)].ravel()
- return x_binned, bins
- def regress_out(self, a, b):
- """Regress b from a keeping a's original mean."""
- a_mean = a.mean()
- a = a - a_mean
- b = b - b.mean()
- b = np.c_[b]
- a_prime = a - b.dot(np.linalg.pinv(b).dot(a))
- return np.asarray(a_prime + a_mean).reshape(a.shape)
- def plot(self, ax, scatter_kws, line_kws):
- """Draw the full plot."""
- # Insert the plot label into the correct set of keyword arguments
- if self.scatter:
- scatter_kws["label"] = self.label
- else:
- line_kws["label"] = self.label
- # Use the current color cycle state as a default
- if self.color is None:
- lines, = ax.plot([], [])
- color = lines.get_color()
- lines.remove()
- else:
- color = self.color
- # Ensure that color is hex to avoid matplotlib weirdness
- color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color))
- # Let color in keyword arguments override overall plot color
- scatter_kws.setdefault("color", color)
- line_kws.setdefault("color", color)
- # Draw the constituent plots
- if self.scatter:
- self.scatterplot(ax, scatter_kws)
- if self.fit_reg:
- self.lineplot(ax, line_kws)
- # Label the axes
- if hasattr(self.x, "name"):
- ax.set_xlabel(self.x.name)
- if hasattr(self.y, "name"):
- ax.set_ylabel(self.y.name)
- def scatterplot(self, ax, kws):
- """Draw the data."""
- # Treat the line-based markers specially, explicitly setting larger
- # linewidth than is provided by the seaborn style defaults.
- # This would ideally be handled better in matplotlib (i.e., distinguish
- # between edgewidth for solid glyphs and linewidth for line glyphs
- # but this should do for now.
- line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"]
- if self.x_estimator is None:
- if "marker" in kws and kws["marker"] in line_markers:
- lw = mpl.rcParams["lines.linewidth"]
- else:
- lw = mpl.rcParams["lines.markeredgewidth"]
- kws.setdefault("linewidths", lw)
- if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
- kws.setdefault("alpha", .8)
- x, y = self.scatter_data
- ax.scatter(x, y, **kws)
- else:
- # TODO abstraction
- ci_kws = {"color": kws["color"]}
- if "alpha" in kws:
- ci_kws["alpha"] = kws["alpha"]
- ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75
- kws.setdefault("s", 50)
- xs, ys, cis = self.estimate_data
- if [ci for ci in cis if ci is not None]:
- for x, ci in zip(xs, cis):
- ax.plot([x, x], ci, **ci_kws)
- ax.scatter(xs, ys, **kws)
- def lineplot(self, ax, kws):
- """Draw the model."""
- # Fit the regression model
- grid, yhat, err_bands = self.fit_regression(ax)
- edges = grid[0], grid[-1]
- # Get set default aesthetics
- fill_color = kws["color"]
- lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5)
- kws.setdefault("linewidth", lw)
- # Draw the regression line and confidence interval
- line, = ax.plot(grid, yhat, **kws)
- if not self.truncate:
- line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
- if err_bands is not None:
- ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)
- _regression_docs = dict(
- model_api=dedent("""\
- There are a number of mutually exclusive options for estimating the
- regression model. See the :ref:`tutorial <regression_tutorial>` for more
- information.\
- """),
- regplot_vs_lmplot=dedent("""\
- The :func:`regplot` and :func:`lmplot` functions are closely related, but
- the former is an axes-level function while the latter is a figure-level
- function that combines :func:`regplot` and :class:`FacetGrid`.\
- """),
- x_estimator=dedent("""\
- x_estimator : callable that maps vector -> scalar, optional
- Apply this function to each unique value of ``x`` and plot the
- resulting estimate. This is useful when ``x`` is a discrete variable.
- If ``x_ci`` is given, this estimate will be bootstrapped and a
- confidence interval will be drawn.\
- """),
- x_bins=dedent("""\
- x_bins : int or vector, optional
- Bin the ``x`` variable into discrete bins and then estimate the central
- tendency and a confidence interval. This binning only influences how
- the scatterplot is drawn; the regression is still fit to the original
- data. This parameter is interpreted either as the number of
- evenly-sized (not necessary spaced) bins or the positions of the bin
- centers. When this parameter is used, it implies that the default of
- ``x_estimator`` is ``numpy.mean``.\
- """),
- x_ci=dedent("""\
- x_ci : "ci", "sd", int in [0, 100] or None, optional
- Size of the confidence interval used when plotting a central tendency
- for discrete values of ``x``. If ``"ci"``, defer to the value of the
- ``ci`` parameter. If ``"sd"``, skip bootstrapping and show the
- standard deviation of the observations in each bin.\
- """),
- scatter=dedent("""\
- scatter : bool, optional
- If ``True``, draw a scatterplot with the underlying observations (or
- the ``x_estimator`` values).\
- """),
- fit_reg=dedent("""\
- fit_reg : bool, optional
- If ``True``, estimate and plot a regression model relating the ``x``
- and ``y`` variables.\
- """),
- ci=dedent("""\
- ci : int in [0, 100] or None, optional
- Size of the confidence interval for the regression estimate. This will
- be drawn using translucent bands around the regression line. The
- confidence interval is estimated using a bootstrap; for large
- datasets, it may be advisable to avoid that computation by setting
- this parameter to None.\
- """),
- n_boot=dedent("""\
- n_boot : int, optional
- Number of bootstrap resamples used to estimate the ``ci``. The default
- value attempts to balance time and stability; you may want to increase
- this value for "final" versions of plots.\
- """),
- units=dedent("""\
- units : variable name in ``data``, optional
- If the ``x`` and ``y`` observations are nested within sampling units,
- those can be specified here. This will be taken into account when
- computing the confidence intervals by performing a multilevel bootstrap
- that resamples both units and observations (within unit). This does not
- otherwise influence how the regression is estimated or drawn.\
- """),
- seed=dedent("""\
- seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
- Seed or random number generator for reproducible bootstrapping.\
- """),
- order=dedent("""\
- order : int, optional
- If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a
- polynomial regression.\
- """),
- logistic=dedent("""\
- logistic : bool, optional
- If ``True``, assume that ``y`` is a binary variable and use
- ``statsmodels`` to estimate a logistic regression model. Note that this
- is substantially more computationally intensive than linear regression,
- so you may wish to decrease the number of bootstrap resamples
- (``n_boot``) or set ``ci`` to None.\
- """),
- lowess=dedent("""\
- lowess : bool, optional
- If ``True``, use ``statsmodels`` to estimate a nonparametric lowess
- model (locally weighted linear regression). Note that confidence
- intervals cannot currently be drawn for this kind of model.\
- """),
- robust=dedent("""\
- robust : bool, optional
- If ``True``, use ``statsmodels`` to estimate a robust regression. This
- will de-weight outliers. Note that this is substantially more
- computationally intensive than standard linear regression, so you may
- wish to decrease the number of bootstrap resamples (``n_boot``) or set
- ``ci`` to None.\
- """),
- logx=dedent("""\
- logx : bool, optional
- If ``True``, estimate a linear regression of the form y ~ log(x), but
- plot the scatterplot and regression model in the input space. Note that
- ``x`` must be positive for this to work.\
- """),
- xy_partial=dedent("""\
- {x,y}_partial : strings in ``data`` or matrices
- Confounding variables to regress out of the ``x`` or ``y`` variables
- before plotting.\
- """),
- truncate=dedent("""\
- truncate : bool, optional
- If ``True``, the regression line is bounded by the data limits. If
- ``False``, it extends to the ``x`` axis limits.
- """),
- xy_jitter=dedent("""\
- {x,y}_jitter : floats, optional
- Add uniform random noise of this size to either the ``x`` or ``y``
- variables. The noise is added to a copy of the data after fitting the
- regression, and only influences the look of the scatterplot. This can
- be helpful when plotting variables that take discrete values.\
- """),
- scatter_line_kws=dedent("""\
- {scatter,line}_kws : dictionaries
- Additional keyword arguments to pass to ``plt.scatter`` and
- ``plt.plot``.\
- """),
- )
- _regression_docs.update(_facet_docs)
- def lmplot(
- data=None, *,
- x=None, y=None, hue=None, col=None, row=None,
- palette=None, col_wrap=None, height=5, aspect=1, markers="o",
- sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None,
- legend=True, legend_out=None, x_estimator=None, x_bins=None,
- x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
- units=None, seed=None, order=1, logistic=False, lowess=False,
- robust=False, logx=False, x_partial=None, y_partial=None,
- truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
- line_kws=None, facet_kws=None,
- ):
- if facet_kws is None:
- facet_kws = {}
- def facet_kw_deprecation(key, val):
- msg = (
- f"{key} is deprecated from the `lmplot` function signature. "
- "Please update your code to pass it using `facet_kws`."
- )
- if val is not None:
- warnings.warn(msg, UserWarning)
- facet_kws[key] = val
- facet_kw_deprecation("sharex", sharex)
- facet_kw_deprecation("sharey", sharey)
- facet_kw_deprecation("legend_out", legend_out)
- if data is None:
- raise TypeError("Missing required keyword argument `data`.")
- # Reduce the dataframe to only needed columns
- need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
- cols = np.unique([a for a in need_cols if a is not None]).tolist()
- data = data[cols]
- # Initialize the grid
- facets = FacetGrid(
- data, row=row, col=col, hue=hue,
- palette=palette,
- row_order=row_order, col_order=col_order, hue_order=hue_order,
- height=height, aspect=aspect, col_wrap=col_wrap,
- **facet_kws,
- )
- # Add the markers here as FacetGrid has figured out how many levels of the
- # hue variable are needed and we don't want to duplicate that process
- if facets.hue_names is None:
- n_markers = 1
- else:
- n_markers = len(facets.hue_names)
- if not isinstance(markers, list):
- markers = [markers] * n_markers
- if len(markers) != n_markers:
- raise ValueError("markers must be a singleton or a list of markers "
- "for each level of the hue variable")
- facets.hue_kws = {"marker": markers}
- def update_datalim(data, x, y, ax, **kws):
- xys = data[[x, y]].to_numpy().astype(float)
- ax.update_datalim(xys, updatey=False)
- ax.autoscale_view(scaley=False)
- facets.map_dataframe(update_datalim, x=x, y=y)
- # Draw the regression plot on each facet
- regplot_kws = dict(
- x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci,
- scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units,
- seed=seed, order=order, logistic=logistic, lowess=lowess,
- robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
- truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
- scatter_kws=scatter_kws, line_kws=line_kws,
- )
- facets.map_dataframe(regplot, x=x, y=y, **regplot_kws)
- facets.set_axis_labels(x, y)
- # Add a legend
- if legend and (hue is not None) and (hue not in [col, row]):
- facets.add_legend()
- return facets
- lmplot.__doc__ = dedent("""\
- Plot data and regression model fits across a FacetGrid.
- This function combines :func:`regplot` and :class:`FacetGrid`. It is
- intended as a convenient interface to fit regression models across
- conditional subsets of a dataset.
- When thinking about how to assign variables to different facets, a general
- rule is that it makes sense to use ``hue`` for the most important
- comparison, followed by ``col`` and ``row``. However, always think about
- your particular dataset and the goals of the visualization you are
- creating.
- {model_api}
- The parameters to this function span most of the options in
- :class:`FacetGrid`, although there may be occasional cases where you will
- want to use that class and :func:`regplot` directly.
- Parameters
- ----------
- {data}
- x, y : strings, optional
- Input variables; these should be column names in ``data``.
- hue, col, row : strings
- Variables that define subsets of the data, which will be drawn on
- separate facets in the grid. See the ``*_order`` parameters to control
- the order of levels of this variable.
- {palette}
- {col_wrap}
- {height}
- {aspect}
- markers : matplotlib marker code or list of marker codes, optional
- Markers for the scatterplot. If a list, each marker in the list will be
- used for each level of the ``hue`` variable.
- {share_xy}
- .. deprecated:: 0.12.0
- Pass using the `facet_kws` dictionary.
- {{hue,col,row}}_order : lists, optional
- Order for the levels of the faceting variables. By default, this will
- be the order that the levels appear in ``data`` or, if the variables
- are pandas categoricals, the category order.
- legend : bool, optional
- If ``True`` and there is a ``hue`` variable, add a legend.
- {legend_out}
- .. deprecated:: 0.12.0
- Pass using the `facet_kws` dictionary.
- {x_estimator}
- {x_bins}
- {x_ci}
- {scatter}
- {fit_reg}
- {ci}
- {n_boot}
- {units}
- {seed}
- {order}
- {logistic}
- {lowess}
- {robust}
- {logx}
- {xy_partial}
- {truncate}
- {xy_jitter}
- {scatter_line_kws}
- facet_kws : dict
- Dictionary of keyword arguments for :class:`FacetGrid`.
- See Also
- --------
- regplot : Plot data and a conditional model fit.
- FacetGrid : Subplot grid for plotting conditional relationships.
- pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
- ``kind="reg"``).
- Notes
- -----
- {regplot_vs_lmplot}
- Examples
- --------
- .. include:: ../docstrings/lmplot.rst
- """).format(**_regression_docs)
- def regplot(
- data=None, *, x=None, y=None,
- x_estimator=None, x_bins=None, x_ci="ci",
- scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
- seed=None, order=1, logistic=False, lowess=False, robust=False,
- logx=False, x_partial=None, y_partial=None,
- truncate=True, dropna=True, x_jitter=None, y_jitter=None,
- label=None, color=None, marker="o",
- scatter_kws=None, line_kws=None, ax=None
- ):
- plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
- scatter, fit_reg, ci, n_boot, units, seed,
- order, logistic, lowess, robust, logx,
- x_partial, y_partial, truncate, dropna,
- x_jitter, y_jitter, color, label)
- if ax is None:
- ax = plt.gca()
- scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
- scatter_kws["marker"] = marker
- line_kws = {} if line_kws is None else copy.copy(line_kws)
- plotter.plot(ax, scatter_kws, line_kws)
- return ax
- regplot.__doc__ = dedent("""\
- Plot data and a linear regression model fit.
- {model_api}
- Parameters
- ----------
- x, y: string, series, or vector array
- Input variables. If strings, these should correspond with column names
- in ``data``. When pandas objects are used, axes will be labeled with
- the series name.
- {data}
- {x_estimator}
- {x_bins}
- {x_ci}
- {scatter}
- {fit_reg}
- {ci}
- {n_boot}
- {units}
- {seed}
- {order}
- {logistic}
- {lowess}
- {robust}
- {logx}
- {xy_partial}
- {truncate}
- {xy_jitter}
- label : string
- Label to apply to either the scatterplot or regression line (if
- ``scatter`` is ``False``) for use in a legend.
- color : matplotlib color
- Color to apply to all plot elements; will be superseded by colors
- passed in ``scatter_kws`` or ``line_kws``.
- marker : matplotlib marker code
- Marker to use for the scatterplot glyphs.
- {scatter_line_kws}
- ax : matplotlib Axes, optional
- Axes object to draw the plot onto, otherwise uses the current Axes.
- Returns
- -------
- ax : matplotlib Axes
- The Axes object containing the plot.
- See Also
- --------
- lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple
- linear relationships in a dataset.
- jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with
- ``kind="reg"``).
- pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
- ``kind="reg"``).
- residplot : Plot the residuals of a linear regression model.
- Notes
- -----
- {regplot_vs_lmplot}
- It's also easy to combine :func:`regplot` and :class:`JointGrid` or
- :class:`PairGrid` through the :func:`jointplot` and :func:`pairplot`
- functions, although these do not directly accept all of :func:`regplot`'s
- parameters.
- Examples
- --------
- .. include:: ../docstrings/regplot.rst
- """).format(**_regression_docs)
- def residplot(
- data=None, *, x=None, y=None,
- x_partial=None, y_partial=None, lowess=False,
- order=1, robust=False, dropna=True, label=None, color=None,
- scatter_kws=None, line_kws=None, ax=None
- ):
- """Plot the residuals of a linear regression.
- This function will regress y on x (possibly as a robust or polynomial
- regression) and then draw a scatterplot of the residuals. You can
- optionally fit a lowess smoother to the residual plot, which can
- help in determining if there is structure to the residuals.
- Parameters
- ----------
- data : DataFrame, optional
- DataFrame to use if `x` and `y` are column names.
- x : vector or string
- Data or column name in `data` for the predictor variable.
- y : vector or string
- Data or column name in `data` for the response variable.
- {x, y}_partial : vectors or string(s) , optional
- These variables are treated as confounding and are removed from
- the `x` or `y` variables before plotting.
- lowess : boolean, optional
- Fit a lowess smoother to the residual scatterplot.
- order : int, optional
- Order of the polynomial to fit when calculating the residuals.
- robust : boolean, optional
- Fit a robust linear regression when calculating the residuals.
- dropna : boolean, optional
- If True, ignore observations with missing data when fitting and
- plotting.
- label : string, optional
- Label that will be used in any plot legends.
- color : matplotlib color, optional
- Color to use for all elements of the plot.
- {scatter, line}_kws : dictionaries, optional
- Additional keyword arguments passed to scatter() and plot() for drawing
- the components of the plot.
- ax : matplotlib axis, optional
- Plot into this axis, otherwise grab the current axis or make a new
- one if not existing.
- Returns
- -------
- ax: matplotlib axes
- Axes with the regression plot.
- See Also
- --------
- regplot : Plot a simple linear regression model.
- jointplot : Draw a :func:`residplot` with univariate marginal distributions
- (when used with ``kind="resid"``).
- Examples
- --------
- .. include:: ../docstrings/residplot.rst
- """
- plotter = _RegressionPlotter(x, y, data, ci=None,
- order=order, robust=robust,
- x_partial=x_partial, y_partial=y_partial,
- dropna=dropna, color=color, label=label)
- if ax is None:
- ax = plt.gca()
- # Calculate the residual from a linear regression
- _, yhat, _ = plotter.fit_regression(grid=plotter.x)
- plotter.y = plotter.y - yhat
- # Set the regression option on the plotter
- if lowess:
- plotter.lowess = True
- else:
- plotter.fit_reg = False
- # Plot a horizontal line at 0
- ax.axhline(0, ls=":", c=".2")
- # Draw the scatterplot
- scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
- line_kws = {} if line_kws is None else line_kws.copy()
- plotter.plot(ax, scatter_kws, line_kws)
- return ax
|