regression.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940
  1. """Plotting functions for linear models (broadly construed)."""
  2. import copy
  3. from textwrap import dedent
  4. import warnings
  5. import numpy as np
  6. import pandas as pd
  7. import matplotlib as mpl
  8. import matplotlib.pyplot as plt
  9. try:
  10. import statsmodels
  11. assert statsmodels
  12. _has_statsmodels = True
  13. except ImportError:
  14. _has_statsmodels = False
  15. from . import utils
  16. from . import algorithms as algo
  17. from .axisgrid import FacetGrid, _facet_docs
  18. __all__ = ["lmplot", "regplot", "residplot"]
  19. class _LinearPlotter:
  20. """Base class for plotting relational data in tidy format.
  21. To get anything useful done you'll have to inherit from this, but setup
  22. code that can be abstracted out should be put here.
  23. """
  24. def establish_variables(self, data, **kws):
  25. """Extract variables from data or use directly."""
  26. self.data = data
  27. # Validate the inputs
  28. any_strings = any([isinstance(v, str) for v in kws.values()])
  29. if any_strings and data is None:
  30. raise ValueError("Must pass `data` if using named variables.")
  31. # Set the variables
  32. for var, val in kws.items():
  33. if isinstance(val, str):
  34. vector = data[val]
  35. elif isinstance(val, list):
  36. vector = np.asarray(val)
  37. else:
  38. vector = val
  39. if vector is not None and vector.shape != (1,):
  40. vector = np.squeeze(vector)
  41. if np.ndim(vector) > 1:
  42. err = "regplot inputs must be 1d"
  43. raise ValueError(err)
  44. setattr(self, var, vector)
  45. def dropna(self, *vars):
  46. """Remove observations with missing data."""
  47. vals = [getattr(self, var) for var in vars]
  48. vals = [v for v in vals if v is not None]
  49. not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1)
  50. for var in vars:
  51. val = getattr(self, var)
  52. if val is not None:
  53. setattr(self, var, val[not_na])
  54. def plot(self, ax):
  55. raise NotImplementedError
  56. class _RegressionPlotter(_LinearPlotter):
  57. """Plotter for numeric independent variables with regression model.
  58. This does the computations and drawing for the `regplot` function, and
  59. is thus also used indirectly by `lmplot`.
  60. """
  61. def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
  62. x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
  63. units=None, seed=None, order=1, logistic=False, lowess=False,
  64. robust=False, logx=False, x_partial=None, y_partial=None,
  65. truncate=False, dropna=True, x_jitter=None, y_jitter=None,
  66. color=None, label=None):
  67. # Set member attributes
  68. self.x_estimator = x_estimator
  69. self.ci = ci
  70. self.x_ci = ci if x_ci == "ci" else x_ci
  71. self.n_boot = n_boot
  72. self.seed = seed
  73. self.scatter = scatter
  74. self.fit_reg = fit_reg
  75. self.order = order
  76. self.logistic = logistic
  77. self.lowess = lowess
  78. self.robust = robust
  79. self.logx = logx
  80. self.truncate = truncate
  81. self.x_jitter = x_jitter
  82. self.y_jitter = y_jitter
  83. self.color = color
  84. self.label = label
  85. # Validate the regression options:
  86. if sum((order > 1, logistic, robust, lowess, logx)) > 1:
  87. raise ValueError("Mutually exclusive regression options.")
  88. # Extract the data vals from the arguments or passed dataframe
  89. self.establish_variables(data, x=x, y=y, units=units,
  90. x_partial=x_partial, y_partial=y_partial)
  91. # Drop null observations
  92. if dropna:
  93. self.dropna("x", "y", "units", "x_partial", "y_partial")
  94. # Regress nuisance variables out of the data
  95. if self.x_partial is not None:
  96. self.x = self.regress_out(self.x, self.x_partial)
  97. if self.y_partial is not None:
  98. self.y = self.regress_out(self.y, self.y_partial)
  99. # Possibly bin the predictor variable, which implies a point estimate
  100. if x_bins is not None:
  101. self.x_estimator = np.mean if x_estimator is None else x_estimator
  102. x_discrete, x_bins = self.bin_predictor(x_bins)
  103. self.x_discrete = x_discrete
  104. else:
  105. self.x_discrete = self.x
  106. # Disable regression in case of singleton inputs
  107. if len(self.x) <= 1:
  108. self.fit_reg = False
  109. # Save the range of the x variable for the grid later
  110. if self.fit_reg:
  111. self.x_range = self.x.min(), self.x.max()
  112. @property
  113. def scatter_data(self):
  114. """Data where each observation is a point."""
  115. x_j = self.x_jitter
  116. if x_j is None:
  117. x = self.x
  118. else:
  119. x = self.x + np.random.uniform(-x_j, x_j, len(self.x))
  120. y_j = self.y_jitter
  121. if y_j is None:
  122. y = self.y
  123. else:
  124. y = self.y + np.random.uniform(-y_j, y_j, len(self.y))
  125. return x, y
  126. @property
  127. def estimate_data(self):
  128. """Data with a point estimate and CI for each discrete x value."""
  129. x, y = self.x_discrete, self.y
  130. vals = sorted(np.unique(x))
  131. points, cis = [], []
  132. for val in vals:
  133. # Get the point estimate of the y variable
  134. _y = y[x == val]
  135. est = self.x_estimator(_y)
  136. points.append(est)
  137. # Compute the confidence interval for this estimate
  138. if self.x_ci is None:
  139. cis.append(None)
  140. else:
  141. units = None
  142. if self.x_ci == "sd":
  143. sd = np.std(_y)
  144. _ci = est - sd, est + sd
  145. else:
  146. if self.units is not None:
  147. units = self.units[x == val]
  148. boots = algo.bootstrap(_y,
  149. func=self.x_estimator,
  150. n_boot=self.n_boot,
  151. units=units,
  152. seed=self.seed)
  153. _ci = utils.ci(boots, self.x_ci)
  154. cis.append(_ci)
  155. return vals, points, cis
  156. def _check_statsmodels(self):
  157. """Check whether statsmodels is installed if any boolean options require it."""
  158. options = "logistic", "robust", "lowess"
  159. err = "`{}=True` requires statsmodels, an optional dependency, to be installed."
  160. for option in options:
  161. if getattr(self, option) and not _has_statsmodels:
  162. raise RuntimeError(err.format(option))
  163. def fit_regression(self, ax=None, x_range=None, grid=None):
  164. """Fit the regression model."""
  165. self._check_statsmodels()
  166. # Create the grid for the regression
  167. if grid is None:
  168. if self.truncate:
  169. x_min, x_max = self.x_range
  170. else:
  171. if ax is None:
  172. x_min, x_max = x_range
  173. else:
  174. x_min, x_max = ax.get_xlim()
  175. grid = np.linspace(x_min, x_max, 100)
  176. ci = self.ci
  177. # Fit the regression
  178. if self.order > 1:
  179. yhat, yhat_boots = self.fit_poly(grid, self.order)
  180. elif self.logistic:
  181. from statsmodels.genmod.generalized_linear_model import GLM
  182. from statsmodels.genmod.families import Binomial
  183. yhat, yhat_boots = self.fit_statsmodels(grid, GLM,
  184. family=Binomial())
  185. elif self.lowess:
  186. ci = None
  187. grid, yhat = self.fit_lowess()
  188. elif self.robust:
  189. from statsmodels.robust.robust_linear_model import RLM
  190. yhat, yhat_boots = self.fit_statsmodels(grid, RLM)
  191. elif self.logx:
  192. yhat, yhat_boots = self.fit_logx(grid)
  193. else:
  194. yhat, yhat_boots = self.fit_fast(grid)
  195. # Compute the confidence interval at each grid point
  196. if ci is None:
  197. err_bands = None
  198. else:
  199. err_bands = utils.ci(yhat_boots, ci, axis=0)
  200. return grid, yhat, err_bands
  201. def fit_fast(self, grid):
  202. """Low-level regression and prediction using linear algebra."""
  203. def reg_func(_x, _y):
  204. return np.linalg.pinv(_x).dot(_y)
  205. X, y = np.c_[np.ones(len(self.x)), self.x], self.y
  206. grid = np.c_[np.ones(len(grid)), grid]
  207. yhat = grid.dot(reg_func(X, y))
  208. if self.ci is None:
  209. return yhat, None
  210. beta_boots = algo.bootstrap(X, y,
  211. func=reg_func,
  212. n_boot=self.n_boot,
  213. units=self.units,
  214. seed=self.seed).T
  215. yhat_boots = grid.dot(beta_boots).T
  216. return yhat, yhat_boots
  217. def fit_poly(self, grid, order):
  218. """Regression using numpy polyfit for higher-order trends."""
  219. def reg_func(_x, _y):
  220. return np.polyval(np.polyfit(_x, _y, order), grid)
  221. x, y = self.x, self.y
  222. yhat = reg_func(x, y)
  223. if self.ci is None:
  224. return yhat, None
  225. yhat_boots = algo.bootstrap(x, y,
  226. func=reg_func,
  227. n_boot=self.n_boot,
  228. units=self.units,
  229. seed=self.seed)
  230. return yhat, yhat_boots
  231. def fit_statsmodels(self, grid, model, **kwargs):
  232. """More general regression function using statsmodels objects."""
  233. import statsmodels.tools.sm_exceptions as sme
  234. X, y = np.c_[np.ones(len(self.x)), self.x], self.y
  235. grid = np.c_[np.ones(len(grid)), grid]
  236. def reg_func(_x, _y):
  237. err_classes = (sme.PerfectSeparationError,)
  238. try:
  239. with warnings.catch_warnings():
  240. if hasattr(sme, "PerfectSeparationWarning"):
  241. # statsmodels>=0.14.0
  242. warnings.simplefilter("error", sme.PerfectSeparationWarning)
  243. err_classes = (*err_classes, sme.PerfectSeparationWarning)
  244. yhat = model(_y, _x, **kwargs).fit().predict(grid)
  245. except err_classes:
  246. yhat = np.empty(len(grid))
  247. yhat.fill(np.nan)
  248. return yhat
  249. yhat = reg_func(X, y)
  250. if self.ci is None:
  251. return yhat, None
  252. yhat_boots = algo.bootstrap(X, y,
  253. func=reg_func,
  254. n_boot=self.n_boot,
  255. units=self.units,
  256. seed=self.seed)
  257. return yhat, yhat_boots
  258. def fit_lowess(self):
  259. """Fit a locally-weighted regression, which returns its own grid."""
  260. from statsmodels.nonparametric.smoothers_lowess import lowess
  261. grid, yhat = lowess(self.y, self.x).T
  262. return grid, yhat
  263. def fit_logx(self, grid):
  264. """Fit the model in log-space."""
  265. X, y = np.c_[np.ones(len(self.x)), self.x], self.y
  266. grid = np.c_[np.ones(len(grid)), np.log(grid)]
  267. def reg_func(_x, _y):
  268. _x = np.c_[_x[:, 0], np.log(_x[:, 1])]
  269. return np.linalg.pinv(_x).dot(_y)
  270. yhat = grid.dot(reg_func(X, y))
  271. if self.ci is None:
  272. return yhat, None
  273. beta_boots = algo.bootstrap(X, y,
  274. func=reg_func,
  275. n_boot=self.n_boot,
  276. units=self.units,
  277. seed=self.seed).T
  278. yhat_boots = grid.dot(beta_boots).T
  279. return yhat, yhat_boots
  280. def bin_predictor(self, bins):
  281. """Discretize a predictor by assigning value to closest bin."""
  282. x = np.asarray(self.x)
  283. if np.isscalar(bins):
  284. percentiles = np.linspace(0, 100, bins + 2)[1:-1]
  285. bins = np.percentile(x, percentiles)
  286. else:
  287. bins = np.ravel(bins)
  288. dist = np.abs(np.subtract.outer(x, bins))
  289. x_binned = bins[np.argmin(dist, axis=1)].ravel()
  290. return x_binned, bins
  291. def regress_out(self, a, b):
  292. """Regress b from a keeping a's original mean."""
  293. a_mean = a.mean()
  294. a = a - a_mean
  295. b = b - b.mean()
  296. b = np.c_[b]
  297. a_prime = a - b.dot(np.linalg.pinv(b).dot(a))
  298. return np.asarray(a_prime + a_mean).reshape(a.shape)
  299. def plot(self, ax, scatter_kws, line_kws):
  300. """Draw the full plot."""
  301. # Insert the plot label into the correct set of keyword arguments
  302. if self.scatter:
  303. scatter_kws["label"] = self.label
  304. else:
  305. line_kws["label"] = self.label
  306. # Use the current color cycle state as a default
  307. if self.color is None:
  308. lines, = ax.plot([], [])
  309. color = lines.get_color()
  310. lines.remove()
  311. else:
  312. color = self.color
  313. # Ensure that color is hex to avoid matplotlib weirdness
  314. color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color))
  315. # Let color in keyword arguments override overall plot color
  316. scatter_kws.setdefault("color", color)
  317. line_kws.setdefault("color", color)
  318. # Draw the constituent plots
  319. if self.scatter:
  320. self.scatterplot(ax, scatter_kws)
  321. if self.fit_reg:
  322. self.lineplot(ax, line_kws)
  323. # Label the axes
  324. if hasattr(self.x, "name"):
  325. ax.set_xlabel(self.x.name)
  326. if hasattr(self.y, "name"):
  327. ax.set_ylabel(self.y.name)
  328. def scatterplot(self, ax, kws):
  329. """Draw the data."""
  330. # Treat the line-based markers specially, explicitly setting larger
  331. # linewidth than is provided by the seaborn style defaults.
  332. # This would ideally be handled better in matplotlib (i.e., distinguish
  333. # between edgewidth for solid glyphs and linewidth for line glyphs
  334. # but this should do for now.
  335. line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"]
  336. if self.x_estimator is None:
  337. if "marker" in kws and kws["marker"] in line_markers:
  338. lw = mpl.rcParams["lines.linewidth"]
  339. else:
  340. lw = mpl.rcParams["lines.markeredgewidth"]
  341. kws.setdefault("linewidths", lw)
  342. if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
  343. kws.setdefault("alpha", .8)
  344. x, y = self.scatter_data
  345. ax.scatter(x, y, **kws)
  346. else:
  347. # TODO abstraction
  348. ci_kws = {"color": kws["color"]}
  349. if "alpha" in kws:
  350. ci_kws["alpha"] = kws["alpha"]
  351. ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75
  352. kws.setdefault("s", 50)
  353. xs, ys, cis = self.estimate_data
  354. if [ci for ci in cis if ci is not None]:
  355. for x, ci in zip(xs, cis):
  356. ax.plot([x, x], ci, **ci_kws)
  357. ax.scatter(xs, ys, **kws)
  358. def lineplot(self, ax, kws):
  359. """Draw the model."""
  360. # Fit the regression model
  361. grid, yhat, err_bands = self.fit_regression(ax)
  362. edges = grid[0], grid[-1]
  363. # Get set default aesthetics
  364. fill_color = kws["color"]
  365. lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5)
  366. kws.setdefault("linewidth", lw)
  367. # Draw the regression line and confidence interval
  368. line, = ax.plot(grid, yhat, **kws)
  369. if not self.truncate:
  370. line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
  371. if err_bands is not None:
  372. ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)
  373. _regression_docs = dict(
  374. model_api=dedent("""\
  375. There are a number of mutually exclusive options for estimating the
  376. regression model. See the :ref:`tutorial <regression_tutorial>` for more
  377. information.\
  378. """),
  379. regplot_vs_lmplot=dedent("""\
  380. The :func:`regplot` and :func:`lmplot` functions are closely related, but
  381. the former is an axes-level function while the latter is a figure-level
  382. function that combines :func:`regplot` and :class:`FacetGrid`.\
  383. """),
  384. x_estimator=dedent("""\
  385. x_estimator : callable that maps vector -> scalar, optional
  386. Apply this function to each unique value of ``x`` and plot the
  387. resulting estimate. This is useful when ``x`` is a discrete variable.
  388. If ``x_ci`` is given, this estimate will be bootstrapped and a
  389. confidence interval will be drawn.\
  390. """),
  391. x_bins=dedent("""\
  392. x_bins : int or vector, optional
  393. Bin the ``x`` variable into discrete bins and then estimate the central
  394. tendency and a confidence interval. This binning only influences how
  395. the scatterplot is drawn; the regression is still fit to the original
  396. data. This parameter is interpreted either as the number of
  397. evenly-sized (not necessary spaced) bins or the positions of the bin
  398. centers. When this parameter is used, it implies that the default of
  399. ``x_estimator`` is ``numpy.mean``.\
  400. """),
  401. x_ci=dedent("""\
  402. x_ci : "ci", "sd", int in [0, 100] or None, optional
  403. Size of the confidence interval used when plotting a central tendency
  404. for discrete values of ``x``. If ``"ci"``, defer to the value of the
  405. ``ci`` parameter. If ``"sd"``, skip bootstrapping and show the
  406. standard deviation of the observations in each bin.\
  407. """),
  408. scatter=dedent("""\
  409. scatter : bool, optional
  410. If ``True``, draw a scatterplot with the underlying observations (or
  411. the ``x_estimator`` values).\
  412. """),
  413. fit_reg=dedent("""\
  414. fit_reg : bool, optional
  415. If ``True``, estimate and plot a regression model relating the ``x``
  416. and ``y`` variables.\
  417. """),
  418. ci=dedent("""\
  419. ci : int in [0, 100] or None, optional
  420. Size of the confidence interval for the regression estimate. This will
  421. be drawn using translucent bands around the regression line. The
  422. confidence interval is estimated using a bootstrap; for large
  423. datasets, it may be advisable to avoid that computation by setting
  424. this parameter to None.\
  425. """),
  426. n_boot=dedent("""\
  427. n_boot : int, optional
  428. Number of bootstrap resamples used to estimate the ``ci``. The default
  429. value attempts to balance time and stability; you may want to increase
  430. this value for "final" versions of plots.\
  431. """),
  432. units=dedent("""\
  433. units : variable name in ``data``, optional
  434. If the ``x`` and ``y`` observations are nested within sampling units,
  435. those can be specified here. This will be taken into account when
  436. computing the confidence intervals by performing a multilevel bootstrap
  437. that resamples both units and observations (within unit). This does not
  438. otherwise influence how the regression is estimated or drawn.\
  439. """),
  440. seed=dedent("""\
  441. seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
  442. Seed or random number generator for reproducible bootstrapping.\
  443. """),
  444. order=dedent("""\
  445. order : int, optional
  446. If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a
  447. polynomial regression.\
  448. """),
  449. logistic=dedent("""\
  450. logistic : bool, optional
  451. If ``True``, assume that ``y`` is a binary variable and use
  452. ``statsmodels`` to estimate a logistic regression model. Note that this
  453. is substantially more computationally intensive than linear regression,
  454. so you may wish to decrease the number of bootstrap resamples
  455. (``n_boot``) or set ``ci`` to None.\
  456. """),
  457. lowess=dedent("""\
  458. lowess : bool, optional
  459. If ``True``, use ``statsmodels`` to estimate a nonparametric lowess
  460. model (locally weighted linear regression). Note that confidence
  461. intervals cannot currently be drawn for this kind of model.\
  462. """),
  463. robust=dedent("""\
  464. robust : bool, optional
  465. If ``True``, use ``statsmodels`` to estimate a robust regression. This
  466. will de-weight outliers. Note that this is substantially more
  467. computationally intensive than standard linear regression, so you may
  468. wish to decrease the number of bootstrap resamples (``n_boot``) or set
  469. ``ci`` to None.\
  470. """),
  471. logx=dedent("""\
  472. logx : bool, optional
  473. If ``True``, estimate a linear regression of the form y ~ log(x), but
  474. plot the scatterplot and regression model in the input space. Note that
  475. ``x`` must be positive for this to work.\
  476. """),
  477. xy_partial=dedent("""\
  478. {x,y}_partial : strings in ``data`` or matrices
  479. Confounding variables to regress out of the ``x`` or ``y`` variables
  480. before plotting.\
  481. """),
  482. truncate=dedent("""\
  483. truncate : bool, optional
  484. If ``True``, the regression line is bounded by the data limits. If
  485. ``False``, it extends to the ``x`` axis limits.
  486. """),
  487. xy_jitter=dedent("""\
  488. {x,y}_jitter : floats, optional
  489. Add uniform random noise of this size to either the ``x`` or ``y``
  490. variables. The noise is added to a copy of the data after fitting the
  491. regression, and only influences the look of the scatterplot. This can
  492. be helpful when plotting variables that take discrete values.\
  493. """),
  494. scatter_line_kws=dedent("""\
  495. {scatter,line}_kws : dictionaries
  496. Additional keyword arguments to pass to ``plt.scatter`` and
  497. ``plt.plot``.\
  498. """),
  499. )
  500. _regression_docs.update(_facet_docs)
  501. def lmplot(
  502. data=None, *,
  503. x=None, y=None, hue=None, col=None, row=None,
  504. palette=None, col_wrap=None, height=5, aspect=1, markers="o",
  505. sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None,
  506. legend=True, legend_out=None, x_estimator=None, x_bins=None,
  507. x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
  508. units=None, seed=None, order=1, logistic=False, lowess=False,
  509. robust=False, logx=False, x_partial=None, y_partial=None,
  510. truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
  511. line_kws=None, facet_kws=None,
  512. ):
  513. if facet_kws is None:
  514. facet_kws = {}
  515. def facet_kw_deprecation(key, val):
  516. msg = (
  517. f"{key} is deprecated from the `lmplot` function signature. "
  518. "Please update your code to pass it using `facet_kws`."
  519. )
  520. if val is not None:
  521. warnings.warn(msg, UserWarning)
  522. facet_kws[key] = val
  523. facet_kw_deprecation("sharex", sharex)
  524. facet_kw_deprecation("sharey", sharey)
  525. facet_kw_deprecation("legend_out", legend_out)
  526. if data is None:
  527. raise TypeError("Missing required keyword argument `data`.")
  528. # Reduce the dataframe to only needed columns
  529. need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
  530. cols = np.unique([a for a in need_cols if a is not None]).tolist()
  531. data = data[cols]
  532. # Initialize the grid
  533. facets = FacetGrid(
  534. data, row=row, col=col, hue=hue,
  535. palette=palette,
  536. row_order=row_order, col_order=col_order, hue_order=hue_order,
  537. height=height, aspect=aspect, col_wrap=col_wrap,
  538. **facet_kws,
  539. )
  540. # Add the markers here as FacetGrid has figured out how many levels of the
  541. # hue variable are needed and we don't want to duplicate that process
  542. if facets.hue_names is None:
  543. n_markers = 1
  544. else:
  545. n_markers = len(facets.hue_names)
  546. if not isinstance(markers, list):
  547. markers = [markers] * n_markers
  548. if len(markers) != n_markers:
  549. raise ValueError("markers must be a singleton or a list of markers "
  550. "for each level of the hue variable")
  551. facets.hue_kws = {"marker": markers}
  552. def update_datalim(data, x, y, ax, **kws):
  553. xys = data[[x, y]].to_numpy().astype(float)
  554. ax.update_datalim(xys, updatey=False)
  555. ax.autoscale_view(scaley=False)
  556. facets.map_dataframe(update_datalim, x=x, y=y)
  557. # Draw the regression plot on each facet
  558. regplot_kws = dict(
  559. x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci,
  560. scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units,
  561. seed=seed, order=order, logistic=logistic, lowess=lowess,
  562. robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
  563. truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
  564. scatter_kws=scatter_kws, line_kws=line_kws,
  565. )
  566. facets.map_dataframe(regplot, x=x, y=y, **regplot_kws)
  567. facets.set_axis_labels(x, y)
  568. # Add a legend
  569. if legend and (hue is not None) and (hue not in [col, row]):
  570. facets.add_legend()
  571. return facets
  572. lmplot.__doc__ = dedent("""\
  573. Plot data and regression model fits across a FacetGrid.
  574. This function combines :func:`regplot` and :class:`FacetGrid`. It is
  575. intended as a convenient interface to fit regression models across
  576. conditional subsets of a dataset.
  577. When thinking about how to assign variables to different facets, a general
  578. rule is that it makes sense to use ``hue`` for the most important
  579. comparison, followed by ``col`` and ``row``. However, always think about
  580. your particular dataset and the goals of the visualization you are
  581. creating.
  582. {model_api}
  583. The parameters to this function span most of the options in
  584. :class:`FacetGrid`, although there may be occasional cases where you will
  585. want to use that class and :func:`regplot` directly.
  586. Parameters
  587. ----------
  588. {data}
  589. x, y : strings, optional
  590. Input variables; these should be column names in ``data``.
  591. hue, col, row : strings
  592. Variables that define subsets of the data, which will be drawn on
  593. separate facets in the grid. See the ``*_order`` parameters to control
  594. the order of levels of this variable.
  595. {palette}
  596. {col_wrap}
  597. {height}
  598. {aspect}
  599. markers : matplotlib marker code or list of marker codes, optional
  600. Markers for the scatterplot. If a list, each marker in the list will be
  601. used for each level of the ``hue`` variable.
  602. {share_xy}
  603. .. deprecated:: 0.12.0
  604. Pass using the `facet_kws` dictionary.
  605. {{hue,col,row}}_order : lists, optional
  606. Order for the levels of the faceting variables. By default, this will
  607. be the order that the levels appear in ``data`` or, if the variables
  608. are pandas categoricals, the category order.
  609. legend : bool, optional
  610. If ``True`` and there is a ``hue`` variable, add a legend.
  611. {legend_out}
  612. .. deprecated:: 0.12.0
  613. Pass using the `facet_kws` dictionary.
  614. {x_estimator}
  615. {x_bins}
  616. {x_ci}
  617. {scatter}
  618. {fit_reg}
  619. {ci}
  620. {n_boot}
  621. {units}
  622. {seed}
  623. {order}
  624. {logistic}
  625. {lowess}
  626. {robust}
  627. {logx}
  628. {xy_partial}
  629. {truncate}
  630. {xy_jitter}
  631. {scatter_line_kws}
  632. facet_kws : dict
  633. Dictionary of keyword arguments for :class:`FacetGrid`.
  634. See Also
  635. --------
  636. regplot : Plot data and a conditional model fit.
  637. FacetGrid : Subplot grid for plotting conditional relationships.
  638. pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
  639. ``kind="reg"``).
  640. Notes
  641. -----
  642. {regplot_vs_lmplot}
  643. Examples
  644. --------
  645. .. include:: ../docstrings/lmplot.rst
  646. """).format(**_regression_docs)
  647. def regplot(
  648. data=None, *, x=None, y=None,
  649. x_estimator=None, x_bins=None, x_ci="ci",
  650. scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
  651. seed=None, order=1, logistic=False, lowess=False, robust=False,
  652. logx=False, x_partial=None, y_partial=None,
  653. truncate=True, dropna=True, x_jitter=None, y_jitter=None,
  654. label=None, color=None, marker="o",
  655. scatter_kws=None, line_kws=None, ax=None
  656. ):
  657. plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
  658. scatter, fit_reg, ci, n_boot, units, seed,
  659. order, logistic, lowess, robust, logx,
  660. x_partial, y_partial, truncate, dropna,
  661. x_jitter, y_jitter, color, label)
  662. if ax is None:
  663. ax = plt.gca()
  664. scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
  665. scatter_kws["marker"] = marker
  666. line_kws = {} if line_kws is None else copy.copy(line_kws)
  667. plotter.plot(ax, scatter_kws, line_kws)
  668. return ax
  669. regplot.__doc__ = dedent("""\
  670. Plot data and a linear regression model fit.
  671. {model_api}
  672. Parameters
  673. ----------
  674. x, y: string, series, or vector array
  675. Input variables. If strings, these should correspond with column names
  676. in ``data``. When pandas objects are used, axes will be labeled with
  677. the series name.
  678. {data}
  679. {x_estimator}
  680. {x_bins}
  681. {x_ci}
  682. {scatter}
  683. {fit_reg}
  684. {ci}
  685. {n_boot}
  686. {units}
  687. {seed}
  688. {order}
  689. {logistic}
  690. {lowess}
  691. {robust}
  692. {logx}
  693. {xy_partial}
  694. {truncate}
  695. {xy_jitter}
  696. label : string
  697. Label to apply to either the scatterplot or regression line (if
  698. ``scatter`` is ``False``) for use in a legend.
  699. color : matplotlib color
  700. Color to apply to all plot elements; will be superseded by colors
  701. passed in ``scatter_kws`` or ``line_kws``.
  702. marker : matplotlib marker code
  703. Marker to use for the scatterplot glyphs.
  704. {scatter_line_kws}
  705. ax : matplotlib Axes, optional
  706. Axes object to draw the plot onto, otherwise uses the current Axes.
  707. Returns
  708. -------
  709. ax : matplotlib Axes
  710. The Axes object containing the plot.
  711. See Also
  712. --------
  713. lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple
  714. linear relationships in a dataset.
  715. jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with
  716. ``kind="reg"``).
  717. pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
  718. ``kind="reg"``).
  719. residplot : Plot the residuals of a linear regression model.
  720. Notes
  721. -----
  722. {regplot_vs_lmplot}
  723. It's also easy to combine :func:`regplot` and :class:`JointGrid` or
  724. :class:`PairGrid` through the :func:`jointplot` and :func:`pairplot`
  725. functions, although these do not directly accept all of :func:`regplot`'s
  726. parameters.
  727. Examples
  728. --------
  729. .. include:: ../docstrings/regplot.rst
  730. """).format(**_regression_docs)
  731. def residplot(
  732. data=None, *, x=None, y=None,
  733. x_partial=None, y_partial=None, lowess=False,
  734. order=1, robust=False, dropna=True, label=None, color=None,
  735. scatter_kws=None, line_kws=None, ax=None
  736. ):
  737. """Plot the residuals of a linear regression.
  738. This function will regress y on x (possibly as a robust or polynomial
  739. regression) and then draw a scatterplot of the residuals. You can
  740. optionally fit a lowess smoother to the residual plot, which can
  741. help in determining if there is structure to the residuals.
  742. Parameters
  743. ----------
  744. data : DataFrame, optional
  745. DataFrame to use if `x` and `y` are column names.
  746. x : vector or string
  747. Data or column name in `data` for the predictor variable.
  748. y : vector or string
  749. Data or column name in `data` for the response variable.
  750. {x, y}_partial : vectors or string(s) , optional
  751. These variables are treated as confounding and are removed from
  752. the `x` or `y` variables before plotting.
  753. lowess : boolean, optional
  754. Fit a lowess smoother to the residual scatterplot.
  755. order : int, optional
  756. Order of the polynomial to fit when calculating the residuals.
  757. robust : boolean, optional
  758. Fit a robust linear regression when calculating the residuals.
  759. dropna : boolean, optional
  760. If True, ignore observations with missing data when fitting and
  761. plotting.
  762. label : string, optional
  763. Label that will be used in any plot legends.
  764. color : matplotlib color, optional
  765. Color to use for all elements of the plot.
  766. {scatter, line}_kws : dictionaries, optional
  767. Additional keyword arguments passed to scatter() and plot() for drawing
  768. the components of the plot.
  769. ax : matplotlib axis, optional
  770. Plot into this axis, otherwise grab the current axis or make a new
  771. one if not existing.
  772. Returns
  773. -------
  774. ax: matplotlib axes
  775. Axes with the regression plot.
  776. See Also
  777. --------
  778. regplot : Plot a simple linear regression model.
  779. jointplot : Draw a :func:`residplot` with univariate marginal distributions
  780. (when used with ``kind="resid"``).
  781. Examples
  782. --------
  783. .. include:: ../docstrings/residplot.rst
  784. """
  785. plotter = _RegressionPlotter(x, y, data, ci=None,
  786. order=order, robust=robust,
  787. x_partial=x_partial, y_partial=y_partial,
  788. dropna=dropna, color=color, label=label)
  789. if ax is None:
  790. ax = plt.gca()
  791. # Calculate the residual from a linear regression
  792. _, yhat, _ = plotter.fit_regression(grid=plotter.x)
  793. plotter.y = plotter.y - yhat
  794. # Set the regression option on the plotter
  795. if lowess:
  796. plotter.lowess = True
  797. else:
  798. plotter.fit_reg = False
  799. # Plot a horizontal line at 0
  800. ax.axhline(0, ls=":", c=".2")
  801. # Draw the scatterplot
  802. scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
  803. line_kws = {} if line_kws is None else line_kws.copy()
  804. plotter.plot(ax, scatter_kws, line_kws)
  805. return ax