test_frame_legend.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. import numpy as np
  2. import pytest
  3. import pandas.util._test_decorators as td
  4. from pandas import (
  5. DataFrame,
  6. date_range,
  7. )
  8. from pandas.tests.plotting.common import TestPlotBase
  9. from pandas.util.version import Version
  10. class TestFrameLegend(TestPlotBase):
  11. @pytest.mark.xfail(
  12. reason=(
  13. "Open bug in matplotlib "
  14. "https://github.com/matplotlib/matplotlib/issues/11357"
  15. )
  16. )
  17. def test_mixed_yerr(self):
  18. # https://github.com/pandas-dev/pandas/issues/39522
  19. import matplotlib as mpl
  20. from matplotlib.collections import LineCollection
  21. from matplotlib.lines import Line2D
  22. df = DataFrame([{"x": 1, "a": 1, "b": 1}, {"x": 2, "a": 2, "b": 3}])
  23. ax = df.plot("x", "a", c="orange", yerr=0.1, label="orange")
  24. df.plot("x", "b", c="blue", yerr=None, ax=ax, label="blue")
  25. legend = ax.get_legend()
  26. if Version(mpl.__version__) < Version("3.7"):
  27. result_handles = legend.legendHandles
  28. else:
  29. result_handles = legend.legend_handles
  30. assert isinstance(result_handles[0], LineCollection)
  31. assert isinstance(result_handles[1], Line2D)
  32. def test_legend_false(self):
  33. # https://github.com/pandas-dev/pandas/issues/40044
  34. import matplotlib as mpl
  35. df = DataFrame({"a": [1, 1], "b": [2, 3]})
  36. df2 = DataFrame({"d": [2.5, 2.5]})
  37. ax = df.plot(legend=True, color={"a": "blue", "b": "green"}, secondary_y="b")
  38. df2.plot(legend=True, color={"d": "red"}, ax=ax)
  39. legend = ax.get_legend()
  40. if Version(mpl.__version__) < Version("3.7"):
  41. handles = legend.legendHandles
  42. else:
  43. handles = legend.legend_handles
  44. result = [handle.get_color() for handle in handles]
  45. expected = ["blue", "green", "red"]
  46. assert result == expected
  47. @td.skip_if_no_scipy
  48. def test_df_legend_labels(self):
  49. kinds = ["line", "bar", "barh", "kde", "area", "hist"]
  50. df = DataFrame(np.random.rand(3, 3), columns=["a", "b", "c"])
  51. df2 = DataFrame(np.random.rand(3, 3), columns=["d", "e", "f"])
  52. df3 = DataFrame(np.random.rand(3, 3), columns=["g", "h", "i"])
  53. df4 = DataFrame(np.random.rand(3, 3), columns=["j", "k", "l"])
  54. for kind in kinds:
  55. ax = df.plot(kind=kind, legend=True)
  56. self._check_legend_labels(ax, labels=df.columns)
  57. ax = df2.plot(kind=kind, legend=False, ax=ax)
  58. self._check_legend_labels(ax, labels=df.columns)
  59. ax = df3.plot(kind=kind, legend=True, ax=ax)
  60. self._check_legend_labels(ax, labels=df.columns.union(df3.columns))
  61. ax = df4.plot(kind=kind, legend="reverse", ax=ax)
  62. expected = list(df.columns.union(df3.columns)) + list(reversed(df4.columns))
  63. self._check_legend_labels(ax, labels=expected)
  64. # Secondary Y
  65. ax = df.plot(legend=True, secondary_y="b")
  66. self._check_legend_labels(ax, labels=["a", "b (right)", "c"])
  67. ax = df2.plot(legend=False, ax=ax)
  68. self._check_legend_labels(ax, labels=["a", "b (right)", "c"])
  69. ax = df3.plot(kind="bar", legend=True, secondary_y="h", ax=ax)
  70. self._check_legend_labels(
  71. ax, labels=["a", "b (right)", "c", "g", "h (right)", "i"]
  72. )
  73. # Time Series
  74. ind = date_range("1/1/2014", periods=3)
  75. df = DataFrame(np.random.randn(3, 3), columns=["a", "b", "c"], index=ind)
  76. df2 = DataFrame(np.random.randn(3, 3), columns=["d", "e", "f"], index=ind)
  77. df3 = DataFrame(np.random.randn(3, 3), columns=["g", "h", "i"], index=ind)
  78. ax = df.plot(legend=True, secondary_y="b")
  79. self._check_legend_labels(ax, labels=["a", "b (right)", "c"])
  80. ax = df2.plot(legend=False, ax=ax)
  81. self._check_legend_labels(ax, labels=["a", "b (right)", "c"])
  82. ax = df3.plot(legend=True, ax=ax)
  83. self._check_legend_labels(ax, labels=["a", "b (right)", "c", "g", "h", "i"])
  84. # scatter
  85. ax = df.plot.scatter(x="a", y="b", label="data1")
  86. self._check_legend_labels(ax, labels=["data1"])
  87. ax = df2.plot.scatter(x="d", y="e", legend=False, label="data2", ax=ax)
  88. self._check_legend_labels(ax, labels=["data1"])
  89. ax = df3.plot.scatter(x="g", y="h", label="data3", ax=ax)
  90. self._check_legend_labels(ax, labels=["data1", "data3"])
  91. # ensure label args pass through and
  92. # index name does not mutate
  93. # column names don't mutate
  94. df5 = df.set_index("a")
  95. ax = df5.plot(y="b")
  96. self._check_legend_labels(ax, labels=["b"])
  97. ax = df5.plot(y="b", label="LABEL_b")
  98. self._check_legend_labels(ax, labels=["LABEL_b"])
  99. self._check_text_labels(ax.xaxis.get_label(), "a")
  100. ax = df5.plot(y="c", label="LABEL_c", ax=ax)
  101. self._check_legend_labels(ax, labels=["LABEL_b", "LABEL_c"])
  102. assert df5.columns.tolist() == ["b", "c"]
  103. def test_missing_marker_multi_plots_on_same_ax(self):
  104. # GH 18222
  105. df = DataFrame(data=[[1, 1, 1, 1], [2, 2, 4, 8]], columns=["x", "r", "g", "b"])
  106. fig, ax = self.plt.subplots(nrows=1, ncols=3)
  107. # Left plot
  108. df.plot(x="x", y="r", linewidth=0, marker="o", color="r", ax=ax[0])
  109. df.plot(x="x", y="g", linewidth=1, marker="x", color="g", ax=ax[0])
  110. df.plot(x="x", y="b", linewidth=1, marker="o", color="b", ax=ax[0])
  111. self._check_legend_labels(ax[0], labels=["r", "g", "b"])
  112. self._check_legend_marker(ax[0], expected_markers=["o", "x", "o"])
  113. # Center plot
  114. df.plot(x="x", y="b", linewidth=1, marker="o", color="b", ax=ax[1])
  115. df.plot(x="x", y="r", linewidth=0, marker="o", color="r", ax=ax[1])
  116. df.plot(x="x", y="g", linewidth=1, marker="x", color="g", ax=ax[1])
  117. self._check_legend_labels(ax[1], labels=["b", "r", "g"])
  118. self._check_legend_marker(ax[1], expected_markers=["o", "o", "x"])
  119. # Right plot
  120. df.plot(x="x", y="g", linewidth=1, marker="x", color="g", ax=ax[2])
  121. df.plot(x="x", y="b", linewidth=1, marker="o", color="b", ax=ax[2])
  122. df.plot(x="x", y="r", linewidth=0, marker="o", color="r", ax=ax[2])
  123. self._check_legend_labels(ax[2], labels=["g", "b", "r"])
  124. self._check_legend_marker(ax[2], expected_markers=["x", "o", "o"])
  125. def test_legend_name(self):
  126. multi = DataFrame(
  127. np.random.randn(4, 4),
  128. columns=[np.array(["a", "a", "b", "b"]), np.array(["x", "y", "x", "y"])],
  129. )
  130. multi.columns.names = ["group", "individual"]
  131. ax = multi.plot()
  132. leg_title = ax.legend_.get_title()
  133. self._check_text_labels(leg_title, "group,individual")
  134. df = DataFrame(np.random.randn(5, 5))
  135. ax = df.plot(legend=True, ax=ax)
  136. leg_title = ax.legend_.get_title()
  137. self._check_text_labels(leg_title, "group,individual")
  138. df.columns.name = "new"
  139. ax = df.plot(legend=False, ax=ax)
  140. leg_title = ax.legend_.get_title()
  141. self._check_text_labels(leg_title, "group,individual")
  142. ax = df.plot(legend=True, ax=ax)
  143. leg_title = ax.legend_.get_title()
  144. self._check_text_labels(leg_title, "new")
  145. @pytest.mark.parametrize(
  146. "kind",
  147. [
  148. "line",
  149. "bar",
  150. "barh",
  151. pytest.param("kde", marks=td.skip_if_no_scipy),
  152. "area",
  153. "hist",
  154. ],
  155. )
  156. def test_no_legend(self, kind):
  157. df = DataFrame(np.random.rand(3, 3), columns=["a", "b", "c"])
  158. ax = df.plot(kind=kind, legend=False)
  159. self._check_legend_labels(ax, visible=False)
  160. def test_missing_markers_legend(self):
  161. # 14958
  162. df = DataFrame(np.random.randn(8, 3), columns=["A", "B", "C"])
  163. ax = df.plot(y=["A"], marker="x", linestyle="solid")
  164. df.plot(y=["B"], marker="o", linestyle="dotted", ax=ax)
  165. df.plot(y=["C"], marker="<", linestyle="dotted", ax=ax)
  166. self._check_legend_labels(ax, labels=["A", "B", "C"])
  167. self._check_legend_marker(ax, expected_markers=["x", "o", "<"])
  168. def test_missing_markers_legend_using_style(self):
  169. # 14563
  170. df = DataFrame(
  171. {
  172. "A": [1, 2, 3, 4, 5, 6],
  173. "B": [2, 4, 1, 3, 2, 4],
  174. "C": [3, 3, 2, 6, 4, 2],
  175. "X": [1, 2, 3, 4, 5, 6],
  176. }
  177. )
  178. fig, ax = self.plt.subplots()
  179. for kind in "ABC":
  180. df.plot("X", kind, label=kind, ax=ax, style=".")
  181. self._check_legend_labels(ax, labels=["A", "B", "C"])
  182. self._check_legend_marker(ax, expected_markers=[".", ".", "."])