common.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. """
  2. Module consolidating common testing functions for checking plotting.
  3. """
  4. from __future__ import annotations
  5. from typing import (
  6. TYPE_CHECKING,
  7. Sequence,
  8. )
  9. import numpy as np
  10. from pandas.util._decorators import cache_readonly
  11. import pandas.util._test_decorators as td
  12. from pandas.core.dtypes.api import is_list_like
  13. import pandas as pd
  14. from pandas import Series
  15. import pandas._testing as tm
  16. if TYPE_CHECKING:
  17. from matplotlib.axes import Axes
  18. @td.skip_if_no_mpl
  19. class TestPlotBase:
  20. """
  21. This is a common base class used for various plotting tests
  22. """
  23. def setup_method(self):
  24. import matplotlib as mpl
  25. mpl.rcdefaults()
  26. def teardown_method(self):
  27. tm.close()
  28. @cache_readonly
  29. def plt(self):
  30. import matplotlib.pyplot as plt
  31. return plt
  32. @cache_readonly
  33. def colorconverter(self):
  34. from matplotlib import colors
  35. return colors.colorConverter
  36. def _check_legend_labels(self, axes, labels=None, visible=True):
  37. """
  38. Check each axes has expected legend labels
  39. Parameters
  40. ----------
  41. axes : matplotlib Axes object, or its list-like
  42. labels : list-like
  43. expected legend labels
  44. visible : bool
  45. expected legend visibility. labels are checked only when visible is
  46. True
  47. """
  48. if visible and (labels is None):
  49. raise ValueError("labels must be specified when visible is True")
  50. axes = self._flatten_visible(axes)
  51. for ax in axes:
  52. if visible:
  53. assert ax.get_legend() is not None
  54. self._check_text_labels(ax.get_legend().get_texts(), labels)
  55. else:
  56. assert ax.get_legend() is None
  57. def _check_legend_marker(self, ax, expected_markers=None, visible=True):
  58. """
  59. Check ax has expected legend markers
  60. Parameters
  61. ----------
  62. ax : matplotlib Axes object
  63. expected_markers : list-like
  64. expected legend markers
  65. visible : bool
  66. expected legend visibility. labels are checked only when visible is
  67. True
  68. """
  69. if visible and (expected_markers is None):
  70. raise ValueError("Markers must be specified when visible is True")
  71. if visible:
  72. handles, _ = ax.get_legend_handles_labels()
  73. markers = [handle.get_marker() for handle in handles]
  74. assert markers == expected_markers
  75. else:
  76. assert ax.get_legend() is None
  77. def _check_data(self, xp, rs):
  78. """
  79. Check each axes has identical lines
  80. Parameters
  81. ----------
  82. xp : matplotlib Axes object
  83. rs : matplotlib Axes object
  84. """
  85. xp_lines = xp.get_lines()
  86. rs_lines = rs.get_lines()
  87. assert len(xp_lines) == len(rs_lines)
  88. for xpl, rsl in zip(xp_lines, rs_lines):
  89. xpdata = xpl.get_xydata()
  90. rsdata = rsl.get_xydata()
  91. tm.assert_almost_equal(xpdata, rsdata)
  92. tm.close()
  93. def _check_visible(self, collections, visible=True):
  94. """
  95. Check each artist is visible or not
  96. Parameters
  97. ----------
  98. collections : matplotlib Artist or its list-like
  99. target Artist or its list or collection
  100. visible : bool
  101. expected visibility
  102. """
  103. from matplotlib.collections import Collection
  104. if not isinstance(collections, Collection) and not is_list_like(collections):
  105. collections = [collections]
  106. for patch in collections:
  107. assert patch.get_visible() == visible
  108. def _check_patches_all_filled(
  109. self, axes: Axes | Sequence[Axes], filled: bool = True
  110. ) -> None:
  111. """
  112. Check for each artist whether it is filled or not
  113. Parameters
  114. ----------
  115. axes : matplotlib Axes object, or its list-like
  116. filled : bool
  117. expected filling
  118. """
  119. axes = self._flatten_visible(axes)
  120. for ax in axes:
  121. for patch in ax.patches:
  122. assert patch.fill == filled
  123. def _get_colors_mapped(self, series, colors):
  124. unique = series.unique()
  125. # unique and colors length can be differed
  126. # depending on slice value
  127. mapped = dict(zip(unique, colors))
  128. return [mapped[v] for v in series.values]
  129. def _check_colors(
  130. self, collections, linecolors=None, facecolors=None, mapping=None
  131. ):
  132. """
  133. Check each artist has expected line colors and face colors
  134. Parameters
  135. ----------
  136. collections : list-like
  137. list or collection of target artist
  138. linecolors : list-like which has the same length as collections
  139. list of expected line colors
  140. facecolors : list-like which has the same length as collections
  141. list of expected face colors
  142. mapping : Series
  143. Series used for color grouping key
  144. used for andrew_curves, parallel_coordinates, radviz test
  145. """
  146. from matplotlib.collections import (
  147. Collection,
  148. LineCollection,
  149. PolyCollection,
  150. )
  151. from matplotlib.lines import Line2D
  152. conv = self.colorconverter
  153. if linecolors is not None:
  154. if mapping is not None:
  155. linecolors = self._get_colors_mapped(mapping, linecolors)
  156. linecolors = linecolors[: len(collections)]
  157. assert len(collections) == len(linecolors)
  158. for patch, color in zip(collections, linecolors):
  159. if isinstance(patch, Line2D):
  160. result = patch.get_color()
  161. # Line2D may contains string color expression
  162. result = conv.to_rgba(result)
  163. elif isinstance(patch, (PolyCollection, LineCollection)):
  164. result = tuple(patch.get_edgecolor()[0])
  165. else:
  166. result = patch.get_edgecolor()
  167. expected = conv.to_rgba(color)
  168. assert result == expected
  169. if facecolors is not None:
  170. if mapping is not None:
  171. facecolors = self._get_colors_mapped(mapping, facecolors)
  172. facecolors = facecolors[: len(collections)]
  173. assert len(collections) == len(facecolors)
  174. for patch, color in zip(collections, facecolors):
  175. if isinstance(patch, Collection):
  176. # returned as list of np.array
  177. result = patch.get_facecolor()[0]
  178. else:
  179. result = patch.get_facecolor()
  180. if isinstance(result, np.ndarray):
  181. result = tuple(result)
  182. expected = conv.to_rgba(color)
  183. assert result == expected
  184. def _check_text_labels(self, texts, expected):
  185. """
  186. Check each text has expected labels
  187. Parameters
  188. ----------
  189. texts : matplotlib Text object, or its list-like
  190. target text, or its list
  191. expected : str or list-like which has the same length as texts
  192. expected text label, or its list
  193. """
  194. if not is_list_like(texts):
  195. assert texts.get_text() == expected
  196. else:
  197. labels = [t.get_text() for t in texts]
  198. assert len(labels) == len(expected)
  199. for label, e in zip(labels, expected):
  200. assert label == e
  201. def _check_ticks_props(
  202. self, axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None
  203. ):
  204. """
  205. Check each axes has expected tick properties
  206. Parameters
  207. ----------
  208. axes : matplotlib Axes object, or its list-like
  209. xlabelsize : number
  210. expected xticks font size
  211. xrot : number
  212. expected xticks rotation
  213. ylabelsize : number
  214. expected yticks font size
  215. yrot : number
  216. expected yticks rotation
  217. """
  218. from matplotlib.ticker import NullFormatter
  219. axes = self._flatten_visible(axes)
  220. for ax in axes:
  221. if xlabelsize is not None or xrot is not None:
  222. if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
  223. # If minor ticks has NullFormatter, rot / fontsize are not
  224. # retained
  225. labels = ax.get_xticklabels()
  226. else:
  227. labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True)
  228. for label in labels:
  229. if xlabelsize is not None:
  230. tm.assert_almost_equal(label.get_fontsize(), xlabelsize)
  231. if xrot is not None:
  232. tm.assert_almost_equal(label.get_rotation(), xrot)
  233. if ylabelsize is not None or yrot is not None:
  234. if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
  235. labels = ax.get_yticklabels()
  236. else:
  237. labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True)
  238. for label in labels:
  239. if ylabelsize is not None:
  240. tm.assert_almost_equal(label.get_fontsize(), ylabelsize)
  241. if yrot is not None:
  242. tm.assert_almost_equal(label.get_rotation(), yrot)
  243. def _check_ax_scales(self, axes, xaxis="linear", yaxis="linear"):
  244. """
  245. Check each axes has expected scales
  246. Parameters
  247. ----------
  248. axes : matplotlib Axes object, or its list-like
  249. xaxis : {'linear', 'log'}
  250. expected xaxis scale
  251. yaxis : {'linear', 'log'}
  252. expected yaxis scale
  253. """
  254. axes = self._flatten_visible(axes)
  255. for ax in axes:
  256. assert ax.xaxis.get_scale() == xaxis
  257. assert ax.yaxis.get_scale() == yaxis
  258. def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None):
  259. """
  260. Check expected number of axes is drawn in expected layout
  261. Parameters
  262. ----------
  263. axes : matplotlib Axes object, or its list-like
  264. axes_num : number
  265. expected number of axes. Unnecessary axes should be set to
  266. invisible.
  267. layout : tuple
  268. expected layout, (expected number of rows , columns)
  269. figsize : tuple
  270. expected figsize. default is matplotlib default
  271. """
  272. from pandas.plotting._matplotlib.tools import flatten_axes
  273. if figsize is None:
  274. figsize = (6.4, 4.8)
  275. visible_axes = self._flatten_visible(axes)
  276. if axes_num is not None:
  277. assert len(visible_axes) == axes_num
  278. for ax in visible_axes:
  279. # check something drawn on visible axes
  280. assert len(ax.get_children()) > 0
  281. if layout is not None:
  282. result = self._get_axes_layout(flatten_axes(axes))
  283. assert result == layout
  284. tm.assert_numpy_array_equal(
  285. visible_axes[0].figure.get_size_inches(),
  286. np.array(figsize, dtype=np.float64),
  287. )
  288. def _get_axes_layout(self, axes):
  289. x_set = set()
  290. y_set = set()
  291. for ax in axes:
  292. # check axes coordinates to estimate layout
  293. points = ax.get_position().get_points()
  294. x_set.add(points[0][0])
  295. y_set.add(points[0][1])
  296. return (len(y_set), len(x_set))
  297. def _flatten_visible(self, axes):
  298. """
  299. Flatten axes, and filter only visible
  300. Parameters
  301. ----------
  302. axes : matplotlib Axes object, or its list-like
  303. """
  304. from pandas.plotting._matplotlib.tools import flatten_axes
  305. axes = flatten_axes(axes)
  306. axes = [ax for ax in axes if ax.get_visible()]
  307. return axes
  308. def _check_has_errorbars(self, axes, xerr=0, yerr=0):
  309. """
  310. Check axes has expected number of errorbars
  311. Parameters
  312. ----------
  313. axes : matplotlib Axes object, or its list-like
  314. xerr : number
  315. expected number of x errorbar
  316. yerr : number
  317. expected number of y errorbar
  318. """
  319. axes = self._flatten_visible(axes)
  320. for ax in axes:
  321. containers = ax.containers
  322. xerr_count = 0
  323. yerr_count = 0
  324. for c in containers:
  325. has_xerr = getattr(c, "has_xerr", False)
  326. has_yerr = getattr(c, "has_yerr", False)
  327. if has_xerr:
  328. xerr_count += 1
  329. if has_yerr:
  330. yerr_count += 1
  331. assert xerr == xerr_count
  332. assert yerr == yerr_count
  333. def _check_box_return_type(
  334. self, returned, return_type, expected_keys=None, check_ax_title=True
  335. ):
  336. """
  337. Check box returned type is correct
  338. Parameters
  339. ----------
  340. returned : object to be tested, returned from boxplot
  341. return_type : str
  342. return_type passed to boxplot
  343. expected_keys : list-like, optional
  344. group labels in subplot case. If not passed,
  345. the function checks assuming boxplot uses single ax
  346. check_ax_title : bool
  347. Whether to check the ax.title is the same as expected_key
  348. Intended to be checked by calling from ``boxplot``.
  349. Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
  350. """
  351. from matplotlib.axes import Axes
  352. types = {"dict": dict, "axes": Axes, "both": tuple}
  353. if expected_keys is None:
  354. # should be fixed when the returning default is changed
  355. if return_type is None:
  356. return_type = "dict"
  357. assert isinstance(returned, types[return_type])
  358. if return_type == "both":
  359. assert isinstance(returned.ax, Axes)
  360. assert isinstance(returned.lines, dict)
  361. else:
  362. # should be fixed when the returning default is changed
  363. if return_type is None:
  364. for r in self._flatten_visible(returned):
  365. assert isinstance(r, Axes)
  366. return
  367. assert isinstance(returned, Series)
  368. assert sorted(returned.keys()) == sorted(expected_keys)
  369. for key, value in returned.items():
  370. assert isinstance(value, types[return_type])
  371. # check returned dict has correct mapping
  372. if return_type == "axes":
  373. if check_ax_title:
  374. assert value.get_title() == key
  375. elif return_type == "both":
  376. if check_ax_title:
  377. assert value.ax.get_title() == key
  378. assert isinstance(value.ax, Axes)
  379. assert isinstance(value.lines, dict)
  380. elif return_type == "dict":
  381. line = value["medians"][0]
  382. axes = line.axes
  383. if check_ax_title:
  384. assert axes.get_title() == key
  385. else:
  386. raise AssertionError
  387. def _check_grid_settings(self, obj, kinds, kws={}):
  388. # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
  389. import matplotlib as mpl
  390. def is_grid_on():
  391. xticks = self.plt.gca().xaxis.get_major_ticks()
  392. yticks = self.plt.gca().yaxis.get_major_ticks()
  393. xoff = all(not g.gridline.get_visible() for g in xticks)
  394. yoff = all(not g.gridline.get_visible() for g in yticks)
  395. return not (xoff and yoff)
  396. spndx = 1
  397. for kind in kinds:
  398. self.plt.subplot(1, 4 * len(kinds), spndx)
  399. spndx += 1
  400. mpl.rc("axes", grid=False)
  401. obj.plot(kind=kind, **kws)
  402. assert not is_grid_on()
  403. self.plt.clf()
  404. self.plt.subplot(1, 4 * len(kinds), spndx)
  405. spndx += 1
  406. mpl.rc("axes", grid=True)
  407. obj.plot(kind=kind, grid=False, **kws)
  408. assert not is_grid_on()
  409. self.plt.clf()
  410. if kind not in ["pie", "hexbin", "scatter"]:
  411. self.plt.subplot(1, 4 * len(kinds), spndx)
  412. spndx += 1
  413. mpl.rc("axes", grid=True)
  414. obj.plot(kind=kind, **kws)
  415. assert is_grid_on()
  416. self.plt.clf()
  417. self.plt.subplot(1, 4 * len(kinds), spndx)
  418. spndx += 1
  419. mpl.rc("axes", grid=False)
  420. obj.plot(kind=kind, grid=True, **kws)
  421. assert is_grid_on()
  422. self.plt.clf()
  423. def _unpack_cycler(self, rcParams, field="color"):
  424. """
  425. Auxiliary function for correctly unpacking cycler after MPL >= 1.5
  426. """
  427. return [v[field] for v in rcParams["axes.prop_cycle"]]
  428. def get_x_axis(self, ax):
  429. return ax._shared_axes["x"]
  430. def get_y_axis(self, ax):
  431. return ax._shared_axes["y"]
  432. def _check_plot_works(f, default_axes=False, **kwargs):
  433. """
  434. Create plot and ensure that plot return object is valid.
  435. Parameters
  436. ----------
  437. f : func
  438. Plotting function.
  439. default_axes : bool, optional
  440. If False (default):
  441. - If `ax` not in `kwargs`, then create subplot(211) and plot there
  442. - Create new subplot(212) and plot there as well
  443. - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`)
  444. If True:
  445. - Simply run plotting function with kwargs provided
  446. - All required axes instances will be created automatically
  447. - It is recommended to use it when the plotting function
  448. creates multiple axes itself. It helps avoid warnings like
  449. 'UserWarning: To output multiple subplots,
  450. the figure containing the passed axes is being cleared'
  451. **kwargs
  452. Keyword arguments passed to the plotting function.
  453. Returns
  454. -------
  455. Plot object returned by the last plotting.
  456. """
  457. import matplotlib.pyplot as plt
  458. if default_axes:
  459. gen_plots = _gen_default_plot
  460. else:
  461. gen_plots = _gen_two_subplots
  462. ret = None
  463. try:
  464. fig = kwargs.get("figure", plt.gcf())
  465. plt.clf()
  466. for ret in gen_plots(f, fig, **kwargs):
  467. tm.assert_is_valid_plot_return_object(ret)
  468. with tm.ensure_clean(return_filelike=True) as path:
  469. plt.savefig(path)
  470. finally:
  471. tm.close(fig)
  472. return ret
  473. def _gen_default_plot(f, fig, **kwargs):
  474. """
  475. Create plot in a default way.
  476. """
  477. yield f(**kwargs)
  478. def _gen_two_subplots(f, fig, **kwargs):
  479. """
  480. Create plot on two subplots forcefully created.
  481. """
  482. if "ax" not in kwargs:
  483. fig.add_subplot(211)
  484. yield f(**kwargs)
  485. if f is pd.plotting.bootstrap_plot:
  486. assert "ax" not in kwargs
  487. else:
  488. kwargs["ax"] = fig.add_subplot(212)
  489. yield f(**kwargs)