utils.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944
  1. """Utility functions, mostly for internal use."""
  2. import os
  3. import inspect
  4. import warnings
  5. import colorsys
  6. from contextlib import contextmanager
  7. from urllib.request import urlopen, urlretrieve
  8. from types import ModuleType
  9. import numpy as np
  10. import pandas as pd
  11. import matplotlib as mpl
  12. from matplotlib.colors import to_rgb
  13. import matplotlib.pyplot as plt
  14. from matplotlib.cbook import normalize_kwargs
  15. from seaborn._core.typing import deprecated
  16. from seaborn.external.version import Version
  17. from seaborn.external.appdirs import user_cache_dir
  18. __all__ = ["desaturate", "saturate", "set_hls_values", "move_legend",
  19. "despine", "get_dataset_names", "get_data_home", "load_dataset"]
  20. DATASET_SOURCE = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master"
  21. DATASET_NAMES_URL = f"{DATASET_SOURCE}/dataset_names.txt"
  22. def ci_to_errsize(cis, heights):
  23. """Convert intervals to error arguments relative to plot heights.
  24. Parameters
  25. ----------
  26. cis : 2 x n sequence
  27. sequence of confidence interval limits
  28. heights : n sequence
  29. sequence of plot heights
  30. Returns
  31. -------
  32. errsize : 2 x n array
  33. sequence of error size relative to height values in correct
  34. format as argument for plt.bar
  35. """
  36. cis = np.atleast_2d(cis).reshape(2, -1)
  37. heights = np.atleast_1d(heights)
  38. errsize = []
  39. for i, (low, high) in enumerate(np.transpose(cis)):
  40. h = heights[i]
  41. elow = h - low
  42. ehigh = high - h
  43. errsize.append([elow, ehigh])
  44. errsize = np.asarray(errsize).T
  45. return errsize
  46. def _normal_quantile_func(q):
  47. """
  48. Compute the quantile function of the standard normal distribution.
  49. This wrapper exists because we are dropping scipy as a mandatory dependency
  50. but statistics.NormalDist was added to the standard library in 3.8.
  51. """
  52. try:
  53. from statistics import NormalDist
  54. qf = np.vectorize(NormalDist().inv_cdf)
  55. except ImportError:
  56. try:
  57. from scipy.stats import norm
  58. qf = norm.ppf
  59. except ImportError:
  60. msg = (
  61. "Standard normal quantile functions require either Python>=3.8 or scipy"
  62. )
  63. raise RuntimeError(msg)
  64. return qf(q)
  65. def _draw_figure(fig):
  66. """Force draw of a matplotlib figure, accounting for back-compat."""
  67. # See https://github.com/matplotlib/matplotlib/issues/19197 for context
  68. fig.canvas.draw()
  69. if fig.stale:
  70. try:
  71. fig.draw(fig.canvas.get_renderer())
  72. except AttributeError:
  73. pass
  74. def _default_color(method, hue, color, kws, saturation=1):
  75. """If needed, get a default color by using the matplotlib property cycle."""
  76. if hue is not None:
  77. # This warning is probably user-friendly, but it's currently triggered
  78. # in a FacetGrid context and I don't want to mess with that logic right now
  79. # if color is not None:
  80. # msg = "`color` is ignored when `hue` is assigned."
  81. # warnings.warn(msg)
  82. return None
  83. kws = kws.copy()
  84. kws.pop("label", None)
  85. if color is not None:
  86. if saturation < 1:
  87. color = desaturate(color, saturation)
  88. return color
  89. elif method.__name__ == "plot":
  90. color = _normalize_kwargs(kws, mpl.lines.Line2D).get("color")
  91. scout, = method([], [], scalex=False, scaley=False, color=color)
  92. color = scout.get_color()
  93. scout.remove()
  94. elif method.__name__ == "scatter":
  95. # Matplotlib will raise if the size of x/y don't match s/c,
  96. # and the latter might be in the kws dict
  97. scout_size = max(
  98. np.atleast_1d(kws.get(key, [])).shape[0]
  99. for key in ["s", "c", "fc", "facecolor", "facecolors"]
  100. )
  101. scout_x = scout_y = np.full(scout_size, np.nan)
  102. scout = method(scout_x, scout_y, **kws)
  103. facecolors = scout.get_facecolors()
  104. if not len(facecolors):
  105. # Handle bug in matplotlib <= 3.2 (I think)
  106. # This will limit the ability to use non color= kwargs to specify
  107. # a color in versions of matplotlib with the bug, but trying to
  108. # work out what the user wanted by re-implementing the broken logic
  109. # of inspecting the kwargs is probably too brittle.
  110. single_color = False
  111. else:
  112. single_color = np.unique(facecolors, axis=0).shape[0] == 1
  113. # Allow the user to specify an array of colors through various kwargs
  114. if "c" not in kws and single_color:
  115. color = to_rgb(facecolors[0])
  116. scout.remove()
  117. elif method.__name__ == "bar":
  118. # bar() needs masked, not empty data, to generate a patch
  119. scout, = method([np.nan], [np.nan], **kws)
  120. color = to_rgb(scout.get_facecolor())
  121. scout.remove()
  122. # Axes.bar adds both a patch and a container
  123. method.__self__.containers.pop(-1)
  124. elif method.__name__ == "fill_between":
  125. kws = _normalize_kwargs(kws, mpl.collections.PolyCollection)
  126. scout = method([], [], **kws)
  127. facecolor = scout.get_facecolor()
  128. color = to_rgb(facecolor[0])
  129. scout.remove()
  130. if saturation < 1:
  131. color = desaturate(color, saturation)
  132. return color
  133. def desaturate(color, prop):
  134. """Decrease the saturation channel of a color by some percent.
  135. Parameters
  136. ----------
  137. color : matplotlib color
  138. hex, rgb-tuple, or html color name
  139. prop : float
  140. saturation channel of color will be multiplied by this value
  141. Returns
  142. -------
  143. new_color : rgb tuple
  144. desaturated color code in RGB tuple representation
  145. """
  146. # Check inputs
  147. if not 0 <= prop <= 1:
  148. raise ValueError("prop must be between 0 and 1")
  149. # Get rgb tuple rep
  150. rgb = to_rgb(color)
  151. # Short circuit to avoid floating point issues
  152. if prop == 1:
  153. return rgb
  154. # Convert to hls
  155. h, l, s = colorsys.rgb_to_hls(*rgb)
  156. # Desaturate the saturation channel
  157. s *= prop
  158. # Convert back to rgb
  159. new_color = colorsys.hls_to_rgb(h, l, s)
  160. return new_color
  161. def saturate(color):
  162. """Return a fully saturated color with the same hue.
  163. Parameters
  164. ----------
  165. color : matplotlib color
  166. hex, rgb-tuple, or html color name
  167. Returns
  168. -------
  169. new_color : rgb tuple
  170. saturated color code in RGB tuple representation
  171. """
  172. return set_hls_values(color, s=1)
  173. def set_hls_values(color, h=None, l=None, s=None): # noqa
  174. """Independently manipulate the h, l, or s channels of a color.
  175. Parameters
  176. ----------
  177. color : matplotlib color
  178. hex, rgb-tuple, or html color name
  179. h, l, s : floats between 0 and 1, or None
  180. new values for each channel in hls space
  181. Returns
  182. -------
  183. new_color : rgb tuple
  184. new color code in RGB tuple representation
  185. """
  186. # Get an RGB tuple representation
  187. rgb = to_rgb(color)
  188. vals = list(colorsys.rgb_to_hls(*rgb))
  189. for i, val in enumerate([h, l, s]):
  190. if val is not None:
  191. vals[i] = val
  192. rgb = colorsys.hls_to_rgb(*vals)
  193. return rgb
  194. def axlabel(xlabel, ylabel, **kwargs):
  195. """Grab current axis and label it.
  196. DEPRECATED: will be removed in a future version.
  197. """
  198. msg = "This function is deprecated and will be removed in a future version"
  199. warnings.warn(msg, FutureWarning)
  200. ax = plt.gca()
  201. ax.set_xlabel(xlabel, **kwargs)
  202. ax.set_ylabel(ylabel, **kwargs)
  203. def remove_na(vector):
  204. """Helper method for removing null values from data vectors.
  205. Parameters
  206. ----------
  207. vector : vector object
  208. Must implement boolean masking with [] subscript syntax.
  209. Returns
  210. -------
  211. clean_clean : same type as ``vector``
  212. Vector of data with null values removed. May be a copy or a view.
  213. """
  214. return vector[pd.notnull(vector)]
  215. def get_color_cycle():
  216. """Return the list of colors in the current matplotlib color cycle
  217. Parameters
  218. ----------
  219. None
  220. Returns
  221. -------
  222. colors : list
  223. List of matplotlib colors in the current cycle, or dark gray if
  224. the current color cycle is empty.
  225. """
  226. cycler = mpl.rcParams['axes.prop_cycle']
  227. return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"]
  228. def despine(fig=None, ax=None, top=True, right=True, left=False,
  229. bottom=False, offset=None, trim=False):
  230. """Remove the top and right spines from plot(s).
  231. fig : matplotlib figure, optional
  232. Figure to despine all axes of, defaults to the current figure.
  233. ax : matplotlib axes, optional
  234. Specific axes object to despine. Ignored if fig is provided.
  235. top, right, left, bottom : boolean, optional
  236. If True, remove that spine.
  237. offset : int or dict, optional
  238. Absolute distance, in points, spines should be moved away
  239. from the axes (negative values move spines inward). A single value
  240. applies to all spines; a dict can be used to set offset values per
  241. side.
  242. trim : bool, optional
  243. If True, limit spines to the smallest and largest major tick
  244. on each non-despined axis.
  245. Returns
  246. -------
  247. None
  248. """
  249. # Get references to the axes we want
  250. if fig is None and ax is None:
  251. axes = plt.gcf().axes
  252. elif fig is not None:
  253. axes = fig.axes
  254. elif ax is not None:
  255. axes = [ax]
  256. for ax_i in axes:
  257. for side in ["top", "right", "left", "bottom"]:
  258. # Toggle the spine objects
  259. is_visible = not locals()[side]
  260. ax_i.spines[side].set_visible(is_visible)
  261. if offset is not None and is_visible:
  262. try:
  263. val = offset.get(side, 0)
  264. except AttributeError:
  265. val = offset
  266. ax_i.spines[side].set_position(('outward', val))
  267. # Potentially move the ticks
  268. if left and not right:
  269. maj_on = any(
  270. t.tick1line.get_visible()
  271. for t in ax_i.yaxis.majorTicks
  272. )
  273. min_on = any(
  274. t.tick1line.get_visible()
  275. for t in ax_i.yaxis.minorTicks
  276. )
  277. ax_i.yaxis.set_ticks_position("right")
  278. for t in ax_i.yaxis.majorTicks:
  279. t.tick2line.set_visible(maj_on)
  280. for t in ax_i.yaxis.minorTicks:
  281. t.tick2line.set_visible(min_on)
  282. if bottom and not top:
  283. maj_on = any(
  284. t.tick1line.get_visible()
  285. for t in ax_i.xaxis.majorTicks
  286. )
  287. min_on = any(
  288. t.tick1line.get_visible()
  289. for t in ax_i.xaxis.minorTicks
  290. )
  291. ax_i.xaxis.set_ticks_position("top")
  292. for t in ax_i.xaxis.majorTicks:
  293. t.tick2line.set_visible(maj_on)
  294. for t in ax_i.xaxis.minorTicks:
  295. t.tick2line.set_visible(min_on)
  296. if trim:
  297. # clip off the parts of the spines that extend past major ticks
  298. xticks = np.asarray(ax_i.get_xticks())
  299. if xticks.size:
  300. firsttick = np.compress(xticks >= min(ax_i.get_xlim()),
  301. xticks)[0]
  302. lasttick = np.compress(xticks <= max(ax_i.get_xlim()),
  303. xticks)[-1]
  304. ax_i.spines['bottom'].set_bounds(firsttick, lasttick)
  305. ax_i.spines['top'].set_bounds(firsttick, lasttick)
  306. newticks = xticks.compress(xticks <= lasttick)
  307. newticks = newticks.compress(newticks >= firsttick)
  308. ax_i.set_xticks(newticks)
  309. yticks = np.asarray(ax_i.get_yticks())
  310. if yticks.size:
  311. firsttick = np.compress(yticks >= min(ax_i.get_ylim()),
  312. yticks)[0]
  313. lasttick = np.compress(yticks <= max(ax_i.get_ylim()),
  314. yticks)[-1]
  315. ax_i.spines['left'].set_bounds(firsttick, lasttick)
  316. ax_i.spines['right'].set_bounds(firsttick, lasttick)
  317. newticks = yticks.compress(yticks <= lasttick)
  318. newticks = newticks.compress(newticks >= firsttick)
  319. ax_i.set_yticks(newticks)
  320. def move_legend(obj, loc, **kwargs):
  321. """
  322. Recreate a plot's legend at a new location.
  323. The name is a slight misnomer. Matplotlib legends do not expose public
  324. control over their position parameters. So this function creates a new legend,
  325. copying over the data from the original object, which is then removed.
  326. Parameters
  327. ----------
  328. obj : the object with the plot
  329. This argument can be either a seaborn or matplotlib object:
  330. - :class:`seaborn.FacetGrid` or :class:`seaborn.PairGrid`
  331. - :class:`matplotlib.axes.Axes` or :class:`matplotlib.figure.Figure`
  332. loc : str or int
  333. Location argument, as in :meth:`matplotlib.axes.Axes.legend`.
  334. kwargs
  335. Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.legend`.
  336. Examples
  337. --------
  338. .. include:: ../docstrings/move_legend.rst
  339. """
  340. # This is a somewhat hackish solution that will hopefully be obviated by
  341. # upstream improvements to matplotlib legends that make them easier to
  342. # modify after creation.
  343. from seaborn.axisgrid import Grid # Avoid circular import
  344. # Locate the legend object and a method to recreate the legend
  345. if isinstance(obj, Grid):
  346. old_legend = obj.legend
  347. legend_func = obj.figure.legend
  348. elif isinstance(obj, mpl.axes.Axes):
  349. old_legend = obj.legend_
  350. legend_func = obj.legend
  351. elif isinstance(obj, mpl.figure.Figure):
  352. if obj.legends:
  353. old_legend = obj.legends[-1]
  354. else:
  355. old_legend = None
  356. legend_func = obj.legend
  357. else:
  358. err = "`obj` must be a seaborn Grid or matplotlib Axes or Figure instance."
  359. raise TypeError(err)
  360. if old_legend is None:
  361. err = f"{obj} has no legend attached."
  362. raise ValueError(err)
  363. # Extract the components of the legend we need to reuse
  364. # Import here to avoid a circular import
  365. from seaborn._compat import get_legend_handles
  366. handles = get_legend_handles(old_legend)
  367. labels = [t.get_text() for t in old_legend.get_texts()]
  368. # Handle the case where the user is trying to override the labels
  369. if (new_labels := kwargs.pop("labels", None)) is not None:
  370. if len(new_labels) != len(labels):
  371. err = "Length of new labels does not match existing legend."
  372. raise ValueError(err)
  373. labels = new_labels
  374. # Extract legend properties that can be passed to the recreation method
  375. # (Vexingly, these don't all round-trip)
  376. legend_kws = inspect.signature(mpl.legend.Legend).parameters
  377. props = {k: v for k, v in old_legend.properties().items() if k in legend_kws}
  378. # Delegate default bbox_to_anchor rules to matplotlib
  379. props.pop("bbox_to_anchor")
  380. # Try to propagate the existing title and font properties; respect new ones too
  381. title = props.pop("title")
  382. if "title" in kwargs:
  383. title.set_text(kwargs.pop("title"))
  384. title_kwargs = {k: v for k, v in kwargs.items() if k.startswith("title_")}
  385. for key, val in title_kwargs.items():
  386. title.set(**{key[6:]: val})
  387. kwargs.pop(key)
  388. # Try to respect the frame visibility
  389. kwargs.setdefault("frameon", old_legend.legendPatch.get_visible())
  390. # Remove the old legend and create the new one
  391. props.update(kwargs)
  392. old_legend.remove()
  393. new_legend = legend_func(handles, labels, loc=loc, **props)
  394. new_legend.set_title(title.get_text(), title.get_fontproperties())
  395. # Let the Grid object continue to track the correct legend object
  396. if isinstance(obj, Grid):
  397. obj._legend = new_legend
  398. def _kde_support(data, bw, gridsize, cut, clip):
  399. """Establish support for a kernel density estimate."""
  400. support_min = max(data.min() - bw * cut, clip[0])
  401. support_max = min(data.max() + bw * cut, clip[1])
  402. support = np.linspace(support_min, support_max, gridsize)
  403. return support
  404. def ci(a, which=95, axis=None):
  405. """Return a percentile range from an array of values."""
  406. p = 50 - which / 2, 50 + which / 2
  407. return np.nanpercentile(a, p, axis)
  408. def get_dataset_names():
  409. """Report available example datasets, useful for reporting issues.
  410. Requires an internet connection.
  411. """
  412. with urlopen(DATASET_NAMES_URL) as resp:
  413. txt = resp.read()
  414. dataset_names = [name.strip() for name in txt.decode().split("\n")]
  415. return list(filter(None, dataset_names))
  416. def get_data_home(data_home=None):
  417. """Return a path to the cache directory for example datasets.
  418. This directory is used by :func:`load_dataset`.
  419. If the ``data_home`` argument is not provided, it will use a directory
  420. specified by the `SEABORN_DATA` environment variable (if it exists)
  421. or otherwise default to an OS-appropriate user cache location.
  422. """
  423. if data_home is None:
  424. data_home = os.environ.get("SEABORN_DATA", user_cache_dir("seaborn"))
  425. data_home = os.path.expanduser(data_home)
  426. if not os.path.exists(data_home):
  427. os.makedirs(data_home)
  428. return data_home
  429. def load_dataset(name, cache=True, data_home=None, **kws):
  430. """Load an example dataset from the online repository (requires internet).
  431. This function provides quick access to a small number of example datasets
  432. that are useful for documenting seaborn or generating reproducible examples
  433. for bug reports. It is not necessary for normal usage.
  434. Note that some of the datasets have a small amount of preprocessing applied
  435. to define a proper ordering for categorical variables.
  436. Use :func:`get_dataset_names` to see a list of available datasets.
  437. Parameters
  438. ----------
  439. name : str
  440. Name of the dataset (``{name}.csv`` on
  441. https://github.com/mwaskom/seaborn-data).
  442. cache : boolean, optional
  443. If True, try to load from the local cache first, and save to the cache
  444. if a download is required.
  445. data_home : string, optional
  446. The directory in which to cache data; see :func:`get_data_home`.
  447. kws : keys and values, optional
  448. Additional keyword arguments are passed to passed through to
  449. :func:`pandas.read_csv`.
  450. Returns
  451. -------
  452. df : :class:`pandas.DataFrame`
  453. Tabular data, possibly with some preprocessing applied.
  454. """
  455. # A common beginner mistake is to assume that one's personal data needs
  456. # to be passed through this function to be usable with seaborn.
  457. # Let's provide a more helpful error than you would otherwise get.
  458. if isinstance(name, pd.DataFrame):
  459. err = (
  460. "This function accepts only strings (the name of an example dataset). "
  461. "You passed a pandas DataFrame. If you have your own dataset, "
  462. "it is not necessary to use this function before plotting."
  463. )
  464. raise TypeError(err)
  465. url = f"{DATASET_SOURCE}/{name}.csv"
  466. if cache:
  467. cache_path = os.path.join(get_data_home(data_home), os.path.basename(url))
  468. if not os.path.exists(cache_path):
  469. if name not in get_dataset_names():
  470. raise ValueError(f"'{name}' is not one of the example datasets.")
  471. urlretrieve(url, cache_path)
  472. full_path = cache_path
  473. else:
  474. full_path = url
  475. df = pd.read_csv(full_path, **kws)
  476. if df.iloc[-1].isnull().all():
  477. df = df.iloc[:-1]
  478. # Set some columns as a categorical type with ordered levels
  479. if name == "tips":
  480. df["day"] = pd.Categorical(df["day"], ["Thur", "Fri", "Sat", "Sun"])
  481. df["sex"] = pd.Categorical(df["sex"], ["Male", "Female"])
  482. df["time"] = pd.Categorical(df["time"], ["Lunch", "Dinner"])
  483. df["smoker"] = pd.Categorical(df["smoker"], ["Yes", "No"])
  484. elif name == "flights":
  485. months = df["month"].str[:3]
  486. df["month"] = pd.Categorical(months, months.unique())
  487. elif name == "exercise":
  488. df["time"] = pd.Categorical(df["time"], ["1 min", "15 min", "30 min"])
  489. df["kind"] = pd.Categorical(df["kind"], ["rest", "walking", "running"])
  490. df["diet"] = pd.Categorical(df["diet"], ["no fat", "low fat"])
  491. elif name == "titanic":
  492. df["class"] = pd.Categorical(df["class"], ["First", "Second", "Third"])
  493. df["deck"] = pd.Categorical(df["deck"], list("ABCDEFG"))
  494. elif name == "penguins":
  495. df["sex"] = df["sex"].str.title()
  496. elif name == "diamonds":
  497. df["color"] = pd.Categorical(
  498. df["color"], ["D", "E", "F", "G", "H", "I", "J"],
  499. )
  500. df["clarity"] = pd.Categorical(
  501. df["clarity"], ["IF", "VVS1", "VVS2", "VS1", "VS2", "SI1", "SI2", "I1"],
  502. )
  503. df["cut"] = pd.Categorical(
  504. df["cut"], ["Ideal", "Premium", "Very Good", "Good", "Fair"],
  505. )
  506. elif name == "taxis":
  507. df["pickup"] = pd.to_datetime(df["pickup"])
  508. df["dropoff"] = pd.to_datetime(df["dropoff"])
  509. elif name == "seaice":
  510. df["Date"] = pd.to_datetime(df["Date"])
  511. elif name == "dowjones":
  512. df["Date"] = pd.to_datetime(df["Date"])
  513. return df
  514. def axis_ticklabels_overlap(labels):
  515. """Return a boolean for whether the list of ticklabels have overlaps.
  516. Parameters
  517. ----------
  518. labels : list of matplotlib ticklabels
  519. Returns
  520. -------
  521. overlap : boolean
  522. True if any of the labels overlap.
  523. """
  524. if not labels:
  525. return False
  526. try:
  527. bboxes = [l.get_window_extent() for l in labels]
  528. overlaps = [b.count_overlaps(bboxes) for b in bboxes]
  529. return max(overlaps) > 1
  530. except RuntimeError:
  531. # Issue on macos backend raises an error in the above code
  532. return False
  533. def axes_ticklabels_overlap(ax):
  534. """Return booleans for whether the x and y ticklabels on an Axes overlap.
  535. Parameters
  536. ----------
  537. ax : matplotlib Axes
  538. Returns
  539. -------
  540. x_overlap, y_overlap : booleans
  541. True when the labels on that axis overlap.
  542. """
  543. return (axis_ticklabels_overlap(ax.get_xticklabels()),
  544. axis_ticklabels_overlap(ax.get_yticklabels()))
  545. def locator_to_legend_entries(locator, limits, dtype):
  546. """Return levels and formatted levels for brief numeric legends."""
  547. raw_levels = locator.tick_values(*limits).astype(dtype)
  548. # The locator can return ticks outside the limits, clip them here
  549. raw_levels = [l for l in raw_levels if l >= limits[0] and l <= limits[1]]
  550. class dummy_axis:
  551. def get_view_interval(self):
  552. return limits
  553. if isinstance(locator, mpl.ticker.LogLocator):
  554. formatter = mpl.ticker.LogFormatter()
  555. else:
  556. formatter = mpl.ticker.ScalarFormatter()
  557. # Avoid having an offset/scientific notation which we don't currently
  558. # have any way of representing in the legend
  559. formatter.set_useOffset(False)
  560. formatter.set_scientific(False)
  561. formatter.axis = dummy_axis()
  562. # TODO: The following two lines should be replaced
  563. # once pinned matplotlib>=3.1.0 with:
  564. # formatted_levels = formatter.format_ticks(raw_levels)
  565. formatter.set_locs(raw_levels)
  566. formatted_levels = [formatter(x) for x in raw_levels]
  567. return raw_levels, formatted_levels
  568. def relative_luminance(color):
  569. """Calculate the relative luminance of a color according to W3C standards
  570. Parameters
  571. ----------
  572. color : matplotlib color or sequence of matplotlib colors
  573. Hex code, rgb-tuple, or html color name.
  574. Returns
  575. -------
  576. luminance : float(s) between 0 and 1
  577. """
  578. rgb = mpl.colors.colorConverter.to_rgba_array(color)[:, :3]
  579. rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4)
  580. lum = rgb.dot([.2126, .7152, .0722])
  581. try:
  582. return lum.item()
  583. except ValueError:
  584. return lum
  585. def to_utf8(obj):
  586. """Return a string representing a Python object.
  587. Strings (i.e. type ``str``) are returned unchanged.
  588. Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings.
  589. For other objects, the method ``__str__()`` is called, and the result is
  590. returned as a string.
  591. Parameters
  592. ----------
  593. obj : object
  594. Any Python object
  595. Returns
  596. -------
  597. s : str
  598. UTF-8-decoded string representation of ``obj``
  599. """
  600. if isinstance(obj, str):
  601. return obj
  602. try:
  603. return obj.decode(encoding="utf-8")
  604. except AttributeError: # obj is not bytes-like
  605. return str(obj)
  606. def _normalize_kwargs(kws, artist):
  607. """Wrapper for mpl.cbook.normalize_kwargs that supports <= 3.2.1."""
  608. _alias_map = {
  609. 'color': ['c'],
  610. 'linewidth': ['lw'],
  611. 'linestyle': ['ls'],
  612. 'facecolor': ['fc'],
  613. 'edgecolor': ['ec'],
  614. 'markerfacecolor': ['mfc'],
  615. 'markeredgecolor': ['mec'],
  616. 'markeredgewidth': ['mew'],
  617. 'markersize': ['ms']
  618. }
  619. try:
  620. kws = normalize_kwargs(kws, artist)
  621. except AttributeError:
  622. kws = normalize_kwargs(kws, _alias_map)
  623. return kws
  624. def _check_argument(param, options, value, prefix=False):
  625. """Raise if value for param is not in options."""
  626. if prefix and value is not None:
  627. failure = not any(value.startswith(p) for p in options if isinstance(p, str))
  628. else:
  629. failure = value not in options
  630. if failure:
  631. raise ValueError(
  632. f"The value for `{param}` must be one of {options}, "
  633. f"but {repr(value)} was passed."
  634. )
  635. return value
  636. def _assign_default_kwargs(kws, call_func, source_func):
  637. """Assign default kwargs for call_func using values from source_func."""
  638. # This exists so that axes-level functions and figure-level functions can
  639. # both call a Plotter method while having the default kwargs be defined in
  640. # the signature of the axes-level function.
  641. # An alternative would be to have a decorator on the method that sets its
  642. # defaults based on those defined in the axes-level function.
  643. # Then the figure-level function would not need to worry about defaults.
  644. # I am not sure which is better.
  645. needed = inspect.signature(call_func).parameters
  646. defaults = inspect.signature(source_func).parameters
  647. for param in needed:
  648. if param in defaults and param not in kws:
  649. kws[param] = defaults[param].default
  650. return kws
  651. def adjust_legend_subtitles(legend):
  652. """
  653. Make invisible-handle "subtitles" entries look more like titles.
  654. Note: This function is not part of the public API and may be changed or removed.
  655. """
  656. # Legend title not in rcParams until 3.0
  657. font_size = plt.rcParams.get("legend.title_fontsize", None)
  658. hpackers = legend.findobj(mpl.offsetbox.VPacker)[0].get_children()
  659. for hpack in hpackers:
  660. draw_area, text_area = hpack.get_children()
  661. handles = draw_area.get_children()
  662. if not all(artist.get_visible() for artist in handles):
  663. draw_area.set_width(0)
  664. for text in text_area.get_children():
  665. if font_size is not None:
  666. text.set_size(font_size)
  667. def _deprecate_ci(errorbar, ci):
  668. """
  669. Warn on usage of ci= and convert to appropriate errorbar= arg.
  670. ci was deprecated when errorbar was added in 0.12. It should not be removed
  671. completely for some time, but it can be moved out of function definitions
  672. (and extracted from kwargs) after one cycle.
  673. """
  674. if ci is not deprecated and ci != "deprecated":
  675. if ci is None:
  676. errorbar = None
  677. elif ci == "sd":
  678. errorbar = "sd"
  679. else:
  680. errorbar = ("ci", ci)
  681. msg = (
  682. "\n\nThe `ci` parameter is deprecated. "
  683. f"Use `errorbar={repr(errorbar)}` for the same effect.\n"
  684. )
  685. warnings.warn(msg, FutureWarning, stacklevel=3)
  686. return errorbar
  687. def _get_transform_functions(ax, axis):
  688. """Return the forward and inverse transforms for a given axis."""
  689. axis_obj = getattr(ax, f"{axis}axis")
  690. transform = axis_obj.get_transform()
  691. return transform.transform, transform.inverted().transform
  692. @contextmanager
  693. def _disable_autolayout():
  694. """Context manager for preventing rc-controlled auto-layout behavior."""
  695. # This is a workaround for an issue in matplotlib, for details see
  696. # https://github.com/mwaskom/seaborn/issues/2914
  697. # The only affect of this rcParam is to set the default value for
  698. # layout= in plt.figure, so we could just do that instead.
  699. # But then we would need to own the complexity of the transition
  700. # from tight_layout=True -> layout="tight". This seems easier,
  701. # but can be removed when (if) that is simpler on the matplotlib side,
  702. # or if the layout algorithms are improved to handle figure legends.
  703. orig_val = mpl.rcParams["figure.autolayout"]
  704. try:
  705. mpl.rcParams["figure.autolayout"] = False
  706. yield
  707. finally:
  708. mpl.rcParams["figure.autolayout"] = orig_val
  709. def _version_predates(lib: ModuleType, version: str) -> bool:
  710. """Helper function for checking version compatibility."""
  711. return Version(lib.__version__) < Version(version)
  712. def _scatter_legend_artist(**kws):
  713. kws = _normalize_kwargs(kws, mpl.collections.PathCollection)
  714. edgecolor = kws.pop("edgecolor", None)
  715. rc = mpl.rcParams
  716. line_kws = {
  717. "linestyle": "",
  718. "marker": kws.pop("marker", "o"),
  719. "markersize": np.sqrt(kws.pop("s", rc["lines.markersize"] ** 2)),
  720. "markerfacecolor": kws.pop("facecolor", kws.get("color")),
  721. "markeredgewidth": kws.pop("linewidth", 0),
  722. **kws,
  723. }
  724. if edgecolor is not None:
  725. if edgecolor == "face":
  726. line_kws["markeredgecolor"] = line_kws["markerfacecolor"]
  727. else:
  728. line_kws["markeredgecolor"] = edgecolor
  729. return mpl.lines.Line2D([], [], **line_kws)
  730. def _get_patch_legend_artist(fill):
  731. def legend_artist(**kws):
  732. color = kws.pop("color", None)
  733. if color is not None:
  734. if fill:
  735. kws["facecolor"] = color
  736. else:
  737. kws["edgecolor"] = color
  738. kws["facecolor"] = "none"
  739. return mpl.patches.Rectangle((0, 0), 0, 0, **kws)
  740. return legend_artist