test_hist_box_by.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. import re
  2. import numpy as np
  3. import pytest
  4. import pandas.util._test_decorators as td
  5. from pandas import DataFrame
  6. import pandas._testing as tm
  7. from pandas.tests.plotting.common import (
  8. TestPlotBase,
  9. _check_plot_works,
  10. )
  11. @pytest.fixture
  12. def hist_df():
  13. np.random.seed(0)
  14. df = DataFrame(np.random.randn(30, 2), columns=["A", "B"])
  15. df["C"] = np.random.choice(["a", "b", "c"], 30)
  16. df["D"] = np.random.choice(["a", "b", "c"], 30)
  17. return df
  18. @td.skip_if_no_mpl
  19. class TestHistWithBy(TestPlotBase):
  20. @pytest.mark.slow
  21. @pytest.mark.parametrize(
  22. "by, column, titles, legends",
  23. [
  24. ("C", "A", ["a", "b", "c"], [["A"]] * 3),
  25. ("C", ["A", "B"], ["a", "b", "c"], [["A", "B"]] * 3),
  26. ("C", None, ["a", "b", "c"], [["A", "B"]] * 3),
  27. (
  28. ["C", "D"],
  29. "A",
  30. [
  31. "(a, a)",
  32. "(a, b)",
  33. "(a, c)",
  34. "(b, a)",
  35. "(b, b)",
  36. "(b, c)",
  37. "(c, a)",
  38. "(c, b)",
  39. "(c, c)",
  40. ],
  41. [["A"]] * 9,
  42. ),
  43. (
  44. ["C", "D"],
  45. ["A", "B"],
  46. [
  47. "(a, a)",
  48. "(a, b)",
  49. "(a, c)",
  50. "(b, a)",
  51. "(b, b)",
  52. "(b, c)",
  53. "(c, a)",
  54. "(c, b)",
  55. "(c, c)",
  56. ],
  57. [["A", "B"]] * 9,
  58. ),
  59. (
  60. ["C", "D"],
  61. None,
  62. [
  63. "(a, a)",
  64. "(a, b)",
  65. "(a, c)",
  66. "(b, a)",
  67. "(b, b)",
  68. "(b, c)",
  69. "(c, a)",
  70. "(c, b)",
  71. "(c, c)",
  72. ],
  73. [["A", "B"]] * 9,
  74. ),
  75. ],
  76. )
  77. def test_hist_plot_by_argument(self, by, column, titles, legends, hist_df):
  78. # GH 15079
  79. axes = _check_plot_works(
  80. hist_df.plot.hist, column=column, by=by, default_axes=True
  81. )
  82. result_titles = [ax.get_title() for ax in axes]
  83. result_legends = [
  84. [legend.get_text() for legend in ax.get_legend().texts] for ax in axes
  85. ]
  86. assert result_legends == legends
  87. assert result_titles == titles
  88. @pytest.mark.parametrize(
  89. "by, column, titles, legends",
  90. [
  91. (0, "A", ["a", "b", "c"], [["A"]] * 3),
  92. (0, None, ["a", "b", "c"], [["A", "B"]] * 3),
  93. (
  94. [0, "D"],
  95. "A",
  96. [
  97. "(a, a)",
  98. "(a, b)",
  99. "(a, c)",
  100. "(b, a)",
  101. "(b, b)",
  102. "(b, c)",
  103. "(c, a)",
  104. "(c, b)",
  105. "(c, c)",
  106. ],
  107. [["A"]] * 9,
  108. ),
  109. ],
  110. )
  111. def test_hist_plot_by_0(self, by, column, titles, legends, hist_df):
  112. # GH 15079
  113. df = hist_df.copy()
  114. df = df.rename(columns={"C": 0})
  115. axes = _check_plot_works(df.plot.hist, default_axes=True, column=column, by=by)
  116. result_titles = [ax.get_title() for ax in axes]
  117. result_legends = [
  118. [legend.get_text() for legend in ax.get_legend().texts] for ax in axes
  119. ]
  120. assert result_legends == legends
  121. assert result_titles == titles
  122. @pytest.mark.parametrize(
  123. "by, column",
  124. [
  125. ([], ["A"]),
  126. ([], ["A", "B"]),
  127. ((), None),
  128. ((), ["A", "B"]),
  129. ],
  130. )
  131. def test_hist_plot_empty_list_string_tuple_by(self, by, column, hist_df):
  132. # GH 15079
  133. msg = "No group keys passed"
  134. with pytest.raises(ValueError, match=msg):
  135. _check_plot_works(
  136. hist_df.plot.hist, default_axes=True, column=column, by=by
  137. )
  138. @pytest.mark.slow
  139. @pytest.mark.parametrize(
  140. "by, column, layout, axes_num",
  141. [
  142. (["C"], "A", (2, 2), 3),
  143. ("C", "A", (2, 2), 3),
  144. (["C"], ["A"], (1, 3), 3),
  145. ("C", None, (3, 1), 3),
  146. ("C", ["A", "B"], (3, 1), 3),
  147. (["C", "D"], "A", (9, 1), 9),
  148. (["C", "D"], "A", (3, 3), 9),
  149. (["C", "D"], ["A"], (5, 2), 9),
  150. (["C", "D"], ["A", "B"], (9, 1), 9),
  151. (["C", "D"], None, (9, 1), 9),
  152. (["C", "D"], ["A", "B"], (5, 2), 9),
  153. ],
  154. )
  155. def test_hist_plot_layout_with_by(self, by, column, layout, axes_num, hist_df):
  156. # GH 15079
  157. # _check_plot_works adds an ax so catch warning. see GH #13188
  158. with tm.assert_produces_warning(UserWarning, check_stacklevel=False):
  159. axes = _check_plot_works(
  160. hist_df.plot.hist, column=column, by=by, layout=layout
  161. )
  162. self._check_axes_shape(axes, axes_num=axes_num, layout=layout)
  163. @pytest.mark.parametrize(
  164. "msg, by, layout",
  165. [
  166. ("larger than required size", ["C", "D"], (1, 1)),
  167. (re.escape("Layout must be a tuple of (rows, columns)"), "C", (1,)),
  168. ("At least one dimension of layout must be positive", "C", (-1, -1)),
  169. ],
  170. )
  171. def test_hist_plot_invalid_layout_with_by_raises(self, msg, by, layout, hist_df):
  172. # GH 15079, test if error is raised when invalid layout is given
  173. with pytest.raises(ValueError, match=msg):
  174. hist_df.plot.hist(column=["A", "B"], by=by, layout=layout)
  175. @pytest.mark.slow
  176. def test_axis_share_x_with_by(self, hist_df):
  177. # GH 15079
  178. ax1, ax2, ax3 = hist_df.plot.hist(column="A", by="C", sharex=True)
  179. # share x
  180. assert self.get_x_axis(ax1).joined(ax1, ax2)
  181. assert self.get_x_axis(ax2).joined(ax1, ax2)
  182. assert self.get_x_axis(ax3).joined(ax1, ax3)
  183. assert self.get_x_axis(ax3).joined(ax2, ax3)
  184. # don't share y
  185. assert not self.get_y_axis(ax1).joined(ax1, ax2)
  186. assert not self.get_y_axis(ax2).joined(ax1, ax2)
  187. assert not self.get_y_axis(ax3).joined(ax1, ax3)
  188. assert not self.get_y_axis(ax3).joined(ax2, ax3)
  189. @pytest.mark.slow
  190. def test_axis_share_y_with_by(self, hist_df):
  191. # GH 15079
  192. ax1, ax2, ax3 = hist_df.plot.hist(column="A", by="C", sharey=True)
  193. # share y
  194. assert self.get_y_axis(ax1).joined(ax1, ax2)
  195. assert self.get_y_axis(ax2).joined(ax1, ax2)
  196. assert self.get_y_axis(ax3).joined(ax1, ax3)
  197. assert self.get_y_axis(ax3).joined(ax2, ax3)
  198. # don't share x
  199. assert not self.get_x_axis(ax1).joined(ax1, ax2)
  200. assert not self.get_x_axis(ax2).joined(ax1, ax2)
  201. assert not self.get_x_axis(ax3).joined(ax1, ax3)
  202. assert not self.get_x_axis(ax3).joined(ax2, ax3)
  203. @pytest.mark.parametrize("figsize", [(12, 8), (20, 10)])
  204. def test_figure_shape_hist_with_by(self, figsize, hist_df):
  205. # GH 15079
  206. axes = hist_df.plot.hist(column="A", by="C", figsize=figsize)
  207. self._check_axes_shape(axes, axes_num=3, figsize=figsize)
  208. @td.skip_if_no_mpl
  209. class TestBoxWithBy(TestPlotBase):
  210. @pytest.mark.parametrize(
  211. "by, column, titles, xticklabels",
  212. [
  213. ("C", "A", ["A"], [["a", "b", "c"]]),
  214. (
  215. ["C", "D"],
  216. "A",
  217. ["A"],
  218. [
  219. [
  220. "(a, a)",
  221. "(a, b)",
  222. "(a, c)",
  223. "(b, a)",
  224. "(b, b)",
  225. "(b, c)",
  226. "(c, a)",
  227. "(c, b)",
  228. "(c, c)",
  229. ]
  230. ],
  231. ),
  232. ("C", ["A", "B"], ["A", "B"], [["a", "b", "c"]] * 2),
  233. (
  234. ["C", "D"],
  235. ["A", "B"],
  236. ["A", "B"],
  237. [
  238. [
  239. "(a, a)",
  240. "(a, b)",
  241. "(a, c)",
  242. "(b, a)",
  243. "(b, b)",
  244. "(b, c)",
  245. "(c, a)",
  246. "(c, b)",
  247. "(c, c)",
  248. ]
  249. ]
  250. * 2,
  251. ),
  252. (["C"], None, ["A", "B"], [["a", "b", "c"]] * 2),
  253. ],
  254. )
  255. def test_box_plot_by_argument(self, by, column, titles, xticklabels, hist_df):
  256. # GH 15079
  257. axes = _check_plot_works(
  258. hist_df.plot.box, default_axes=True, column=column, by=by
  259. )
  260. result_titles = [ax.get_title() for ax in axes]
  261. result_xticklabels = [
  262. [label.get_text() for label in ax.get_xticklabels()] for ax in axes
  263. ]
  264. assert result_xticklabels == xticklabels
  265. assert result_titles == titles
  266. @pytest.mark.parametrize(
  267. "by, column, titles, xticklabels",
  268. [
  269. (0, "A", ["A"], [["a", "b", "c"]]),
  270. (
  271. [0, "D"],
  272. "A",
  273. ["A"],
  274. [
  275. [
  276. "(a, a)",
  277. "(a, b)",
  278. "(a, c)",
  279. "(b, a)",
  280. "(b, b)",
  281. "(b, c)",
  282. "(c, a)",
  283. "(c, b)",
  284. "(c, c)",
  285. ]
  286. ],
  287. ),
  288. (0, None, ["A", "B"], [["a", "b", "c"]] * 2),
  289. ],
  290. )
  291. def test_box_plot_by_0(self, by, column, titles, xticklabels, hist_df):
  292. # GH 15079
  293. df = hist_df.copy()
  294. df = df.rename(columns={"C": 0})
  295. axes = _check_plot_works(df.plot.box, default_axes=True, column=column, by=by)
  296. result_titles = [ax.get_title() for ax in axes]
  297. result_xticklabels = [
  298. [label.get_text() for label in ax.get_xticklabels()] for ax in axes
  299. ]
  300. assert result_xticklabels == xticklabels
  301. assert result_titles == titles
  302. @pytest.mark.parametrize(
  303. "by, column",
  304. [
  305. ([], ["A"]),
  306. ((), "A"),
  307. ([], None),
  308. ((), ["A", "B"]),
  309. ],
  310. )
  311. def test_box_plot_with_none_empty_list_by(self, by, column, hist_df):
  312. # GH 15079
  313. msg = "No group keys passed"
  314. with pytest.raises(ValueError, match=msg):
  315. _check_plot_works(hist_df.plot.box, default_axes=True, column=column, by=by)
  316. @pytest.mark.slow
  317. @pytest.mark.parametrize(
  318. "by, column, layout, axes_num",
  319. [
  320. (["C"], "A", (1, 1), 1),
  321. ("C", "A", (1, 1), 1),
  322. ("C", None, (2, 1), 2),
  323. ("C", ["A", "B"], (1, 2), 2),
  324. (["C", "D"], "A", (1, 1), 1),
  325. (["C", "D"], None, (1, 2), 2),
  326. ],
  327. )
  328. def test_box_plot_layout_with_by(self, by, column, layout, axes_num, hist_df):
  329. # GH 15079
  330. axes = _check_plot_works(
  331. hist_df.plot.box, default_axes=True, column=column, by=by, layout=layout
  332. )
  333. self._check_axes_shape(axes, axes_num=axes_num, layout=layout)
  334. @pytest.mark.parametrize(
  335. "msg, by, layout",
  336. [
  337. ("larger than required size", ["C", "D"], (1, 1)),
  338. (re.escape("Layout must be a tuple of (rows, columns)"), "C", (1,)),
  339. ("At least one dimension of layout must be positive", "C", (-1, -1)),
  340. ],
  341. )
  342. def test_box_plot_invalid_layout_with_by_raises(self, msg, by, layout, hist_df):
  343. # GH 15079, test if error is raised when invalid layout is given
  344. with pytest.raises(ValueError, match=msg):
  345. hist_df.plot.box(column=["A", "B"], by=by, layout=layout)
  346. @pytest.mark.parametrize("figsize", [(12, 8), (20, 10)])
  347. def test_figure_shape_hist_with_by(self, figsize, hist_df):
  348. # GH 15079
  349. axes = hist_df.plot.box(column="A", by="C", figsize=figsize)
  350. self._check_axes_shape(axes, axes_num=1, figsize=figsize)