bar.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. from __future__ import annotations
  2. from collections import defaultdict
  3. from dataclasses import dataclass
  4. import numpy as np
  5. import matplotlib as mpl
  6. from seaborn._marks.base import (
  7. Mark,
  8. Mappable,
  9. MappableBool,
  10. MappableColor,
  11. MappableFloat,
  12. MappableStyle,
  13. resolve_properties,
  14. resolve_color,
  15. document_properties
  16. )
  17. from seaborn.utils import _version_predates
  18. from typing import TYPE_CHECKING
  19. if TYPE_CHECKING:
  20. from typing import Any
  21. from matplotlib.artist import Artist
  22. from seaborn._core.scales import Scale
  23. class BarBase(Mark):
  24. def _make_patches(self, data, scales, orient):
  25. transform = scales[orient]._matplotlib_scale.get_transform()
  26. forward = transform.transform
  27. reverse = transform.inverted().transform
  28. other = {"x": "y", "y": "x"}[orient]
  29. pos = reverse(forward(data[orient]) - data["width"] / 2)
  30. width = reverse(forward(data[orient]) + data["width"] / 2) - pos
  31. val = (data[other] - data["baseline"]).to_numpy()
  32. base = data["baseline"].to_numpy()
  33. kws = self._resolve_properties(data, scales)
  34. if orient == "x":
  35. kws.update(x=pos, y=base, w=width, h=val)
  36. else:
  37. kws.update(x=base, y=pos, w=val, h=width)
  38. kws.pop("width", None)
  39. kws.pop("baseline", None)
  40. val_dim = {"x": "h", "y": "w"}[orient]
  41. bars, vals = [], []
  42. for i in range(len(data)):
  43. row = {k: v[i] for k, v in kws.items()}
  44. # Skip bars with no value. It's possible we'll want to make this
  45. # an option (i.e so you have an artist for animating or annotating),
  46. # but let's keep things simple for now.
  47. if not np.nan_to_num(row[val_dim]):
  48. continue
  49. bar = mpl.patches.Rectangle(
  50. xy=(row["x"], row["y"]),
  51. width=row["w"],
  52. height=row["h"],
  53. facecolor=row["facecolor"],
  54. edgecolor=row["edgecolor"],
  55. linestyle=row["edgestyle"],
  56. linewidth=row["edgewidth"],
  57. **self.artist_kws,
  58. )
  59. bars.append(bar)
  60. vals.append(row[val_dim])
  61. return bars, vals
  62. def _resolve_properties(self, data, scales):
  63. resolved = resolve_properties(self, data, scales)
  64. resolved["facecolor"] = resolve_color(self, data, "", scales)
  65. resolved["edgecolor"] = resolve_color(self, data, "edge", scales)
  66. fc = resolved["facecolor"]
  67. if isinstance(fc, tuple):
  68. resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"]
  69. else:
  70. fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem?
  71. resolved["facecolor"] = fc
  72. return resolved
  73. def _legend_artist(
  74. self, variables: list[str], value: Any, scales: dict[str, Scale],
  75. ) -> Artist:
  76. # TODO return some sensible default?
  77. key = {v: value for v in variables}
  78. key = self._resolve_properties(key, scales)
  79. artist = mpl.patches.Patch(
  80. facecolor=key["facecolor"],
  81. edgecolor=key["edgecolor"],
  82. linewidth=key["edgewidth"],
  83. linestyle=key["edgestyle"],
  84. )
  85. return artist
  86. @document_properties
  87. @dataclass
  88. class Bar(BarBase):
  89. """
  90. A bar mark drawn between baseline and data values.
  91. See also
  92. --------
  93. Bars : A faster bar mark with defaults more suitable for histograms.
  94. Examples
  95. --------
  96. .. include:: ../docstrings/objects.Bar.rst
  97. """
  98. color: MappableColor = Mappable("C0", grouping=False)
  99. alpha: MappableFloat = Mappable(.7, grouping=False)
  100. fill: MappableBool = Mappable(True, grouping=False)
  101. edgecolor: MappableColor = Mappable(depend="color", grouping=False)
  102. edgealpha: MappableFloat = Mappable(1, grouping=False)
  103. edgewidth: MappableFloat = Mappable(rc="patch.linewidth", grouping=False)
  104. edgestyle: MappableStyle = Mappable("-", grouping=False)
  105. # pattern: MappableString = Mappable(None) # TODO no Property yet
  106. width: MappableFloat = Mappable(.8, grouping=False)
  107. baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
  108. def _plot(self, split_gen, scales, orient):
  109. val_idx = ["y", "x"].index(orient)
  110. for _, data, ax in split_gen():
  111. bars, vals = self._make_patches(data, scales, orient)
  112. for bar in bars:
  113. # Because we are clipping the artist (see below), the edges end up
  114. # looking half as wide as they actually are. I don't love this clumsy
  115. # workaround, which is going to cause surprises if you work with the
  116. # artists directly. We may need to revisit after feedback.
  117. bar.set_linewidth(bar.get_linewidth() * 2)
  118. linestyle = bar.get_linestyle()
  119. if linestyle[1]:
  120. linestyle = (linestyle[0], tuple(x / 2 for x in linestyle[1]))
  121. bar.set_linestyle(linestyle)
  122. # This is a bit of a hack to handle the fact that the edge lines are
  123. # centered on the actual extents of the bar, and overlap when bars are
  124. # stacked or dodged. We may discover that this causes problems and needs
  125. # to be revisited at some point. Also it should be faster to clip with
  126. # a bbox than a path, but I cant't work out how to get the intersection
  127. # with the axes bbox.
  128. bar.set_clip_path(bar.get_path(), bar.get_transform() + ax.transData)
  129. if self.artist_kws.get("clip_on", True):
  130. # It seems the above hack undoes the default axes clipping
  131. bar.set_clip_box(ax.bbox)
  132. bar.sticky_edges[val_idx][:] = (0, np.inf)
  133. ax.add_patch(bar)
  134. # Add a container which is useful for, e.g. Axes.bar_label
  135. if _version_predates(mpl, "3.4"):
  136. container_kws = {}
  137. else:
  138. orientation = {"x": "vertical", "y": "horizontal"}[orient]
  139. container_kws = dict(datavalues=vals, orientation=orientation)
  140. container = mpl.container.BarContainer(bars, **container_kws)
  141. ax.add_container(container)
  142. @document_properties
  143. @dataclass
  144. class Bars(BarBase):
  145. """
  146. A faster bar mark with defaults more suitable for histograms.
  147. See also
  148. --------
  149. Bar : A bar mark drawn between baseline and data values.
  150. Examples
  151. --------
  152. .. include:: ../docstrings/objects.Bars.rst
  153. """
  154. color: MappableColor = Mappable("C0", grouping=False)
  155. alpha: MappableFloat = Mappable(.7, grouping=False)
  156. fill: MappableBool = Mappable(True, grouping=False)
  157. edgecolor: MappableColor = Mappable(rc="patch.edgecolor", grouping=False)
  158. edgealpha: MappableFloat = Mappable(1, grouping=False)
  159. edgewidth: MappableFloat = Mappable(auto=True, grouping=False)
  160. edgestyle: MappableStyle = Mappable("-", grouping=False)
  161. # pattern: MappableString = Mappable(None) # TODO no Property yet
  162. width: MappableFloat = Mappable(1, grouping=False)
  163. baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
  164. def _plot(self, split_gen, scales, orient):
  165. ori_idx = ["x", "y"].index(orient)
  166. val_idx = ["y", "x"].index(orient)
  167. patches = defaultdict(list)
  168. for _, data, ax in split_gen():
  169. bars, _ = self._make_patches(data, scales, orient)
  170. patches[ax].extend(bars)
  171. collections = {}
  172. for ax, ax_patches in patches.items():
  173. col = mpl.collections.PatchCollection(ax_patches, match_original=True)
  174. col.sticky_edges[val_idx][:] = (0, np.inf)
  175. ax.add_collection(col, autolim=False)
  176. collections[ax] = col
  177. # Workaround for matplotlib autoscaling bug
  178. # https://github.com/matplotlib/matplotlib/issues/11898
  179. # https://github.com/matplotlib/matplotlib/issues/23129
  180. xys = np.vstack([path.vertices for path in col.get_paths()])
  181. ax.update_datalim(xys)
  182. if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable):
  183. for ax in collections:
  184. ax.autoscale_view()
  185. def get_dimensions(collection):
  186. edges, widths = [], []
  187. for verts in (path.vertices for path in collection.get_paths()):
  188. edges.append(min(verts[:, ori_idx]))
  189. widths.append(np.ptp(verts[:, ori_idx]))
  190. return np.array(edges), np.array(widths)
  191. min_width = np.inf
  192. for ax, col in collections.items():
  193. edges, widths = get_dimensions(col)
  194. points = 72 / ax.figure.dpi * abs(
  195. ax.transData.transform([edges + widths] * 2)
  196. - ax.transData.transform([edges] * 2)
  197. )
  198. min_width = min(min_width, min(points[:, ori_idx]))
  199. linewidth = min(.1 * min_width, mpl.rcParams["patch.linewidth"])
  200. for _, col in collections.items():
  201. col.set_linewidth(linewidth)