123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- from __future__ import annotations
- from dataclasses import dataclass
- from typing import ClassVar
- import numpy as np
- import matplotlib as mpl
- from seaborn._marks.base import (
- Mark,
- Mappable,
- MappableFloat,
- MappableString,
- MappableColor,
- resolve_properties,
- resolve_color,
- document_properties,
- )
- @document_properties
- @dataclass
- class Path(Mark):
- """
- A mark connecting data points in the order they appear.
- See also
- --------
- Line : A mark connecting data points with sorting along the orientation axis.
- Paths : A faster but less-flexible mark for drawing many paths.
- Examples
- --------
- .. include:: ../docstrings/objects.Path.rst
- """
- color: MappableColor = Mappable("C0")
- alpha: MappableFloat = Mappable(1)
- linewidth: MappableFloat = Mappable(rc="lines.linewidth")
- linestyle: MappableString = Mappable(rc="lines.linestyle")
- marker: MappableString = Mappable(rc="lines.marker")
- pointsize: MappableFloat = Mappable(rc="lines.markersize")
- fillcolor: MappableColor = Mappable(depend="color")
- edgecolor: MappableColor = Mappable(depend="color")
- edgewidth: MappableFloat = Mappable(rc="lines.markeredgewidth")
- _sort: ClassVar[bool] = False
- def _plot(self, split_gen, scales, orient):
- for keys, data, ax in split_gen(keep_na=not self._sort):
- vals = resolve_properties(self, keys, scales)
- vals["color"] = resolve_color(self, keys, scales=scales)
- vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
- vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)
- if self._sort:
- data = data.sort_values(orient, kind="mergesort")
- artist_kws = self.artist_kws.copy()
- self._handle_capstyle(artist_kws, vals)
- line = mpl.lines.Line2D(
- data["x"].to_numpy(),
- data["y"].to_numpy(),
- color=vals["color"],
- linewidth=vals["linewidth"],
- linestyle=vals["linestyle"],
- marker=vals["marker"],
- markersize=vals["pointsize"],
- markerfacecolor=vals["fillcolor"],
- markeredgecolor=vals["edgecolor"],
- markeredgewidth=vals["edgewidth"],
- **artist_kws,
- )
- ax.add_line(line)
- def _legend_artist(self, variables, value, scales):
- keys = {v: value for v in variables}
- vals = resolve_properties(self, keys, scales)
- vals["color"] = resolve_color(self, keys, scales=scales)
- vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
- vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)
- artist_kws = self.artist_kws.copy()
- self._handle_capstyle(artist_kws, vals)
- return mpl.lines.Line2D(
- [], [],
- color=vals["color"],
- linewidth=vals["linewidth"],
- linestyle=vals["linestyle"],
- marker=vals["marker"],
- markersize=vals["pointsize"],
- markerfacecolor=vals["fillcolor"],
- markeredgecolor=vals["edgecolor"],
- markeredgewidth=vals["edgewidth"],
- **artist_kws,
- )
- def _handle_capstyle(self, kws, vals):
- # Work around for this matplotlib issue:
- # https://github.com/matplotlib/matplotlib/issues/23437
- if vals["linestyle"][1] is None:
- capstyle = kws.get("solid_capstyle", mpl.rcParams["lines.solid_capstyle"])
- kws["dash_capstyle"] = capstyle
- @document_properties
- @dataclass
- class Line(Path):
- """
- A mark connecting data points with sorting along the orientation axis.
- See also
- --------
- Path : A mark connecting data points in the order they appear.
- Lines : A faster but less-flexible mark for drawing many lines.
- Examples
- --------
- .. include:: ../docstrings/objects.Line.rst
- """
- _sort: ClassVar[bool] = True
- @document_properties
- @dataclass
- class Paths(Mark):
- """
- A faster but less-flexible mark for drawing many paths.
- See also
- --------
- Path : A mark connecting data points in the order they appear.
- Examples
- --------
- .. include:: ../docstrings/objects.Paths.rst
- """
- color: MappableColor = Mappable("C0")
- alpha: MappableFloat = Mappable(1)
- linewidth: MappableFloat = Mappable(rc="lines.linewidth")
- linestyle: MappableString = Mappable(rc="lines.linestyle")
- _sort: ClassVar[bool] = False
- def __post_init__(self):
- # LineCollection artists have a capstyle property but don't source its value
- # from the rc, so we do that manually here. Unfortunately, because we add
- # only one LineCollection, we have the use the same capstyle for all lines
- # even when they are dashed. It's a slight inconsistency, but looks fine IMO.
- self.artist_kws.setdefault("capstyle", mpl.rcParams["lines.solid_capstyle"])
- def _plot(self, split_gen, scales, orient):
- line_data = {}
- for keys, data, ax in split_gen(keep_na=not self._sort):
- if ax not in line_data:
- line_data[ax] = {
- "segments": [],
- "colors": [],
- "linewidths": [],
- "linestyles": [],
- }
- segments = self._setup_segments(data, orient)
- line_data[ax]["segments"].extend(segments)
- n = len(segments)
- vals = resolve_properties(self, keys, scales)
- vals["color"] = resolve_color(self, keys, scales=scales)
- line_data[ax]["colors"].extend([vals["color"]] * n)
- line_data[ax]["linewidths"].extend([vals["linewidth"]] * n)
- line_data[ax]["linestyles"].extend([vals["linestyle"]] * n)
- for ax, ax_data in line_data.items():
- lines = mpl.collections.LineCollection(**ax_data, **self.artist_kws)
- # Handle datalim update manually
- # https://github.com/matplotlib/matplotlib/issues/23129
- ax.add_collection(lines, autolim=False)
- if ax_data["segments"]:
- xy = np.concatenate(ax_data["segments"])
- ax.update_datalim(xy)
- def _legend_artist(self, variables, value, scales):
- key = resolve_properties(self, {v: value for v in variables}, scales)
- artist_kws = self.artist_kws.copy()
- capstyle = artist_kws.pop("capstyle")
- artist_kws["solid_capstyle"] = capstyle
- artist_kws["dash_capstyle"] = capstyle
- return mpl.lines.Line2D(
- [], [],
- color=key["color"],
- linewidth=key["linewidth"],
- linestyle=key["linestyle"],
- **artist_kws,
- )
- def _setup_segments(self, data, orient):
- if self._sort:
- data = data.sort_values(orient, kind="mergesort")
- # Column stack to avoid block consolidation
- xy = np.column_stack([data["x"], data["y"]])
- return [xy]
- @document_properties
- @dataclass
- class Lines(Paths):
- """
- A faster but less-flexible mark for drawing many lines.
- See also
- --------
- Line : A mark connecting data points with sorting along the orientation axis.
- Examples
- --------
- .. include:: ../docstrings/objects.Lines.rst
- """
- _sort: ClassVar[bool] = True
- @document_properties
- @dataclass
- class Range(Paths):
- """
- An oriented line mark drawn between min/max values.
- Examples
- --------
- .. include:: ../docstrings/objects.Range.rst
- """
- def _setup_segments(self, data, orient):
- # TODO better checks on what variables we have
- # TODO what if only one exist?
- val = {"x": "y", "y": "x"}[orient]
- if not set(data.columns) & {f"{val}min", f"{val}max"}:
- agg = {f"{val}min": (val, "min"), f"{val}max": (val, "max")}
- data = data.groupby(orient).agg(**agg).reset_index()
- cols = [orient, f"{val}min", f"{val}max"]
- data = data[cols].melt(orient, value_name=val)[["x", "y"]]
- segments = [d.to_numpy() for _, d in data.groupby(orient)]
- return segments
- @document_properties
- @dataclass
- class Dash(Paths):
- """
- A line mark drawn as an oriented segment for each datapoint.
- Examples
- --------
- .. include:: ../docstrings/objects.Dash.rst
- """
- width: MappableFloat = Mappable(.8, grouping=False)
- def _setup_segments(self, data, orient):
- ori = ["x", "y"].index(orient)
- xys = data[["x", "y"]].to_numpy().astype(float)
- segments = np.stack([xys, xys], axis=1)
- segments[:, 0, ori] -= data["width"] / 2
- segments[:, 1, ori] += data["width"] / 2
- return segments
|