123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200 |
- from __future__ import annotations
- from dataclasses import dataclass
- import numpy as np
- import matplotlib as mpl
- from seaborn._marks.base import (
- Mark,
- Mappable,
- MappableBool,
- MappableFloat,
- MappableString,
- MappableColor,
- MappableStyle,
- resolve_properties,
- resolve_color,
- document_properties,
- )
- from typing import TYPE_CHECKING
- if TYPE_CHECKING:
- from typing import Any
- from matplotlib.artist import Artist
- from seaborn._core.scales import Scale
- class DotBase(Mark):
- def _resolve_paths(self, data):
- paths = []
- path_cache = {}
- marker = data["marker"]
- def get_transformed_path(m):
- return m.get_path().transformed(m.get_transform())
- if isinstance(marker, mpl.markers.MarkerStyle):
- return get_transformed_path(marker)
- for m in marker:
- if m not in path_cache:
- path_cache[m] = get_transformed_path(m)
- paths.append(path_cache[m])
- return paths
- def _resolve_properties(self, data, scales):
- resolved = resolve_properties(self, data, scales)
- resolved["path"] = self._resolve_paths(resolved)
- resolved["size"] = resolved["pointsize"] ** 2
- if isinstance(data, dict): # Properties for single dot
- filled_marker = resolved["marker"].is_filled()
- else:
- filled_marker = [m.is_filled() for m in resolved["marker"]]
- resolved["fill"] = resolved["fill"] * filled_marker
- return resolved
- def _plot(self, split_gen, scales, orient):
- # TODO Not backcompat with allowed (but nonfunctional) univariate plots
- # (That should be solved upstream by defaulting to "" for unset x/y?)
- # (Be mindful of xmin/xmax, etc!)
- for _, data, ax in split_gen():
- offsets = np.column_stack([data["x"], data["y"]])
- data = self._resolve_properties(data, scales)
- points = mpl.collections.PathCollection(
- offsets=offsets,
- paths=data["path"],
- sizes=data["size"],
- facecolors=data["facecolor"],
- edgecolors=data["edgecolor"],
- linewidths=data["linewidth"],
- linestyles=data["edgestyle"],
- transOffset=ax.transData,
- transform=mpl.transforms.IdentityTransform(),
- **self.artist_kws,
- )
- ax.add_collection(points)
- def _legend_artist(
- self, variables: list[str], value: Any, scales: dict[str, Scale],
- ) -> Artist:
- key = {v: value for v in variables}
- res = self._resolve_properties(key, scales)
- return mpl.collections.PathCollection(
- paths=[res["path"]],
- sizes=[res["size"]],
- facecolors=[res["facecolor"]],
- edgecolors=[res["edgecolor"]],
- linewidths=[res["linewidth"]],
- linestyles=[res["edgestyle"]],
- transform=mpl.transforms.IdentityTransform(),
- **self.artist_kws,
- )
- @document_properties
- @dataclass
- class Dot(DotBase):
- """
- A mark suitable for dot plots or less-dense scatterplots.
- See also
- --------
- Dots : A dot mark defined by strokes to better handle overplotting.
- Examples
- --------
- .. include:: ../docstrings/objects.Dot.rst
- """
- marker: MappableString = Mappable("o", grouping=False)
- pointsize: MappableFloat = Mappable(6, grouping=False) # TODO rcParam?
- stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam?
- color: MappableColor = Mappable("C0", grouping=False)
- alpha: MappableFloat = Mappable(1, grouping=False)
- fill: MappableBool = Mappable(True, grouping=False)
- edgecolor: MappableColor = Mappable(depend="color", grouping=False)
- edgealpha: MappableFloat = Mappable(depend="alpha", grouping=False)
- edgewidth: MappableFloat = Mappable(.5, grouping=False) # TODO rcParam?
- edgestyle: MappableStyle = Mappable("-", grouping=False)
- def _resolve_properties(self, data, scales):
- resolved = super()._resolve_properties(data, scales)
- filled = resolved["fill"]
- main_stroke = resolved["stroke"]
- edge_stroke = resolved["edgewidth"]
- resolved["linewidth"] = np.where(filled, edge_stroke, main_stroke)
- main_color = resolve_color(self, data, "", scales)
- edge_color = resolve_color(self, data, "edge", scales)
- if not np.isscalar(filled):
- # Expand dims to use in np.where with rgba arrays
- filled = filled[:, None]
- resolved["edgecolor"] = np.where(filled, edge_color, main_color)
- filled = np.squeeze(filled)
- if isinstance(main_color, tuple):
- # TODO handle this in resolve_color
- main_color = tuple([*main_color[:3], main_color[3] * filled])
- else:
- main_color = np.c_[main_color[:, :3], main_color[:, 3] * filled]
- resolved["facecolor"] = main_color
- return resolved
- @document_properties
- @dataclass
- class Dots(DotBase):
- """
- A dot mark defined by strokes to better handle overplotting.
- See also
- --------
- Dot : A mark suitable for dot plots or less-dense scatterplots.
- Examples
- --------
- .. include:: ../docstrings/objects.Dots.rst
- """
- # TODO retype marker as MappableMarker
- marker: MappableString = Mappable(rc="scatter.marker", grouping=False)
- pointsize: MappableFloat = Mappable(4, grouping=False) # TODO rcParam?
- stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam?
- color: MappableColor = Mappable("C0", grouping=False)
- alpha: MappableFloat = Mappable(1, grouping=False) # TODO auto alpha?
- fill: MappableBool = Mappable(True, grouping=False)
- fillcolor: MappableColor = Mappable(depend="color", grouping=False)
- fillalpha: MappableFloat = Mappable(.2, grouping=False)
- def _resolve_properties(self, data, scales):
- resolved = super()._resolve_properties(data, scales)
- resolved["linewidth"] = resolved.pop("stroke")
- resolved["facecolor"] = resolve_color(self, data, "fill", scales)
- resolved["edgecolor"] = resolve_color(self, data, "", scales)
- resolved.setdefault("edgestyle", (0, None))
- fc = resolved["facecolor"]
- if isinstance(fc, tuple):
- resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"]
- else:
- fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem?
- resolved["facecolor"] = fc
- return resolved
|