misc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. from __future__ import annotations
  2. import random
  3. from typing import (
  4. TYPE_CHECKING,
  5. Hashable,
  6. )
  7. from matplotlib import patches
  8. import matplotlib.lines as mlines
  9. import numpy as np
  10. from pandas.core.dtypes.missing import notna
  11. from pandas.io.formats.printing import pprint_thing
  12. from pandas.plotting._matplotlib.style import get_standard_colors
  13. from pandas.plotting._matplotlib.tools import (
  14. create_subplots,
  15. do_adjust_figure,
  16. maybe_adjust_figure,
  17. set_ticks_props,
  18. )
  19. if TYPE_CHECKING:
  20. from matplotlib.axes import Axes
  21. from matplotlib.figure import Figure
  22. from pandas import (
  23. DataFrame,
  24. Index,
  25. Series,
  26. )
  27. def scatter_matrix(
  28. frame: DataFrame,
  29. alpha: float = 0.5,
  30. figsize=None,
  31. ax=None,
  32. grid: bool = False,
  33. diagonal: str = "hist",
  34. marker: str = ".",
  35. density_kwds=None,
  36. hist_kwds=None,
  37. range_padding: float = 0.05,
  38. **kwds,
  39. ):
  40. df = frame._get_numeric_data()
  41. n = df.columns.size
  42. naxes = n * n
  43. fig, axes = create_subplots(naxes=naxes, figsize=figsize, ax=ax, squeeze=False)
  44. # no gaps between subplots
  45. maybe_adjust_figure(fig, wspace=0, hspace=0)
  46. mask = notna(df)
  47. marker = _get_marker_compat(marker)
  48. hist_kwds = hist_kwds or {}
  49. density_kwds = density_kwds or {}
  50. # GH 14855
  51. kwds.setdefault("edgecolors", "none")
  52. boundaries_list = []
  53. for a in df.columns:
  54. values = df[a].values[mask[a].values]
  55. rmin_, rmax_ = np.min(values), np.max(values)
  56. rdelta_ext = (rmax_ - rmin_) * range_padding / 2
  57. boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext))
  58. for i, a in enumerate(df.columns):
  59. for j, b in enumerate(df.columns):
  60. ax = axes[i, j]
  61. if i == j:
  62. values = df[a].values[mask[a].values]
  63. # Deal with the diagonal by drawing a histogram there.
  64. if diagonal == "hist":
  65. ax.hist(values, **hist_kwds)
  66. elif diagonal in ("kde", "density"):
  67. from scipy.stats import gaussian_kde
  68. y = values
  69. gkde = gaussian_kde(y)
  70. ind = np.linspace(y.min(), y.max(), 1000)
  71. ax.plot(ind, gkde.evaluate(ind), **density_kwds)
  72. ax.set_xlim(boundaries_list[i])
  73. else:
  74. common = (mask[a] & mask[b]).values
  75. ax.scatter(
  76. df[b][common], df[a][common], marker=marker, alpha=alpha, **kwds
  77. )
  78. ax.set_xlim(boundaries_list[j])
  79. ax.set_ylim(boundaries_list[i])
  80. ax.set_xlabel(b)
  81. ax.set_ylabel(a)
  82. if j != 0:
  83. ax.yaxis.set_visible(False)
  84. if i != n - 1:
  85. ax.xaxis.set_visible(False)
  86. if len(df.columns) > 1:
  87. lim1 = boundaries_list[0]
  88. locs = axes[0][1].yaxis.get_majorticklocs()
  89. locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
  90. adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
  91. lim0 = axes[0][0].get_ylim()
  92. adj = adj * (lim0[1] - lim0[0]) + lim0[0]
  93. axes[0][0].yaxis.set_ticks(adj)
  94. if np.all(locs == locs.astype(int)):
  95. # if all ticks are int
  96. locs = locs.astype(int)
  97. axes[0][0].yaxis.set_ticklabels(locs)
  98. set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  99. return axes
  100. def _get_marker_compat(marker):
  101. if marker not in mlines.lineMarkers:
  102. return "o"
  103. return marker
  104. def radviz(
  105. frame: DataFrame,
  106. class_column,
  107. ax: Axes | None = None,
  108. color=None,
  109. colormap=None,
  110. **kwds,
  111. ) -> Axes:
  112. import matplotlib.pyplot as plt
  113. def normalize(series):
  114. a = min(series)
  115. b = max(series)
  116. return (series - a) / (b - a)
  117. n = len(frame)
  118. classes = frame[class_column].drop_duplicates()
  119. class_col = frame[class_column]
  120. df = frame.drop(class_column, axis=1).apply(normalize)
  121. if ax is None:
  122. ax = plt.gca()
  123. ax.set_xlim(-1, 1)
  124. ax.set_ylim(-1, 1)
  125. to_plot: dict[Hashable, list[list]] = {}
  126. colors = get_standard_colors(
  127. num_colors=len(classes), colormap=colormap, color_type="random", color=color
  128. )
  129. for kls in classes:
  130. to_plot[kls] = [[], []]
  131. m = len(frame.columns) - 1
  132. s = np.array(
  133. [(np.cos(t), np.sin(t)) for t in [2 * np.pi * (i / m) for i in range(m)]]
  134. )
  135. for i in range(n):
  136. row = df.iloc[i].values
  137. row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
  138. y = (s * row_).sum(axis=0) / row.sum()
  139. kls = class_col.iat[i]
  140. to_plot[kls][0].append(y[0])
  141. to_plot[kls][1].append(y[1])
  142. for i, kls in enumerate(classes):
  143. ax.scatter(
  144. to_plot[kls][0],
  145. to_plot[kls][1],
  146. color=colors[i],
  147. label=pprint_thing(kls),
  148. **kwds,
  149. )
  150. ax.legend()
  151. ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor="none"))
  152. for xy, name in zip(s, df.columns):
  153. ax.add_patch(patches.Circle(xy, radius=0.025, facecolor="gray"))
  154. if xy[0] < 0.0 and xy[1] < 0.0:
  155. ax.text(
  156. xy[0] - 0.025, xy[1] - 0.025, name, ha="right", va="top", size="small"
  157. )
  158. elif xy[0] < 0.0 <= xy[1]:
  159. ax.text(
  160. xy[0] - 0.025,
  161. xy[1] + 0.025,
  162. name,
  163. ha="right",
  164. va="bottom",
  165. size="small",
  166. )
  167. elif xy[1] < 0.0 <= xy[0]:
  168. ax.text(
  169. xy[0] + 0.025, xy[1] - 0.025, name, ha="left", va="top", size="small"
  170. )
  171. elif xy[0] >= 0.0 and xy[1] >= 0.0:
  172. ax.text(
  173. xy[0] + 0.025, xy[1] + 0.025, name, ha="left", va="bottom", size="small"
  174. )
  175. ax.axis("equal")
  176. return ax
  177. def andrews_curves(
  178. frame: DataFrame,
  179. class_column,
  180. ax: Axes | None = None,
  181. samples: int = 200,
  182. color=None,
  183. colormap=None,
  184. **kwds,
  185. ) -> Axes:
  186. import matplotlib.pyplot as plt
  187. def function(amplitudes):
  188. def f(t):
  189. x1 = amplitudes[0]
  190. result = x1 / np.sqrt(2.0)
  191. # Take the rest of the coefficients and resize them
  192. # appropriately. Take a copy of amplitudes as otherwise numpy
  193. # deletes the element from amplitudes itself.
  194. coeffs = np.delete(np.copy(amplitudes), 0)
  195. coeffs = np.resize(coeffs, (int((coeffs.size + 1) / 2), 2))
  196. # Generate the harmonics and arguments for the sin and cos
  197. # functions.
  198. harmonics = np.arange(0, coeffs.shape[0]) + 1
  199. trig_args = np.outer(harmonics, t)
  200. result += np.sum(
  201. coeffs[:, 0, np.newaxis] * np.sin(trig_args)
  202. + coeffs[:, 1, np.newaxis] * np.cos(trig_args),
  203. axis=0,
  204. )
  205. return result
  206. return f
  207. n = len(frame)
  208. class_col = frame[class_column]
  209. classes = frame[class_column].drop_duplicates()
  210. df = frame.drop(class_column, axis=1)
  211. t = np.linspace(-np.pi, np.pi, samples)
  212. used_legends: set[str] = set()
  213. color_values = get_standard_colors(
  214. num_colors=len(classes), colormap=colormap, color_type="random", color=color
  215. )
  216. colors = dict(zip(classes, color_values))
  217. if ax is None:
  218. ax = plt.gca()
  219. ax.set_xlim(-np.pi, np.pi)
  220. for i in range(n):
  221. row = df.iloc[i].values
  222. f = function(row)
  223. y = f(t)
  224. kls = class_col.iat[i]
  225. label = pprint_thing(kls)
  226. if label not in used_legends:
  227. used_legends.add(label)
  228. ax.plot(t, y, color=colors[kls], label=label, **kwds)
  229. else:
  230. ax.plot(t, y, color=colors[kls], **kwds)
  231. ax.legend(loc="upper right")
  232. ax.grid()
  233. return ax
  234. def bootstrap_plot(
  235. series: Series,
  236. fig: Figure | None = None,
  237. size: int = 50,
  238. samples: int = 500,
  239. **kwds,
  240. ) -> Figure:
  241. import matplotlib.pyplot as plt
  242. # TODO: is the failure mentioned below still relevant?
  243. # random.sample(ndarray, int) fails on python 3.3, sigh
  244. data = list(series.values)
  245. samplings = [random.sample(data, size) for _ in range(samples)]
  246. means = np.array([np.mean(sampling) for sampling in samplings])
  247. medians = np.array([np.median(sampling) for sampling in samplings])
  248. midranges = np.array(
  249. [(min(sampling) + max(sampling)) * 0.5 for sampling in samplings]
  250. )
  251. if fig is None:
  252. fig = plt.figure()
  253. x = list(range(samples))
  254. axes = []
  255. ax1 = fig.add_subplot(2, 3, 1)
  256. ax1.set_xlabel("Sample")
  257. axes.append(ax1)
  258. ax1.plot(x, means, **kwds)
  259. ax2 = fig.add_subplot(2, 3, 2)
  260. ax2.set_xlabel("Sample")
  261. axes.append(ax2)
  262. ax2.plot(x, medians, **kwds)
  263. ax3 = fig.add_subplot(2, 3, 3)
  264. ax3.set_xlabel("Sample")
  265. axes.append(ax3)
  266. ax3.plot(x, midranges, **kwds)
  267. ax4 = fig.add_subplot(2, 3, 4)
  268. ax4.set_xlabel("Mean")
  269. axes.append(ax4)
  270. ax4.hist(means, **kwds)
  271. ax5 = fig.add_subplot(2, 3, 5)
  272. ax5.set_xlabel("Median")
  273. axes.append(ax5)
  274. ax5.hist(medians, **kwds)
  275. ax6 = fig.add_subplot(2, 3, 6)
  276. ax6.set_xlabel("Midrange")
  277. axes.append(ax6)
  278. ax6.hist(midranges, **kwds)
  279. for axis in axes:
  280. plt.setp(axis.get_xticklabels(), fontsize=8)
  281. plt.setp(axis.get_yticklabels(), fontsize=8)
  282. if do_adjust_figure(fig):
  283. plt.tight_layout()
  284. return fig
  285. def parallel_coordinates(
  286. frame: DataFrame,
  287. class_column,
  288. cols=None,
  289. ax: Axes | None = None,
  290. color=None,
  291. use_columns: bool = False,
  292. xticks=None,
  293. colormap=None,
  294. axvlines: bool = True,
  295. axvlines_kwds=None,
  296. sort_labels: bool = False,
  297. **kwds,
  298. ) -> Axes:
  299. import matplotlib.pyplot as plt
  300. if axvlines_kwds is None:
  301. axvlines_kwds = {"linewidth": 1, "color": "black"}
  302. n = len(frame)
  303. classes = frame[class_column].drop_duplicates()
  304. class_col = frame[class_column]
  305. if cols is None:
  306. df = frame.drop(class_column, axis=1)
  307. else:
  308. df = frame[cols]
  309. used_legends: set[str] = set()
  310. ncols = len(df.columns)
  311. # determine values to use for xticks
  312. x: list[int] | Index
  313. if use_columns is True:
  314. if not np.all(np.isreal(list(df.columns))):
  315. raise ValueError("Columns must be numeric to be used as xticks")
  316. x = df.columns
  317. elif xticks is not None:
  318. if not np.all(np.isreal(xticks)):
  319. raise ValueError("xticks specified must be numeric")
  320. if len(xticks) != ncols:
  321. raise ValueError("Length of xticks must match number of columns")
  322. x = xticks
  323. else:
  324. x = list(range(ncols))
  325. if ax is None:
  326. ax = plt.gca()
  327. color_values = get_standard_colors(
  328. num_colors=len(classes), colormap=colormap, color_type="random", color=color
  329. )
  330. if sort_labels:
  331. classes = sorted(classes)
  332. color_values = sorted(color_values)
  333. colors = dict(zip(classes, color_values))
  334. for i in range(n):
  335. y = df.iloc[i].values
  336. kls = class_col.iat[i]
  337. label = pprint_thing(kls)
  338. if label not in used_legends:
  339. used_legends.add(label)
  340. ax.plot(x, y, color=colors[kls], label=label, **kwds)
  341. else:
  342. ax.plot(x, y, color=colors[kls], **kwds)
  343. if axvlines:
  344. for i in x:
  345. ax.axvline(i, **axvlines_kwds)
  346. ax.set_xticks(x)
  347. ax.set_xticklabels(df.columns)
  348. ax.set_xlim(x[0], x[-1])
  349. ax.legend(loc="upper right")
  350. ax.grid()
  351. return ax
  352. def lag_plot(series: Series, lag: int = 1, ax: Axes | None = None, **kwds) -> Axes:
  353. # workaround because `c='b'` is hardcoded in matplotlib's scatter method
  354. import matplotlib.pyplot as plt
  355. kwds.setdefault("c", plt.rcParams["patch.facecolor"])
  356. data = series.values
  357. y1 = data[:-lag]
  358. y2 = data[lag:]
  359. if ax is None:
  360. ax = plt.gca()
  361. ax.set_xlabel("y(t)")
  362. ax.set_ylabel(f"y(t + {lag})")
  363. ax.scatter(y1, y2, **kwds)
  364. return ax
  365. def autocorrelation_plot(series: Series, ax: Axes | None = None, **kwds) -> Axes:
  366. import matplotlib.pyplot as plt
  367. n = len(series)
  368. data = np.asarray(series)
  369. if ax is None:
  370. ax = plt.gca()
  371. ax.set_xlim(1, n)
  372. ax.set_ylim(-1.0, 1.0)
  373. mean = np.mean(data)
  374. c0 = np.sum((data - mean) ** 2) / n
  375. def r(h):
  376. return ((data[: n - h] - mean) * (data[h:] - mean)).sum() / n / c0
  377. x = np.arange(n) + 1
  378. y = [r(loc) for loc in x]
  379. z95 = 1.959963984540054
  380. z99 = 2.5758293035489004
  381. ax.axhline(y=z99 / np.sqrt(n), linestyle="--", color="grey")
  382. ax.axhline(y=z95 / np.sqrt(n), color="grey")
  383. ax.axhline(y=0.0, color="black")
  384. ax.axhline(y=-z95 / np.sqrt(n), color="grey")
  385. ax.axhline(y=-z99 / np.sqrt(n), linestyle="--", color="grey")
  386. ax.set_xlabel("Lag")
  387. ax.set_ylabel("Autocorrelation")
  388. ax.plot(x, y, **kwds)
  389. if "label" in kwds:
  390. ax.legend()
  391. ax.grid()
  392. return ax
  393. def unpack_single_str_list(keys):
  394. # GH 42795
  395. if isinstance(keys, list) and len(keys) == 1:
  396. keys = keys[0]
  397. return keys