__init__.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. from pandas.plotting._matplotlib.boxplot import (
  4. BoxPlot,
  5. boxplot,
  6. boxplot_frame,
  7. boxplot_frame_groupby,
  8. )
  9. from pandas.plotting._matplotlib.converter import (
  10. deregister,
  11. register,
  12. )
  13. from pandas.plotting._matplotlib.core import (
  14. AreaPlot,
  15. BarhPlot,
  16. BarPlot,
  17. HexBinPlot,
  18. LinePlot,
  19. PiePlot,
  20. ScatterPlot,
  21. )
  22. from pandas.plotting._matplotlib.hist import (
  23. HistPlot,
  24. KdePlot,
  25. hist_frame,
  26. hist_series,
  27. )
  28. from pandas.plotting._matplotlib.misc import (
  29. andrews_curves,
  30. autocorrelation_plot,
  31. bootstrap_plot,
  32. lag_plot,
  33. parallel_coordinates,
  34. radviz,
  35. scatter_matrix,
  36. )
  37. from pandas.plotting._matplotlib.tools import table
  38. if TYPE_CHECKING:
  39. from pandas.plotting._matplotlib.core import MPLPlot
  40. PLOT_CLASSES: dict[str, type[MPLPlot]] = {
  41. "line": LinePlot,
  42. "bar": BarPlot,
  43. "barh": BarhPlot,
  44. "box": BoxPlot,
  45. "hist": HistPlot,
  46. "kde": KdePlot,
  47. "area": AreaPlot,
  48. "pie": PiePlot,
  49. "scatter": ScatterPlot,
  50. "hexbin": HexBinPlot,
  51. }
  52. def plot(data, kind, **kwargs):
  53. # Importing pyplot at the top of the file (before the converters are
  54. # registered) causes problems in matplotlib 2 (converters seem to not
  55. # work)
  56. import matplotlib.pyplot as plt
  57. if kwargs.pop("reuse_plot", False):
  58. ax = kwargs.get("ax")
  59. if ax is None and len(plt.get_fignums()) > 0:
  60. with plt.rc_context():
  61. ax = plt.gca()
  62. kwargs["ax"] = getattr(ax, "left_ax", ax)
  63. plot_obj = PLOT_CLASSES[kind](data, **kwargs)
  64. plot_obj.generate()
  65. plot_obj.draw()
  66. return plot_obj.result
  67. __all__ = [
  68. "plot",
  69. "hist_series",
  70. "hist_frame",
  71. "boxplot",
  72. "boxplot_frame",
  73. "boxplot_frame_groupby",
  74. "table",
  75. "andrews_curves",
  76. "autocorrelation_plot",
  77. "bootstrap_plot",
  78. "lag_plot",
  79. "parallel_coordinates",
  80. "radviz",
  81. "scatter_matrix",
  82. "register",
  83. "deregister",
  84. ]