tools.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. # being a bit too dynamic
  2. from __future__ import annotations
  3. from math import ceil
  4. from typing import (
  5. TYPE_CHECKING,
  6. Iterable,
  7. Sequence,
  8. )
  9. import warnings
  10. from matplotlib import ticker
  11. import matplotlib.table
  12. import numpy as np
  13. from pandas.util._exceptions import find_stack_level
  14. from pandas.core.dtypes.common import is_list_like
  15. from pandas.core.dtypes.generic import (
  16. ABCDataFrame,
  17. ABCIndex,
  18. ABCSeries,
  19. )
  20. if TYPE_CHECKING:
  21. from matplotlib.axes import Axes
  22. from matplotlib.axis import Axis
  23. from matplotlib.figure import Figure
  24. from matplotlib.lines import Line2D
  25. from matplotlib.table import Table
  26. from pandas import (
  27. DataFrame,
  28. Series,
  29. )
  30. def do_adjust_figure(fig: Figure) -> bool:
  31. """Whether fig has constrained_layout enabled."""
  32. if not hasattr(fig, "get_constrained_layout"):
  33. return False
  34. return not fig.get_constrained_layout()
  35. def maybe_adjust_figure(fig: Figure, *args, **kwargs) -> None:
  36. """Call fig.subplots_adjust unless fig has constrained_layout enabled."""
  37. if do_adjust_figure(fig):
  38. fig.subplots_adjust(*args, **kwargs)
  39. def format_date_labels(ax: Axes, rot) -> None:
  40. # mini version of autofmt_xdate
  41. for label in ax.get_xticklabels():
  42. label.set_ha("right")
  43. label.set_rotation(rot)
  44. fig = ax.get_figure()
  45. maybe_adjust_figure(fig, bottom=0.2)
  46. def table(
  47. ax, data: DataFrame | Series, rowLabels=None, colLabels=None, **kwargs
  48. ) -> Table:
  49. if isinstance(data, ABCSeries):
  50. data = data.to_frame()
  51. elif isinstance(data, ABCDataFrame):
  52. pass
  53. else:
  54. raise ValueError("Input data must be DataFrame or Series")
  55. if rowLabels is None:
  56. rowLabels = data.index
  57. if colLabels is None:
  58. colLabels = data.columns
  59. cellText = data.values
  60. return matplotlib.table.table(
  61. ax, cellText=cellText, rowLabels=rowLabels, colLabels=colLabels, **kwargs
  62. )
  63. def _get_layout(
  64. nplots: int,
  65. layout: tuple[int, int] | None = None,
  66. layout_type: str = "box",
  67. ) -> tuple[int, int]:
  68. if layout is not None:
  69. if not isinstance(layout, (tuple, list)) or len(layout) != 2:
  70. raise ValueError("Layout must be a tuple of (rows, columns)")
  71. nrows, ncols = layout
  72. if nrows == -1 and ncols > 0:
  73. layout = nrows, ncols = (ceil(nplots / ncols), ncols)
  74. elif ncols == -1 and nrows > 0:
  75. layout = nrows, ncols = (nrows, ceil(nplots / nrows))
  76. elif ncols <= 0 and nrows <= 0:
  77. msg = "At least one dimension of layout must be positive"
  78. raise ValueError(msg)
  79. if nrows * ncols < nplots:
  80. raise ValueError(
  81. f"Layout of {nrows}x{ncols} must be larger than required size {nplots}"
  82. )
  83. return layout
  84. if layout_type == "single":
  85. return (1, 1)
  86. elif layout_type == "horizontal":
  87. return (1, nplots)
  88. elif layout_type == "vertical":
  89. return (nplots, 1)
  90. layouts = {1: (1, 1), 2: (1, 2), 3: (2, 2), 4: (2, 2)}
  91. try:
  92. return layouts[nplots]
  93. except KeyError:
  94. k = 1
  95. while k**2 < nplots:
  96. k += 1
  97. if (k - 1) * k >= nplots:
  98. return k, (k - 1)
  99. else:
  100. return k, k
  101. # copied from matplotlib/pyplot.py and modified for pandas.plotting
  102. def create_subplots(
  103. naxes: int,
  104. sharex: bool = False,
  105. sharey: bool = False,
  106. squeeze: bool = True,
  107. subplot_kw=None,
  108. ax=None,
  109. layout=None,
  110. layout_type: str = "box",
  111. **fig_kw,
  112. ):
  113. """
  114. Create a figure with a set of subplots already made.
  115. This utility wrapper makes it convenient to create common layouts of
  116. subplots, including the enclosing figure object, in a single call.
  117. Parameters
  118. ----------
  119. naxes : int
  120. Number of required axes. Exceeded axes are set invisible. Default is
  121. nrows * ncols.
  122. sharex : bool
  123. If True, the X axis will be shared amongst all subplots.
  124. sharey : bool
  125. If True, the Y axis will be shared amongst all subplots.
  126. squeeze : bool
  127. If True, extra dimensions are squeezed out from the returned axis object:
  128. - if only one subplot is constructed (nrows=ncols=1), the resulting
  129. single Axis object is returned as a scalar.
  130. - for Nx1 or 1xN subplots, the returned object is a 1-d numpy object
  131. array of Axis objects are returned as numpy 1-d arrays.
  132. - for NxM subplots with N>1 and M>1 are returned as a 2d array.
  133. If False, no squeezing is done: the returned axis object is always
  134. a 2-d array containing Axis instances, even if it ends up being 1x1.
  135. subplot_kw : dict
  136. Dict with keywords passed to the add_subplot() call used to create each
  137. subplots.
  138. ax : Matplotlib axis object, optional
  139. layout : tuple
  140. Number of rows and columns of the subplot grid.
  141. If not specified, calculated from naxes and layout_type
  142. layout_type : {'box', 'horizontal', 'vertical'}, default 'box'
  143. Specify how to layout the subplot grid.
  144. fig_kw : Other keyword arguments to be passed to the figure() call.
  145. Note that all keywords not recognized above will be
  146. automatically included here.
  147. Returns
  148. -------
  149. fig, ax : tuple
  150. - fig is the Matplotlib Figure object
  151. - ax can be either a single axis object or an array of axis objects if
  152. more than one subplot was created. The dimensions of the resulting array
  153. can be controlled with the squeeze keyword, see above.
  154. Examples
  155. --------
  156. x = np.linspace(0, 2*np.pi, 400)
  157. y = np.sin(x**2)
  158. # Just a figure and one subplot
  159. f, ax = plt.subplots()
  160. ax.plot(x, y)
  161. ax.set_title('Simple plot')
  162. # Two subplots, unpack the output array immediately
  163. f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
  164. ax1.plot(x, y)
  165. ax1.set_title('Sharing Y axis')
  166. ax2.scatter(x, y)
  167. # Four polar axes
  168. plt.subplots(2, 2, subplot_kw=dict(polar=True))
  169. """
  170. import matplotlib.pyplot as plt
  171. if subplot_kw is None:
  172. subplot_kw = {}
  173. if ax is None:
  174. fig = plt.figure(**fig_kw)
  175. else:
  176. if is_list_like(ax):
  177. if squeeze:
  178. ax = flatten_axes(ax)
  179. if layout is not None:
  180. warnings.warn(
  181. "When passing multiple axes, layout keyword is ignored.",
  182. UserWarning,
  183. stacklevel=find_stack_level(),
  184. )
  185. if sharex or sharey:
  186. warnings.warn(
  187. "When passing multiple axes, sharex and sharey "
  188. "are ignored. These settings must be specified when creating axes.",
  189. UserWarning,
  190. stacklevel=find_stack_level(),
  191. )
  192. if ax.size == naxes:
  193. fig = ax.flat[0].get_figure()
  194. return fig, ax
  195. else:
  196. raise ValueError(
  197. f"The number of passed axes must be {naxes}, the "
  198. "same as the output plot"
  199. )
  200. fig = ax.get_figure()
  201. # if ax is passed and a number of subplots is 1, return ax as it is
  202. if naxes == 1:
  203. if squeeze:
  204. return fig, ax
  205. else:
  206. return fig, flatten_axes(ax)
  207. else:
  208. warnings.warn(
  209. "To output multiple subplots, the figure containing "
  210. "the passed axes is being cleared.",
  211. UserWarning,
  212. stacklevel=find_stack_level(),
  213. )
  214. fig.clear()
  215. nrows, ncols = _get_layout(naxes, layout=layout, layout_type=layout_type)
  216. nplots = nrows * ncols
  217. # Create empty object array to hold all axes. It's easiest to make it 1-d
  218. # so we can just append subplots upon creation, and then
  219. axarr = np.empty(nplots, dtype=object)
  220. # Create first subplot separately, so we can share it if requested
  221. ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw)
  222. if sharex:
  223. subplot_kw["sharex"] = ax0
  224. if sharey:
  225. subplot_kw["sharey"] = ax0
  226. axarr[0] = ax0
  227. # Note off-by-one counting because add_subplot uses the MATLAB 1-based
  228. # convention.
  229. for i in range(1, nplots):
  230. kwds = subplot_kw.copy()
  231. # Set sharex and sharey to None for blank/dummy axes, these can
  232. # interfere with proper axis limits on the visible axes if
  233. # they share axes e.g. issue #7528
  234. if i >= naxes:
  235. kwds["sharex"] = None
  236. kwds["sharey"] = None
  237. ax = fig.add_subplot(nrows, ncols, i + 1, **kwds)
  238. axarr[i] = ax
  239. if naxes != nplots:
  240. for ax in axarr[naxes:]:
  241. ax.set_visible(False)
  242. handle_shared_axes(axarr, nplots, naxes, nrows, ncols, sharex, sharey)
  243. if squeeze:
  244. # Reshape the array to have the final desired dimension (nrow,ncol),
  245. # though discarding unneeded dimensions that equal 1. If we only have
  246. # one subplot, just return it instead of a 1-element array.
  247. if nplots == 1:
  248. axes = axarr[0]
  249. else:
  250. axes = axarr.reshape(nrows, ncols).squeeze()
  251. else:
  252. # returned axis array will be always 2-d, even if nrows=ncols=1
  253. axes = axarr.reshape(nrows, ncols)
  254. return fig, axes
  255. def _remove_labels_from_axis(axis: Axis) -> None:
  256. for t in axis.get_majorticklabels():
  257. t.set_visible(False)
  258. # set_visible will not be effective if
  259. # minor axis has NullLocator and NullFormatter (default)
  260. if isinstance(axis.get_minor_locator(), ticker.NullLocator):
  261. axis.set_minor_locator(ticker.AutoLocator())
  262. if isinstance(axis.get_minor_formatter(), ticker.NullFormatter):
  263. axis.set_minor_formatter(ticker.FormatStrFormatter(""))
  264. for t in axis.get_minorticklabels():
  265. t.set_visible(False)
  266. axis.get_label().set_visible(False)
  267. def _has_externally_shared_axis(ax1: Axes, compare_axis: str) -> bool:
  268. """
  269. Return whether an axis is externally shared.
  270. Parameters
  271. ----------
  272. ax1 : matplotlib.axes.Axes
  273. Axis to query.
  274. compare_axis : str
  275. `"x"` or `"y"` according to whether the X-axis or Y-axis is being
  276. compared.
  277. Returns
  278. -------
  279. bool
  280. `True` if the axis is externally shared. Otherwise `False`.
  281. Notes
  282. -----
  283. If two axes with different positions are sharing an axis, they can be
  284. referred to as *externally* sharing the common axis.
  285. If two axes sharing an axis also have the same position, they can be
  286. referred to as *internally* sharing the common axis (a.k.a twinning).
  287. _handle_shared_axes() is only interested in axes externally sharing an
  288. axis, regardless of whether either of the axes is also internally sharing
  289. with a third axis.
  290. """
  291. if compare_axis == "x":
  292. axes = ax1.get_shared_x_axes()
  293. elif compare_axis == "y":
  294. axes = ax1.get_shared_y_axes()
  295. else:
  296. raise ValueError(
  297. "_has_externally_shared_axis() needs 'x' or 'y' as a second parameter"
  298. )
  299. axes = axes.get_siblings(ax1)
  300. # Retain ax1 and any of its siblings which aren't in the same position as it
  301. ax1_points = ax1.get_position().get_points()
  302. for ax2 in axes:
  303. if not np.array_equal(ax1_points, ax2.get_position().get_points()):
  304. return True
  305. return False
  306. def handle_shared_axes(
  307. axarr: Iterable[Axes],
  308. nplots: int,
  309. naxes: int,
  310. nrows: int,
  311. ncols: int,
  312. sharex: bool,
  313. sharey: bool,
  314. ) -> None:
  315. if nplots > 1:
  316. row_num = lambda x: x.get_subplotspec().rowspan.start
  317. col_num = lambda x: x.get_subplotspec().colspan.start
  318. is_first_col = lambda x: x.get_subplotspec().is_first_col()
  319. if nrows > 1:
  320. try:
  321. # first find out the ax layout,
  322. # so that we can correctly handle 'gaps"
  323. layout = np.zeros((nrows + 1, ncols + 1), dtype=np.bool_)
  324. for ax in axarr:
  325. layout[row_num(ax), col_num(ax)] = ax.get_visible()
  326. for ax in axarr:
  327. # only the last row of subplots should get x labels -> all
  328. # other off layout handles the case that the subplot is
  329. # the last in the column, because below is no subplot/gap.
  330. if not layout[row_num(ax) + 1, col_num(ax)]:
  331. continue
  332. if sharex or _has_externally_shared_axis(ax, "x"):
  333. _remove_labels_from_axis(ax.xaxis)
  334. except IndexError:
  335. # if gridspec is used, ax.rowNum and ax.colNum may different
  336. # from layout shape. in this case, use last_row logic
  337. is_last_row = lambda x: x.get_subplotspec().is_last_row()
  338. for ax in axarr:
  339. if is_last_row(ax):
  340. continue
  341. if sharex or _has_externally_shared_axis(ax, "x"):
  342. _remove_labels_from_axis(ax.xaxis)
  343. if ncols > 1:
  344. for ax in axarr:
  345. # only the first column should get y labels -> set all other to
  346. # off as we only have labels in the first column and we always
  347. # have a subplot there, we can skip the layout test
  348. if is_first_col(ax):
  349. continue
  350. if sharey or _has_externally_shared_axis(ax, "y"):
  351. _remove_labels_from_axis(ax.yaxis)
  352. def flatten_axes(axes: Axes | Sequence[Axes]) -> np.ndarray:
  353. if not is_list_like(axes):
  354. return np.array([axes])
  355. elif isinstance(axes, (np.ndarray, ABCIndex)):
  356. return np.asarray(axes).ravel()
  357. return np.array(axes)
  358. def set_ticks_props(
  359. axes: Axes | Sequence[Axes],
  360. xlabelsize=None,
  361. xrot=None,
  362. ylabelsize=None,
  363. yrot=None,
  364. ):
  365. import matplotlib.pyplot as plt
  366. for ax in flatten_axes(axes):
  367. if xlabelsize is not None:
  368. plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
  369. if xrot is not None:
  370. plt.setp(ax.get_xticklabels(), rotation=xrot)
  371. if ylabelsize is not None:
  372. plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
  373. if yrot is not None:
  374. plt.setp(ax.get_yticklabels(), rotation=yrot)
  375. return axes
  376. def get_all_lines(ax: Axes) -> list[Line2D]:
  377. lines = ax.get_lines()
  378. if hasattr(ax, "right_ax"):
  379. lines += ax.right_ax.get_lines()
  380. if hasattr(ax, "left_ax"):
  381. lines += ax.left_ax.get_lines()
  382. return lines
  383. def get_xlim(lines: Iterable[Line2D]) -> tuple[float, float]:
  384. left, right = np.inf, -np.inf
  385. for line in lines:
  386. x = line.get_xdata(orig=False)
  387. left = min(np.nanmin(x), left)
  388. right = max(np.nanmax(x), right)
  389. return left, right