dot.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. import numpy as np
  4. import matplotlib as mpl
  5. from seaborn._marks.base import (
  6. Mark,
  7. Mappable,
  8. MappableBool,
  9. MappableFloat,
  10. MappableString,
  11. MappableColor,
  12. MappableStyle,
  13. resolve_properties,
  14. resolve_color,
  15. document_properties,
  16. )
  17. from typing import TYPE_CHECKING
  18. if TYPE_CHECKING:
  19. from typing import Any
  20. from matplotlib.artist import Artist
  21. from seaborn._core.scales import Scale
  22. class DotBase(Mark):
  23. def _resolve_paths(self, data):
  24. paths = []
  25. path_cache = {}
  26. marker = data["marker"]
  27. def get_transformed_path(m):
  28. return m.get_path().transformed(m.get_transform())
  29. if isinstance(marker, mpl.markers.MarkerStyle):
  30. return get_transformed_path(marker)
  31. for m in marker:
  32. if m not in path_cache:
  33. path_cache[m] = get_transformed_path(m)
  34. paths.append(path_cache[m])
  35. return paths
  36. def _resolve_properties(self, data, scales):
  37. resolved = resolve_properties(self, data, scales)
  38. resolved["path"] = self._resolve_paths(resolved)
  39. resolved["size"] = resolved["pointsize"] ** 2
  40. if isinstance(data, dict): # Properties for single dot
  41. filled_marker = resolved["marker"].is_filled()
  42. else:
  43. filled_marker = [m.is_filled() for m in resolved["marker"]]
  44. resolved["fill"] = resolved["fill"] * filled_marker
  45. return resolved
  46. def _plot(self, split_gen, scales, orient):
  47. # TODO Not backcompat with allowed (but nonfunctional) univariate plots
  48. # (That should be solved upstream by defaulting to "" for unset x/y?)
  49. # (Be mindful of xmin/xmax, etc!)
  50. for _, data, ax in split_gen():
  51. offsets = np.column_stack([data["x"], data["y"]])
  52. data = self._resolve_properties(data, scales)
  53. points = mpl.collections.PathCollection(
  54. offsets=offsets,
  55. paths=data["path"],
  56. sizes=data["size"],
  57. facecolors=data["facecolor"],
  58. edgecolors=data["edgecolor"],
  59. linewidths=data["linewidth"],
  60. linestyles=data["edgestyle"],
  61. transOffset=ax.transData,
  62. transform=mpl.transforms.IdentityTransform(),
  63. **self.artist_kws,
  64. )
  65. ax.add_collection(points)
  66. def _legend_artist(
  67. self, variables: list[str], value: Any, scales: dict[str, Scale],
  68. ) -> Artist:
  69. key = {v: value for v in variables}
  70. res = self._resolve_properties(key, scales)
  71. return mpl.collections.PathCollection(
  72. paths=[res["path"]],
  73. sizes=[res["size"]],
  74. facecolors=[res["facecolor"]],
  75. edgecolors=[res["edgecolor"]],
  76. linewidths=[res["linewidth"]],
  77. linestyles=[res["edgestyle"]],
  78. transform=mpl.transforms.IdentityTransform(),
  79. **self.artist_kws,
  80. )
  81. @document_properties
  82. @dataclass
  83. class Dot(DotBase):
  84. """
  85. A mark suitable for dot plots or less-dense scatterplots.
  86. See also
  87. --------
  88. Dots : A dot mark defined by strokes to better handle overplotting.
  89. Examples
  90. --------
  91. .. include:: ../docstrings/objects.Dot.rst
  92. """
  93. marker: MappableString = Mappable("o", grouping=False)
  94. pointsize: MappableFloat = Mappable(6, grouping=False) # TODO rcParam?
  95. stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam?
  96. color: MappableColor = Mappable("C0", grouping=False)
  97. alpha: MappableFloat = Mappable(1, grouping=False)
  98. fill: MappableBool = Mappable(True, grouping=False)
  99. edgecolor: MappableColor = Mappable(depend="color", grouping=False)
  100. edgealpha: MappableFloat = Mappable(depend="alpha", grouping=False)
  101. edgewidth: MappableFloat = Mappable(.5, grouping=False) # TODO rcParam?
  102. edgestyle: MappableStyle = Mappable("-", grouping=False)
  103. def _resolve_properties(self, data, scales):
  104. resolved = super()._resolve_properties(data, scales)
  105. filled = resolved["fill"]
  106. main_stroke = resolved["stroke"]
  107. edge_stroke = resolved["edgewidth"]
  108. resolved["linewidth"] = np.where(filled, edge_stroke, main_stroke)
  109. main_color = resolve_color(self, data, "", scales)
  110. edge_color = resolve_color(self, data, "edge", scales)
  111. if not np.isscalar(filled):
  112. # Expand dims to use in np.where with rgba arrays
  113. filled = filled[:, None]
  114. resolved["edgecolor"] = np.where(filled, edge_color, main_color)
  115. filled = np.squeeze(filled)
  116. if isinstance(main_color, tuple):
  117. # TODO handle this in resolve_color
  118. main_color = tuple([*main_color[:3], main_color[3] * filled])
  119. else:
  120. main_color = np.c_[main_color[:, :3], main_color[:, 3] * filled]
  121. resolved["facecolor"] = main_color
  122. return resolved
  123. @document_properties
  124. @dataclass
  125. class Dots(DotBase):
  126. """
  127. A dot mark defined by strokes to better handle overplotting.
  128. See also
  129. --------
  130. Dot : A mark suitable for dot plots or less-dense scatterplots.
  131. Examples
  132. --------
  133. .. include:: ../docstrings/objects.Dots.rst
  134. """
  135. # TODO retype marker as MappableMarker
  136. marker: MappableString = Mappable(rc="scatter.marker", grouping=False)
  137. pointsize: MappableFloat = Mappable(4, grouping=False) # TODO rcParam?
  138. stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam?
  139. color: MappableColor = Mappable("C0", grouping=False)
  140. alpha: MappableFloat = Mappable(1, grouping=False) # TODO auto alpha?
  141. fill: MappableBool = Mappable(True, grouping=False)
  142. fillcolor: MappableColor = Mappable(depend="color", grouping=False)
  143. fillalpha: MappableFloat = Mappable(.2, grouping=False)
  144. def _resolve_properties(self, data, scales):
  145. resolved = super()._resolve_properties(data, scales)
  146. resolved["linewidth"] = resolved.pop("stroke")
  147. resolved["facecolor"] = resolve_color(self, data, "fill", scales)
  148. resolved["edgecolor"] = resolve_color(self, data, "", scales)
  149. resolved.setdefault("edgestyle", (0, None))
  150. fc = resolved["facecolor"]
  151. if isinstance(fc, tuple):
  152. resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"]
  153. else:
  154. fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem?
  155. resolved["facecolor"] = fc
  156. return resolved