test_frame_subplots.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. """ Test cases for DataFrame.plot """
  2. import string
  3. import warnings
  4. import numpy as np
  5. import pytest
  6. from pandas.compat import is_platform_linux
  7. from pandas.compat.numpy import np_version_gte1p24
  8. import pandas.util._test_decorators as td
  9. import pandas as pd
  10. from pandas import (
  11. DataFrame,
  12. Series,
  13. date_range,
  14. )
  15. import pandas._testing as tm
  16. from pandas.tests.plotting.common import TestPlotBase
  17. from pandas.io.formats.printing import pprint_thing
  18. @td.skip_if_no_mpl
  19. class TestDataFramePlotsSubplots(TestPlotBase):
  20. @pytest.mark.slow
  21. def test_subplots(self):
  22. df = DataFrame(np.random.rand(10, 3), index=list(string.ascii_letters[:10]))
  23. for kind in ["bar", "barh", "line", "area"]:
  24. axes = df.plot(kind=kind, subplots=True, sharex=True, legend=True)
  25. self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
  26. assert axes.shape == (3,)
  27. for ax, column in zip(axes, df.columns):
  28. self._check_legend_labels(ax, labels=[pprint_thing(column)])
  29. for ax in axes[:-2]:
  30. self._check_visible(ax.xaxis) # xaxis must be visible for grid
  31. self._check_visible(ax.get_xticklabels(), visible=False)
  32. if kind != "bar":
  33. # change https://github.com/pandas-dev/pandas/issues/26714
  34. self._check_visible(ax.get_xticklabels(minor=True), visible=False)
  35. self._check_visible(ax.xaxis.get_label(), visible=False)
  36. self._check_visible(ax.get_yticklabels())
  37. self._check_visible(axes[-1].xaxis)
  38. self._check_visible(axes[-1].get_xticklabels())
  39. self._check_visible(axes[-1].get_xticklabels(minor=True))
  40. self._check_visible(axes[-1].xaxis.get_label())
  41. self._check_visible(axes[-1].get_yticklabels())
  42. axes = df.plot(kind=kind, subplots=True, sharex=False)
  43. for ax in axes:
  44. self._check_visible(ax.xaxis)
  45. self._check_visible(ax.get_xticklabels())
  46. self._check_visible(ax.get_xticklabels(minor=True))
  47. self._check_visible(ax.xaxis.get_label())
  48. self._check_visible(ax.get_yticklabels())
  49. axes = df.plot(kind=kind, subplots=True, legend=False)
  50. for ax in axes:
  51. assert ax.get_legend() is None
  52. def test_subplots_timeseries(self):
  53. idx = date_range(start="2014-07-01", freq="M", periods=10)
  54. df = DataFrame(np.random.rand(10, 3), index=idx)
  55. for kind in ["line", "area"]:
  56. axes = df.plot(kind=kind, subplots=True, sharex=True)
  57. self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
  58. for ax in axes[:-2]:
  59. # GH 7801
  60. self._check_visible(ax.xaxis) # xaxis must be visible for grid
  61. self._check_visible(ax.get_xticklabels(), visible=False)
  62. self._check_visible(ax.get_xticklabels(minor=True), visible=False)
  63. self._check_visible(ax.xaxis.get_label(), visible=False)
  64. self._check_visible(ax.get_yticklabels())
  65. self._check_visible(axes[-1].xaxis)
  66. self._check_visible(axes[-1].get_xticklabels())
  67. self._check_visible(axes[-1].get_xticklabels(minor=True))
  68. self._check_visible(axes[-1].xaxis.get_label())
  69. self._check_visible(axes[-1].get_yticklabels())
  70. self._check_ticks_props(axes, xrot=0)
  71. axes = df.plot(kind=kind, subplots=True, sharex=False, rot=45, fontsize=7)
  72. for ax in axes:
  73. self._check_visible(ax.xaxis)
  74. self._check_visible(ax.get_xticklabels())
  75. self._check_visible(ax.get_xticklabels(minor=True))
  76. self._check_visible(ax.xaxis.get_label())
  77. self._check_visible(ax.get_yticklabels())
  78. self._check_ticks_props(ax, xlabelsize=7, xrot=45, ylabelsize=7)
  79. def test_subplots_timeseries_y_axis(self):
  80. # GH16953
  81. data = {
  82. "numeric": np.array([1, 2, 5]),
  83. "timedelta": [
  84. pd.Timedelta(-10, unit="s"),
  85. pd.Timedelta(10, unit="m"),
  86. pd.Timedelta(10, unit="h"),
  87. ],
  88. "datetime_no_tz": [
  89. pd.to_datetime("2017-08-01 00:00:00"),
  90. pd.to_datetime("2017-08-01 02:00:00"),
  91. pd.to_datetime("2017-08-02 00:00:00"),
  92. ],
  93. "datetime_all_tz": [
  94. pd.to_datetime("2017-08-01 00:00:00", utc=True),
  95. pd.to_datetime("2017-08-01 02:00:00", utc=True),
  96. pd.to_datetime("2017-08-02 00:00:00", utc=True),
  97. ],
  98. "text": ["This", "should", "fail"],
  99. }
  100. testdata = DataFrame(data)
  101. y_cols = ["numeric", "timedelta", "datetime_no_tz", "datetime_all_tz"]
  102. for col in y_cols:
  103. ax = testdata.plot(y=col)
  104. result = ax.get_lines()[0].get_data()[1]
  105. expected = testdata[col].values
  106. assert (result == expected).all()
  107. msg = "no numeric data to plot"
  108. with pytest.raises(TypeError, match=msg):
  109. testdata.plot(y="text")
  110. @pytest.mark.xfail(reason="not support for period, categorical, datetime_mixed_tz")
  111. def test_subplots_timeseries_y_axis_not_supported(self):
  112. """
  113. This test will fail for:
  114. period:
  115. since period isn't yet implemented in ``select_dtypes``
  116. and because it will need a custom value converter +
  117. tick formatter (as was done for x-axis plots)
  118. categorical:
  119. because it will need a custom value converter +
  120. tick formatter (also doesn't work for x-axis, as of now)
  121. datetime_mixed_tz:
  122. because of the way how pandas handles ``Series`` of
  123. ``datetime`` objects with different timezone,
  124. generally converting ``datetime`` objects in a tz-aware
  125. form could help with this problem
  126. """
  127. data = {
  128. "numeric": np.array([1, 2, 5]),
  129. "period": [
  130. pd.Period("2017-08-01 00:00:00", freq="H"),
  131. pd.Period("2017-08-01 02:00", freq="H"),
  132. pd.Period("2017-08-02 00:00:00", freq="H"),
  133. ],
  134. "categorical": pd.Categorical(
  135. ["c", "b", "a"], categories=["a", "b", "c"], ordered=False
  136. ),
  137. "datetime_mixed_tz": [
  138. pd.to_datetime("2017-08-01 00:00:00", utc=True),
  139. pd.to_datetime("2017-08-01 02:00:00"),
  140. pd.to_datetime("2017-08-02 00:00:00"),
  141. ],
  142. }
  143. testdata = DataFrame(data)
  144. ax_period = testdata.plot(x="numeric", y="period")
  145. assert (
  146. ax_period.get_lines()[0].get_data()[1] == testdata["period"].values
  147. ).all()
  148. ax_categorical = testdata.plot(x="numeric", y="categorical")
  149. assert (
  150. ax_categorical.get_lines()[0].get_data()[1]
  151. == testdata["categorical"].values
  152. ).all()
  153. ax_datetime_mixed_tz = testdata.plot(x="numeric", y="datetime_mixed_tz")
  154. assert (
  155. ax_datetime_mixed_tz.get_lines()[0].get_data()[1]
  156. == testdata["datetime_mixed_tz"].values
  157. ).all()
  158. def test_subplots_layout_multi_column(self):
  159. # GH 6667
  160. df = DataFrame(np.random.rand(10, 3), index=list(string.ascii_letters[:10]))
  161. axes = df.plot(subplots=True, layout=(2, 2))
  162. self._check_axes_shape(axes, axes_num=3, layout=(2, 2))
  163. assert axes.shape == (2, 2)
  164. axes = df.plot(subplots=True, layout=(-1, 2))
  165. self._check_axes_shape(axes, axes_num=3, layout=(2, 2))
  166. assert axes.shape == (2, 2)
  167. axes = df.plot(subplots=True, layout=(2, -1))
  168. self._check_axes_shape(axes, axes_num=3, layout=(2, 2))
  169. assert axes.shape == (2, 2)
  170. axes = df.plot(subplots=True, layout=(1, 4))
  171. self._check_axes_shape(axes, axes_num=3, layout=(1, 4))
  172. assert axes.shape == (1, 4)
  173. axes = df.plot(subplots=True, layout=(-1, 4))
  174. self._check_axes_shape(axes, axes_num=3, layout=(1, 4))
  175. assert axes.shape == (1, 4)
  176. axes = df.plot(subplots=True, layout=(4, -1))
  177. self._check_axes_shape(axes, axes_num=3, layout=(4, 1))
  178. assert axes.shape == (4, 1)
  179. msg = "Layout of 1x1 must be larger than required size 3"
  180. with pytest.raises(ValueError, match=msg):
  181. df.plot(subplots=True, layout=(1, 1))
  182. msg = "At least one dimension of layout must be positive"
  183. with pytest.raises(ValueError, match=msg):
  184. df.plot(subplots=True, layout=(-1, -1))
  185. @pytest.mark.parametrize(
  186. "kwargs, expected_axes_num, expected_layout, expected_shape",
  187. [
  188. ({}, 1, (1, 1), (1,)),
  189. ({"layout": (3, 3)}, 1, (3, 3), (3, 3)),
  190. ],
  191. )
  192. def test_subplots_layout_single_column(
  193. self, kwargs, expected_axes_num, expected_layout, expected_shape
  194. ):
  195. # GH 6667
  196. df = DataFrame(np.random.rand(10, 1), index=list(string.ascii_letters[:10]))
  197. axes = df.plot(subplots=True, **kwargs)
  198. self._check_axes_shape(
  199. axes,
  200. axes_num=expected_axes_num,
  201. layout=expected_layout,
  202. )
  203. assert axes.shape == expected_shape
  204. @pytest.mark.slow
  205. def test_subplots_warnings(self):
  206. # GH 9464
  207. with tm.assert_produces_warning(None):
  208. df = DataFrame(np.random.randn(100, 4))
  209. df.plot(subplots=True, layout=(3, 2))
  210. df = DataFrame(
  211. np.random.randn(100, 4), index=date_range("1/1/2000", periods=100)
  212. )
  213. df.plot(subplots=True, layout=(3, 2))
  214. def test_subplots_multiple_axes(self):
  215. # GH 5353, 6970, GH 7069
  216. fig, axes = self.plt.subplots(2, 3)
  217. df = DataFrame(np.random.rand(10, 3), index=list(string.ascii_letters[:10]))
  218. returned = df.plot(subplots=True, ax=axes[0], sharex=False, sharey=False)
  219. self._check_axes_shape(returned, axes_num=3, layout=(1, 3))
  220. assert returned.shape == (3,)
  221. assert returned[0].figure is fig
  222. # draw on second row
  223. returned = df.plot(subplots=True, ax=axes[1], sharex=False, sharey=False)
  224. self._check_axes_shape(returned, axes_num=3, layout=(1, 3))
  225. assert returned.shape == (3,)
  226. assert returned[0].figure is fig
  227. self._check_axes_shape(axes, axes_num=6, layout=(2, 3))
  228. tm.close()
  229. msg = "The number of passed axes must be 3, the same as the output plot"
  230. with pytest.raises(ValueError, match=msg):
  231. fig, axes = self.plt.subplots(2, 3)
  232. # pass different number of axes from required
  233. df.plot(subplots=True, ax=axes)
  234. # pass 2-dim axes and invalid layout
  235. # invalid lauout should not affect to input and return value
  236. # (show warning is tested in
  237. # TestDataFrameGroupByPlots.test_grouped_box_multiple_axes
  238. fig, axes = self.plt.subplots(2, 2)
  239. with warnings.catch_warnings():
  240. warnings.simplefilter("ignore", UserWarning)
  241. df = DataFrame(np.random.rand(10, 4), index=list(string.ascii_letters[:10]))
  242. returned = df.plot(
  243. subplots=True, ax=axes, layout=(2, 1), sharex=False, sharey=False
  244. )
  245. self._check_axes_shape(returned, axes_num=4, layout=(2, 2))
  246. assert returned.shape == (4,)
  247. returned = df.plot(
  248. subplots=True, ax=axes, layout=(2, -1), sharex=False, sharey=False
  249. )
  250. self._check_axes_shape(returned, axes_num=4, layout=(2, 2))
  251. assert returned.shape == (4,)
  252. returned = df.plot(
  253. subplots=True, ax=axes, layout=(-1, 2), sharex=False, sharey=False
  254. )
  255. self._check_axes_shape(returned, axes_num=4, layout=(2, 2))
  256. assert returned.shape == (4,)
  257. # single column
  258. fig, axes = self.plt.subplots(1, 1)
  259. df = DataFrame(np.random.rand(10, 1), index=list(string.ascii_letters[:10]))
  260. axes = df.plot(subplots=True, ax=[axes], sharex=False, sharey=False)
  261. self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
  262. assert axes.shape == (1,)
  263. def test_subplots_ts_share_axes(self):
  264. # GH 3964
  265. fig, axes = self.plt.subplots(3, 3, sharex=True, sharey=True)
  266. self.plt.subplots_adjust(left=0.05, right=0.95, hspace=0.3, wspace=0.3)
  267. df = DataFrame(
  268. np.random.randn(10, 9),
  269. index=date_range(start="2014-07-01", freq="M", periods=10),
  270. )
  271. for i, ax in enumerate(axes.ravel()):
  272. df[i].plot(ax=ax, fontsize=5)
  273. # Rows other than bottom should not be visible
  274. for ax in axes[0:-1].ravel():
  275. self._check_visible(ax.get_xticklabels(), visible=False)
  276. # Bottom row should be visible
  277. for ax in axes[-1].ravel():
  278. self._check_visible(ax.get_xticklabels(), visible=True)
  279. # First column should be visible
  280. for ax in axes[[0, 1, 2], [0]].ravel():
  281. self._check_visible(ax.get_yticklabels(), visible=True)
  282. # Other columns should not be visible
  283. for ax in axes[[0, 1, 2], [1]].ravel():
  284. self._check_visible(ax.get_yticklabels(), visible=False)
  285. for ax in axes[[0, 1, 2], [2]].ravel():
  286. self._check_visible(ax.get_yticklabels(), visible=False)
  287. def test_subplots_sharex_axes_existing_axes(self):
  288. # GH 9158
  289. d = {"A": [1.0, 2.0, 3.0, 4.0], "B": [4.0, 3.0, 2.0, 1.0], "C": [5, 1, 3, 4]}
  290. df = DataFrame(d, index=date_range("2014 10 11", "2014 10 14"))
  291. axes = df[["A", "B"]].plot(subplots=True)
  292. df["C"].plot(ax=axes[0], secondary_y=True)
  293. self._check_visible(axes[0].get_xticklabels(), visible=False)
  294. self._check_visible(axes[1].get_xticklabels(), visible=True)
  295. for ax in axes.ravel():
  296. self._check_visible(ax.get_yticklabels(), visible=True)
  297. def test_subplots_dup_columns(self):
  298. # GH 10962
  299. df = DataFrame(np.random.rand(5, 5), columns=list("aaaaa"))
  300. axes = df.plot(subplots=True)
  301. for ax in axes:
  302. self._check_legend_labels(ax, labels=["a"])
  303. assert len(ax.lines) == 1
  304. tm.close()
  305. axes = df.plot(subplots=True, secondary_y="a")
  306. for ax in axes:
  307. # (right) is only attached when subplots=False
  308. self._check_legend_labels(ax, labels=["a"])
  309. assert len(ax.lines) == 1
  310. tm.close()
  311. ax = df.plot(secondary_y="a")
  312. self._check_legend_labels(ax, labels=["a (right)"] * 5)
  313. assert len(ax.lines) == 0
  314. assert len(ax.right_ax.lines) == 5
  315. @pytest.mark.xfail(
  316. np_version_gte1p24 and is_platform_linux(),
  317. reason="Weird rounding problems",
  318. strict=False,
  319. )
  320. def test_bar_log_no_subplots(self):
  321. # GH3254, GH3298 matplotlib/matplotlib#1882, #1892
  322. # regressions in 1.2.1
  323. expected = np.array([0.1, 1.0, 10.0, 100])
  324. # no subplots
  325. df = DataFrame({"A": [3] * 5, "B": list(range(1, 6))}, index=range(5))
  326. ax = df.plot.bar(grid=True, log=True)
  327. tm.assert_numpy_array_equal(ax.yaxis.get_ticklocs(), expected)
  328. @pytest.mark.xfail(
  329. np_version_gte1p24 and is_platform_linux(),
  330. reason="Weird rounding problems",
  331. strict=False,
  332. )
  333. def test_bar_log_subplots(self):
  334. expected = np.array([0.1, 1.0, 10.0, 100.0, 1000.0, 1e4])
  335. ax = DataFrame([Series([200, 300]), Series([300, 500])]).plot.bar(
  336. log=True, subplots=True
  337. )
  338. tm.assert_numpy_array_equal(ax[0].yaxis.get_ticklocs(), expected)
  339. tm.assert_numpy_array_equal(ax[1].yaxis.get_ticklocs(), expected)
  340. def test_boxplot_subplots_return_type(self, hist_df):
  341. df = hist_df
  342. # normal style: return_type=None
  343. result = df.plot.box(subplots=True)
  344. assert isinstance(result, Series)
  345. self._check_box_return_type(
  346. result, None, expected_keys=["height", "weight", "category"]
  347. )
  348. for t in ["dict", "axes", "both"]:
  349. returned = df.plot.box(return_type=t, subplots=True)
  350. self._check_box_return_type(
  351. returned,
  352. t,
  353. expected_keys=["height", "weight", "category"],
  354. check_ax_title=False,
  355. )
  356. def test_df_subplots_patterns_minorticks(self):
  357. # GH 10657
  358. import matplotlib.pyplot as plt
  359. df = DataFrame(
  360. np.random.randn(10, 2),
  361. index=date_range("1/1/2000", periods=10),
  362. columns=list("AB"),
  363. )
  364. # shared subplots
  365. fig, axes = plt.subplots(2, 1, sharex=True)
  366. axes = df.plot(subplots=True, ax=axes)
  367. for ax in axes:
  368. assert len(ax.lines) == 1
  369. self._check_visible(ax.get_yticklabels(), visible=True)
  370. # xaxis of 1st ax must be hidden
  371. self._check_visible(axes[0].get_xticklabels(), visible=False)
  372. self._check_visible(axes[0].get_xticklabels(minor=True), visible=False)
  373. self._check_visible(axes[1].get_xticklabels(), visible=True)
  374. self._check_visible(axes[1].get_xticklabels(minor=True), visible=True)
  375. tm.close()
  376. fig, axes = plt.subplots(2, 1)
  377. with tm.assert_produces_warning(UserWarning):
  378. axes = df.plot(subplots=True, ax=axes, sharex=True)
  379. for ax in axes:
  380. assert len(ax.lines) == 1
  381. self._check_visible(ax.get_yticklabels(), visible=True)
  382. # xaxis of 1st ax must be hidden
  383. self._check_visible(axes[0].get_xticklabels(), visible=False)
  384. self._check_visible(axes[0].get_xticklabels(minor=True), visible=False)
  385. self._check_visible(axes[1].get_xticklabels(), visible=True)
  386. self._check_visible(axes[1].get_xticklabels(minor=True), visible=True)
  387. tm.close()
  388. # not shared
  389. fig, axes = plt.subplots(2, 1)
  390. axes = df.plot(subplots=True, ax=axes)
  391. for ax in axes:
  392. assert len(ax.lines) == 1
  393. self._check_visible(ax.get_yticklabels(), visible=True)
  394. self._check_visible(ax.get_xticklabels(), visible=True)
  395. self._check_visible(ax.get_xticklabels(minor=True), visible=True)
  396. tm.close()
  397. def test_subplots_sharex_false(self):
  398. # test when sharex is set to False, two plots should have different
  399. # labels, GH 25160
  400. df = DataFrame(np.random.rand(10, 2))
  401. df.iloc[5:, 1] = np.nan
  402. df.iloc[:5, 0] = np.nan
  403. figs, axs = self.plt.subplots(2, 1)
  404. df.plot.line(ax=axs, subplots=True, sharex=False)
  405. expected_ax1 = np.arange(4.5, 10, 0.5)
  406. expected_ax2 = np.arange(-0.5, 5, 0.5)
  407. tm.assert_numpy_array_equal(axs[0].get_xticks(), expected_ax1)
  408. tm.assert_numpy_array_equal(axs[1].get_xticks(), expected_ax2)
  409. def test_subplots_constrained_layout(self):
  410. # GH 25261
  411. idx = date_range(start="now", periods=10)
  412. df = DataFrame(np.random.rand(10, 3), index=idx)
  413. kwargs = {}
  414. if hasattr(self.plt.Figure, "get_constrained_layout"):
  415. kwargs["constrained_layout"] = True
  416. fig, axes = self.plt.subplots(2, **kwargs)
  417. with tm.assert_produces_warning(None):
  418. df.plot(ax=axes[0])
  419. with tm.ensure_clean(return_filelike=True) as path:
  420. self.plt.savefig(path)
  421. @pytest.mark.parametrize(
  422. "index_name, old_label, new_label",
  423. [
  424. (None, "", "new"),
  425. ("old", "old", "new"),
  426. (None, "", ""),
  427. (None, "", 1),
  428. (None, "", [1, 2]),
  429. ],
  430. )
  431. @pytest.mark.parametrize("kind", ["line", "area", "bar"])
  432. def test_xlabel_ylabel_dataframe_subplots(
  433. self, kind, index_name, old_label, new_label
  434. ):
  435. # GH 9093
  436. df = DataFrame([[1, 2], [2, 5]], columns=["Type A", "Type B"])
  437. df.index.name = index_name
  438. # default is the ylabel is not shown and xlabel is index name
  439. axes = df.plot(kind=kind, subplots=True)
  440. assert all(ax.get_ylabel() == "" for ax in axes)
  441. assert all(ax.get_xlabel() == old_label for ax in axes)
  442. # old xlabel will be overridden and assigned ylabel will be used as ylabel
  443. axes = df.plot(kind=kind, ylabel=new_label, xlabel=new_label, subplots=True)
  444. assert all(ax.get_ylabel() == str(new_label) for ax in axes)
  445. assert all(ax.get_xlabel() == str(new_label) for ax in axes)
  446. @pytest.mark.parametrize(
  447. "kwargs",
  448. [
  449. # stacked center
  450. {"kind": "bar", "stacked": True},
  451. {"kind": "bar", "stacked": True, "width": 0.9},
  452. {"kind": "barh", "stacked": True},
  453. {"kind": "barh", "stacked": True, "width": 0.9},
  454. # center
  455. {"kind": "bar", "stacked": False},
  456. {"kind": "bar", "stacked": False, "width": 0.9},
  457. {"kind": "barh", "stacked": False},
  458. {"kind": "barh", "stacked": False, "width": 0.9},
  459. # subplots center
  460. {"kind": "bar", "subplots": True},
  461. {"kind": "bar", "subplots": True, "width": 0.9},
  462. {"kind": "barh", "subplots": True},
  463. {"kind": "barh", "subplots": True, "width": 0.9},
  464. # align edge
  465. {"kind": "bar", "stacked": True, "align": "edge"},
  466. {"kind": "bar", "stacked": True, "width": 0.9, "align": "edge"},
  467. {"kind": "barh", "stacked": True, "align": "edge"},
  468. {"kind": "barh", "stacked": True, "width": 0.9, "align": "edge"},
  469. {"kind": "bar", "stacked": False, "align": "edge"},
  470. {"kind": "bar", "stacked": False, "width": 0.9, "align": "edge"},
  471. {"kind": "barh", "stacked": False, "align": "edge"},
  472. {"kind": "barh", "stacked": False, "width": 0.9, "align": "edge"},
  473. {"kind": "bar", "subplots": True, "align": "edge"},
  474. {"kind": "bar", "subplots": True, "width": 0.9, "align": "edge"},
  475. {"kind": "barh", "subplots": True, "align": "edge"},
  476. {"kind": "barh", "subplots": True, "width": 0.9, "align": "edge"},
  477. ],
  478. )
  479. def test_bar_align_multiple_columns(self, kwargs):
  480. # GH2157
  481. df = DataFrame({"A": [3] * 5, "B": list(range(5))}, index=range(5))
  482. self._check_bar_alignment(df, **kwargs)
  483. @pytest.mark.parametrize(
  484. "kwargs",
  485. [
  486. {"kind": "bar", "stacked": False},
  487. {"kind": "bar", "stacked": True},
  488. {"kind": "barh", "stacked": False},
  489. {"kind": "barh", "stacked": True},
  490. {"kind": "bar", "subplots": True},
  491. {"kind": "barh", "subplots": True},
  492. ],
  493. )
  494. def test_bar_align_single_column(self, kwargs):
  495. df = DataFrame(np.random.randn(5))
  496. self._check_bar_alignment(df, **kwargs)
  497. @pytest.mark.parametrize(
  498. "kwargs",
  499. [
  500. {"kind": "bar", "stacked": False},
  501. {"kind": "bar", "stacked": True},
  502. {"kind": "barh", "stacked": False},
  503. {"kind": "barh", "stacked": True},
  504. {"kind": "bar", "subplots": True},
  505. {"kind": "barh", "subplots": True},
  506. ],
  507. )
  508. def test_bar_barwidth_position(self, kwargs):
  509. df = DataFrame(np.random.randn(5, 5))
  510. self._check_bar_alignment(df, width=0.9, position=0.2, **kwargs)
  511. @pytest.mark.parametrize("w", [1, 1.0])
  512. def test_bar_barwidth_position_int(self, w):
  513. # GH 12979
  514. df = DataFrame(np.random.randn(5, 5))
  515. ax = df.plot.bar(stacked=True, width=w)
  516. ticks = ax.xaxis.get_ticklocs()
  517. tm.assert_numpy_array_equal(ticks, np.array([0, 1, 2, 3, 4]))
  518. assert ax.get_xlim() == (-0.75, 4.75)
  519. # check left-edge of bars
  520. assert ax.patches[0].get_x() == -0.5
  521. assert ax.patches[-1].get_x() == 3.5
  522. def test_bar_barwidth_position_int_width_1(self):
  523. # GH 12979
  524. df = DataFrame(np.random.randn(5, 5))
  525. self._check_bar_alignment(df, kind="bar", stacked=True, width=1)
  526. self._check_bar_alignment(df, kind="barh", stacked=False, width=1)
  527. self._check_bar_alignment(df, kind="barh", stacked=True, width=1)
  528. self._check_bar_alignment(df, kind="bar", subplots=True, width=1)
  529. self._check_bar_alignment(df, kind="barh", subplots=True, width=1)
  530. def _check_bar_alignment(
  531. self,
  532. df,
  533. kind="bar",
  534. stacked=False,
  535. subplots=False,
  536. align="center",
  537. width=0.5,
  538. position=0.5,
  539. ):
  540. axes = df.plot(
  541. kind=kind,
  542. stacked=stacked,
  543. subplots=subplots,
  544. align=align,
  545. width=width,
  546. position=position,
  547. grid=True,
  548. )
  549. axes = self._flatten_visible(axes)
  550. for ax in axes:
  551. if kind == "bar":
  552. axis = ax.xaxis
  553. ax_min, ax_max = ax.get_xlim()
  554. min_edge = min(p.get_x() for p in ax.patches)
  555. max_edge = max(p.get_x() + p.get_width() for p in ax.patches)
  556. elif kind == "barh":
  557. axis = ax.yaxis
  558. ax_min, ax_max = ax.get_ylim()
  559. min_edge = min(p.get_y() for p in ax.patches)
  560. max_edge = max(p.get_y() + p.get_height() for p in ax.patches)
  561. else:
  562. raise ValueError
  563. # GH 7498
  564. # compare margins between lim and bar edges
  565. tm.assert_almost_equal(ax_min, min_edge - 0.25)
  566. tm.assert_almost_equal(ax_max, max_edge + 0.25)
  567. p = ax.patches[0]
  568. if kind == "bar" and (stacked is True or subplots is True):
  569. edge = p.get_x()
  570. center = edge + p.get_width() * position
  571. elif kind == "bar" and stacked is False:
  572. center = p.get_x() + p.get_width() * len(df.columns) * position
  573. edge = p.get_x()
  574. elif kind == "barh" and (stacked is True or subplots is True):
  575. center = p.get_y() + p.get_height() * position
  576. edge = p.get_y()
  577. elif kind == "barh" and stacked is False:
  578. center = p.get_y() + p.get_height() * len(df.columns) * position
  579. edge = p.get_y()
  580. else:
  581. raise ValueError
  582. # Check the ticks locates on integer
  583. assert (axis.get_ticklocs() == np.arange(len(df))).all()
  584. if align == "center":
  585. # Check whether the bar locates on center
  586. tm.assert_almost_equal(axis.get_ticklocs()[0], center)
  587. elif align == "edge":
  588. # Check whether the bar's edge starts from the tick
  589. tm.assert_almost_equal(axis.get_ticklocs()[0], edge)
  590. else:
  591. raise ValueError
  592. return axes