line.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import ClassVar
  4. import numpy as np
  5. import matplotlib as mpl
  6. from seaborn._marks.base import (
  7. Mark,
  8. Mappable,
  9. MappableFloat,
  10. MappableString,
  11. MappableColor,
  12. resolve_properties,
  13. resolve_color,
  14. document_properties,
  15. )
  16. @document_properties
  17. @dataclass
  18. class Path(Mark):
  19. """
  20. A mark connecting data points in the order they appear.
  21. See also
  22. --------
  23. Line : A mark connecting data points with sorting along the orientation axis.
  24. Paths : A faster but less-flexible mark for drawing many paths.
  25. Examples
  26. --------
  27. .. include:: ../docstrings/objects.Path.rst
  28. """
  29. color: MappableColor = Mappable("C0")
  30. alpha: MappableFloat = Mappable(1)
  31. linewidth: MappableFloat = Mappable(rc="lines.linewidth")
  32. linestyle: MappableString = Mappable(rc="lines.linestyle")
  33. marker: MappableString = Mappable(rc="lines.marker")
  34. pointsize: MappableFloat = Mappable(rc="lines.markersize")
  35. fillcolor: MappableColor = Mappable(depend="color")
  36. edgecolor: MappableColor = Mappable(depend="color")
  37. edgewidth: MappableFloat = Mappable(rc="lines.markeredgewidth")
  38. _sort: ClassVar[bool] = False
  39. def _plot(self, split_gen, scales, orient):
  40. for keys, data, ax in split_gen(keep_na=not self._sort):
  41. vals = resolve_properties(self, keys, scales)
  42. vals["color"] = resolve_color(self, keys, scales=scales)
  43. vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
  44. vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)
  45. if self._sort:
  46. data = data.sort_values(orient, kind="mergesort")
  47. artist_kws = self.artist_kws.copy()
  48. self._handle_capstyle(artist_kws, vals)
  49. line = mpl.lines.Line2D(
  50. data["x"].to_numpy(),
  51. data["y"].to_numpy(),
  52. color=vals["color"],
  53. linewidth=vals["linewidth"],
  54. linestyle=vals["linestyle"],
  55. marker=vals["marker"],
  56. markersize=vals["pointsize"],
  57. markerfacecolor=vals["fillcolor"],
  58. markeredgecolor=vals["edgecolor"],
  59. markeredgewidth=vals["edgewidth"],
  60. **artist_kws,
  61. )
  62. ax.add_line(line)
  63. def _legend_artist(self, variables, value, scales):
  64. keys = {v: value for v in variables}
  65. vals = resolve_properties(self, keys, scales)
  66. vals["color"] = resolve_color(self, keys, scales=scales)
  67. vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
  68. vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)
  69. artist_kws = self.artist_kws.copy()
  70. self._handle_capstyle(artist_kws, vals)
  71. return mpl.lines.Line2D(
  72. [], [],
  73. color=vals["color"],
  74. linewidth=vals["linewidth"],
  75. linestyle=vals["linestyle"],
  76. marker=vals["marker"],
  77. markersize=vals["pointsize"],
  78. markerfacecolor=vals["fillcolor"],
  79. markeredgecolor=vals["edgecolor"],
  80. markeredgewidth=vals["edgewidth"],
  81. **artist_kws,
  82. )
  83. def _handle_capstyle(self, kws, vals):
  84. # Work around for this matplotlib issue:
  85. # https://github.com/matplotlib/matplotlib/issues/23437
  86. if vals["linestyle"][1] is None:
  87. capstyle = kws.get("solid_capstyle", mpl.rcParams["lines.solid_capstyle"])
  88. kws["dash_capstyle"] = capstyle
  89. @document_properties
  90. @dataclass
  91. class Line(Path):
  92. """
  93. A mark connecting data points with sorting along the orientation axis.
  94. See also
  95. --------
  96. Path : A mark connecting data points in the order they appear.
  97. Lines : A faster but less-flexible mark for drawing many lines.
  98. Examples
  99. --------
  100. .. include:: ../docstrings/objects.Line.rst
  101. """
  102. _sort: ClassVar[bool] = True
  103. @document_properties
  104. @dataclass
  105. class Paths(Mark):
  106. """
  107. A faster but less-flexible mark for drawing many paths.
  108. See also
  109. --------
  110. Path : A mark connecting data points in the order they appear.
  111. Examples
  112. --------
  113. .. include:: ../docstrings/objects.Paths.rst
  114. """
  115. color: MappableColor = Mappable("C0")
  116. alpha: MappableFloat = Mappable(1)
  117. linewidth: MappableFloat = Mappable(rc="lines.linewidth")
  118. linestyle: MappableString = Mappable(rc="lines.linestyle")
  119. _sort: ClassVar[bool] = False
  120. def __post_init__(self):
  121. # LineCollection artists have a capstyle property but don't source its value
  122. # from the rc, so we do that manually here. Unfortunately, because we add
  123. # only one LineCollection, we have the use the same capstyle for all lines
  124. # even when they are dashed. It's a slight inconsistency, but looks fine IMO.
  125. self.artist_kws.setdefault("capstyle", mpl.rcParams["lines.solid_capstyle"])
  126. def _plot(self, split_gen, scales, orient):
  127. line_data = {}
  128. for keys, data, ax in split_gen(keep_na=not self._sort):
  129. if ax not in line_data:
  130. line_data[ax] = {
  131. "segments": [],
  132. "colors": [],
  133. "linewidths": [],
  134. "linestyles": [],
  135. }
  136. segments = self._setup_segments(data, orient)
  137. line_data[ax]["segments"].extend(segments)
  138. n = len(segments)
  139. vals = resolve_properties(self, keys, scales)
  140. vals["color"] = resolve_color(self, keys, scales=scales)
  141. line_data[ax]["colors"].extend([vals["color"]] * n)
  142. line_data[ax]["linewidths"].extend([vals["linewidth"]] * n)
  143. line_data[ax]["linestyles"].extend([vals["linestyle"]] * n)
  144. for ax, ax_data in line_data.items():
  145. lines = mpl.collections.LineCollection(**ax_data, **self.artist_kws)
  146. # Handle datalim update manually
  147. # https://github.com/matplotlib/matplotlib/issues/23129
  148. ax.add_collection(lines, autolim=False)
  149. if ax_data["segments"]:
  150. xy = np.concatenate(ax_data["segments"])
  151. ax.update_datalim(xy)
  152. def _legend_artist(self, variables, value, scales):
  153. key = resolve_properties(self, {v: value for v in variables}, scales)
  154. artist_kws = self.artist_kws.copy()
  155. capstyle = artist_kws.pop("capstyle")
  156. artist_kws["solid_capstyle"] = capstyle
  157. artist_kws["dash_capstyle"] = capstyle
  158. return mpl.lines.Line2D(
  159. [], [],
  160. color=key["color"],
  161. linewidth=key["linewidth"],
  162. linestyle=key["linestyle"],
  163. **artist_kws,
  164. )
  165. def _setup_segments(self, data, orient):
  166. if self._sort:
  167. data = data.sort_values(orient, kind="mergesort")
  168. # Column stack to avoid block consolidation
  169. xy = np.column_stack([data["x"], data["y"]])
  170. return [xy]
  171. @document_properties
  172. @dataclass
  173. class Lines(Paths):
  174. """
  175. A faster but less-flexible mark for drawing many lines.
  176. See also
  177. --------
  178. Line : A mark connecting data points with sorting along the orientation axis.
  179. Examples
  180. --------
  181. .. include:: ../docstrings/objects.Lines.rst
  182. """
  183. _sort: ClassVar[bool] = True
  184. @document_properties
  185. @dataclass
  186. class Range(Paths):
  187. """
  188. An oriented line mark drawn between min/max values.
  189. Examples
  190. --------
  191. .. include:: ../docstrings/objects.Range.rst
  192. """
  193. def _setup_segments(self, data, orient):
  194. # TODO better checks on what variables we have
  195. # TODO what if only one exist?
  196. val = {"x": "y", "y": "x"}[orient]
  197. if not set(data.columns) & {f"{val}min", f"{val}max"}:
  198. agg = {f"{val}min": (val, "min"), f"{val}max": (val, "max")}
  199. data = data.groupby(orient).agg(**agg).reset_index()
  200. cols = [orient, f"{val}min", f"{val}max"]
  201. data = data[cols].melt(orient, value_name=val)[["x", "y"]]
  202. segments = [d.to_numpy() for _, d in data.groupby(orient)]
  203. return segments
  204. @document_properties
  205. @dataclass
  206. class Dash(Paths):
  207. """
  208. A line mark drawn as an oriented segment for each datapoint.
  209. Examples
  210. --------
  211. .. include:: ../docstrings/objects.Dash.rst
  212. """
  213. width: MappableFloat = Mappable(.8, grouping=False)
  214. def _setup_segments(self, data, orient):
  215. ori = ["x", "y"].index(orient)
  216. xys = data[["x", "y"]].to_numpy().astype(float)
  217. segments = np.stack([xys, xys], axis=1)
  218. segments[:, 0, ori] -= data["width"] / 2
  219. segments[:, 1, ori] += data["width"] / 2
  220. return segments