123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256 |
- from __future__ import annotations
- from collections import defaultdict
- from dataclasses import dataclass
- import numpy as np
- import matplotlib as mpl
- from seaborn._marks.base import (
- Mark,
- Mappable,
- MappableBool,
- MappableColor,
- MappableFloat,
- MappableStyle,
- resolve_properties,
- resolve_color,
- document_properties
- )
- from seaborn.utils import _version_predates
- from typing import TYPE_CHECKING
- if TYPE_CHECKING:
- from typing import Any
- from matplotlib.artist import Artist
- from seaborn._core.scales import Scale
- class BarBase(Mark):
- def _make_patches(self, data, scales, orient):
- transform = scales[orient]._matplotlib_scale.get_transform()
- forward = transform.transform
- reverse = transform.inverted().transform
- other = {"x": "y", "y": "x"}[orient]
- pos = reverse(forward(data[orient]) - data["width"] / 2)
- width = reverse(forward(data[orient]) + data["width"] / 2) - pos
- val = (data[other] - data["baseline"]).to_numpy()
- base = data["baseline"].to_numpy()
- kws = self._resolve_properties(data, scales)
- if orient == "x":
- kws.update(x=pos, y=base, w=width, h=val)
- else:
- kws.update(x=base, y=pos, w=val, h=width)
- kws.pop("width", None)
- kws.pop("baseline", None)
- val_dim = {"x": "h", "y": "w"}[orient]
- bars, vals = [], []
- for i in range(len(data)):
- row = {k: v[i] for k, v in kws.items()}
- # Skip bars with no value. It's possible we'll want to make this
- # an option (i.e so you have an artist for animating or annotating),
- # but let's keep things simple for now.
- if not np.nan_to_num(row[val_dim]):
- continue
- bar = mpl.patches.Rectangle(
- xy=(row["x"], row["y"]),
- width=row["w"],
- height=row["h"],
- facecolor=row["facecolor"],
- edgecolor=row["edgecolor"],
- linestyle=row["edgestyle"],
- linewidth=row["edgewidth"],
- **self.artist_kws,
- )
- bars.append(bar)
- vals.append(row[val_dim])
- return bars, vals
- def _resolve_properties(self, data, scales):
- resolved = resolve_properties(self, data, scales)
- resolved["facecolor"] = resolve_color(self, data, "", scales)
- resolved["edgecolor"] = resolve_color(self, data, "edge", scales)
- fc = resolved["facecolor"]
- if isinstance(fc, tuple):
- resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"]
- else:
- fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem?
- resolved["facecolor"] = fc
- return resolved
- def _legend_artist(
- self, variables: list[str], value: Any, scales: dict[str, Scale],
- ) -> Artist:
- # TODO return some sensible default?
- key = {v: value for v in variables}
- key = self._resolve_properties(key, scales)
- artist = mpl.patches.Patch(
- facecolor=key["facecolor"],
- edgecolor=key["edgecolor"],
- linewidth=key["edgewidth"],
- linestyle=key["edgestyle"],
- )
- return artist
- @document_properties
- @dataclass
- class Bar(BarBase):
- """
- A bar mark drawn between baseline and data values.
- See also
- --------
- Bars : A faster bar mark with defaults more suitable for histograms.
- Examples
- --------
- .. include:: ../docstrings/objects.Bar.rst
- """
- color: MappableColor = Mappable("C0", grouping=False)
- alpha: MappableFloat = Mappable(.7, grouping=False)
- fill: MappableBool = Mappable(True, grouping=False)
- edgecolor: MappableColor = Mappable(depend="color", grouping=False)
- edgealpha: MappableFloat = Mappable(1, grouping=False)
- edgewidth: MappableFloat = Mappable(rc="patch.linewidth", grouping=False)
- edgestyle: MappableStyle = Mappable("-", grouping=False)
- # pattern: MappableString = Mappable(None) # TODO no Property yet
- width: MappableFloat = Mappable(.8, grouping=False)
- baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
- def _plot(self, split_gen, scales, orient):
- val_idx = ["y", "x"].index(orient)
- for _, data, ax in split_gen():
- bars, vals = self._make_patches(data, scales, orient)
- for bar in bars:
- # Because we are clipping the artist (see below), the edges end up
- # looking half as wide as they actually are. I don't love this clumsy
- # workaround, which is going to cause surprises if you work with the
- # artists directly. We may need to revisit after feedback.
- bar.set_linewidth(bar.get_linewidth() * 2)
- linestyle = bar.get_linestyle()
- if linestyle[1]:
- linestyle = (linestyle[0], tuple(x / 2 for x in linestyle[1]))
- bar.set_linestyle(linestyle)
- # This is a bit of a hack to handle the fact that the edge lines are
- # centered on the actual extents of the bar, and overlap when bars are
- # stacked or dodged. We may discover that this causes problems and needs
- # to be revisited at some point. Also it should be faster to clip with
- # a bbox than a path, but I cant't work out how to get the intersection
- # with the axes bbox.
- bar.set_clip_path(bar.get_path(), bar.get_transform() + ax.transData)
- if self.artist_kws.get("clip_on", True):
- # It seems the above hack undoes the default axes clipping
- bar.set_clip_box(ax.bbox)
- bar.sticky_edges[val_idx][:] = (0, np.inf)
- ax.add_patch(bar)
- # Add a container which is useful for, e.g. Axes.bar_label
- if _version_predates(mpl, "3.4"):
- container_kws = {}
- else:
- orientation = {"x": "vertical", "y": "horizontal"}[orient]
- container_kws = dict(datavalues=vals, orientation=orientation)
- container = mpl.container.BarContainer(bars, **container_kws)
- ax.add_container(container)
- @document_properties
- @dataclass
- class Bars(BarBase):
- """
- A faster bar mark with defaults more suitable for histograms.
- See also
- --------
- Bar : A bar mark drawn between baseline and data values.
- Examples
- --------
- .. include:: ../docstrings/objects.Bars.rst
- """
- color: MappableColor = Mappable("C0", grouping=False)
- alpha: MappableFloat = Mappable(.7, grouping=False)
- fill: MappableBool = Mappable(True, grouping=False)
- edgecolor: MappableColor = Mappable(rc="patch.edgecolor", grouping=False)
- edgealpha: MappableFloat = Mappable(1, grouping=False)
- edgewidth: MappableFloat = Mappable(auto=True, grouping=False)
- edgestyle: MappableStyle = Mappable("-", grouping=False)
- # pattern: MappableString = Mappable(None) # TODO no Property yet
- width: MappableFloat = Mappable(1, grouping=False)
- baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
- def _plot(self, split_gen, scales, orient):
- ori_idx = ["x", "y"].index(orient)
- val_idx = ["y", "x"].index(orient)
- patches = defaultdict(list)
- for _, data, ax in split_gen():
- bars, _ = self._make_patches(data, scales, orient)
- patches[ax].extend(bars)
- collections = {}
- for ax, ax_patches in patches.items():
- col = mpl.collections.PatchCollection(ax_patches, match_original=True)
- col.sticky_edges[val_idx][:] = (0, np.inf)
- ax.add_collection(col, autolim=False)
- collections[ax] = col
- # Workaround for matplotlib autoscaling bug
- # https://github.com/matplotlib/matplotlib/issues/11898
- # https://github.com/matplotlib/matplotlib/issues/23129
- xys = np.vstack([path.vertices for path in col.get_paths()])
- ax.update_datalim(xys)
- if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable):
- for ax in collections:
- ax.autoscale_view()
- def get_dimensions(collection):
- edges, widths = [], []
- for verts in (path.vertices for path in collection.get_paths()):
- edges.append(min(verts[:, ori_idx]))
- widths.append(np.ptp(verts[:, ori_idx]))
- return np.array(edges), np.array(widths)
- min_width = np.inf
- for ax, col in collections.items():
- edges, widths = get_dimensions(col)
- points = 72 / ax.figure.dpi * abs(
- ax.transData.transform([edges + widths] * 2)
- - ax.transData.transform([edges] * 2)
- )
- min_width = min(min_width, min(points[:, ori_idx]))
- linewidth = min(.1 * min_width, mpl.rcParams["patch.linewidth"])
- for _, col in collections.items():
- col.set_linewidth(linewidth)
|