moves.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import ClassVar, Callable, Optional, Union, cast
  4. import numpy as np
  5. from pandas import DataFrame
  6. from seaborn._core.groupby import GroupBy
  7. from seaborn._core.scales import Scale
  8. from seaborn._core.typing import Default
  9. default = Default()
  10. @dataclass
  11. class Move:
  12. """Base class for objects that apply simple positional transforms."""
  13. group_by_orient: ClassVar[bool] = True
  14. def __call__(
  15. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  16. ) -> DataFrame:
  17. raise NotImplementedError
  18. @dataclass
  19. class Jitter(Move):
  20. """
  21. Random displacement along one or both axes to reduce overplotting.
  22. Parameters
  23. ----------
  24. width : float
  25. Magnitude of jitter, relative to mark width, along the orientation axis.
  26. If not provided, the default value will be 0 when `x` or `y` are set, otherwise
  27. there will be a small amount of jitter applied by default.
  28. x : float
  29. Magnitude of jitter, in data units, along the x axis.
  30. y : float
  31. Magnitude of jitter, in data units, along the y axis.
  32. Examples
  33. --------
  34. .. include:: ../docstrings/objects.Jitter.rst
  35. """
  36. width: float | Default = default
  37. x: float = 0
  38. y: float = 0
  39. seed: int | None = None
  40. def __call__(
  41. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  42. ) -> DataFrame:
  43. data = data.copy()
  44. rng = np.random.default_rng(self.seed)
  45. def jitter(data, col, scale):
  46. noise = rng.uniform(-.5, +.5, len(data))
  47. offsets = noise * scale
  48. return data[col] + offsets
  49. if self.width is default:
  50. width = 0.0 if self.x or self.y else 0.2
  51. else:
  52. width = cast(float, self.width)
  53. if self.width:
  54. data[orient] = jitter(data, orient, width * data["width"])
  55. if self.x:
  56. data["x"] = jitter(data, "x", self.x)
  57. if self.y:
  58. data["y"] = jitter(data, "y", self.y)
  59. return data
  60. @dataclass
  61. class Dodge(Move):
  62. """
  63. Displacement and narrowing of overlapping marks along orientation axis.
  64. Parameters
  65. ----------
  66. empty : {'keep', 'drop', 'fill'}
  67. gap : float
  68. Size of gap between dodged marks.
  69. by : list of variable names
  70. Variables to apply the movement to, otherwise use all.
  71. Examples
  72. --------
  73. .. include:: ../docstrings/objects.Dodge.rst
  74. """
  75. empty: str = "keep" # Options: keep, drop, fill
  76. gap: float = 0
  77. # TODO accept just a str here?
  78. # TODO should this always be present?
  79. # TODO should the default be an "all" singleton?
  80. by: Optional[list[str]] = None
  81. def __call__(
  82. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  83. ) -> DataFrame:
  84. grouping_vars = [v for v in groupby.order if v in data]
  85. groups = groupby.agg(data, {"width": "max"})
  86. if self.empty == "fill":
  87. groups = groups.dropna()
  88. def groupby_pos(s):
  89. grouper = [groups[v] for v in [orient, "col", "row"] if v in data]
  90. return s.groupby(grouper, sort=False, observed=True)
  91. def scale_widths(w):
  92. # TODO what value to fill missing widths??? Hard problem...
  93. # TODO short circuit this if outer widths has no variance?
  94. empty = 0 if self.empty == "fill" else w.mean()
  95. filled = w.fillna(empty)
  96. scale = filled.max()
  97. norm = filled.sum()
  98. if self.empty == "keep":
  99. w = filled
  100. return w / norm * scale
  101. def widths_to_offsets(w):
  102. return w.shift(1).fillna(0).cumsum() + (w - w.sum()) / 2
  103. new_widths = groupby_pos(groups["width"]).transform(scale_widths)
  104. offsets = groupby_pos(new_widths).transform(widths_to_offsets)
  105. if self.gap:
  106. new_widths *= 1 - self.gap
  107. groups["_dodged"] = groups[orient] + offsets
  108. groups["width"] = new_widths
  109. out = (
  110. data
  111. .drop("width", axis=1)
  112. .merge(groups, on=grouping_vars, how="left")
  113. .drop(orient, axis=1)
  114. .rename(columns={"_dodged": orient})
  115. )
  116. return out
  117. @dataclass
  118. class Stack(Move):
  119. """
  120. Displacement of overlapping bar or area marks along the value axis.
  121. Examples
  122. --------
  123. .. include:: ../docstrings/objects.Stack.rst
  124. """
  125. # TODO center? (or should this be a different move, eg. Stream())
  126. def _stack(self, df, orient):
  127. # TODO should stack do something with ymin/ymax style marks?
  128. # Should there be an upstream conversion to baseline/height parameterization?
  129. if df["baseline"].nunique() > 1:
  130. err = "Stack move cannot be used when baselines are already heterogeneous"
  131. raise RuntimeError(err)
  132. other = {"x": "y", "y": "x"}[orient]
  133. stacked_lengths = (df[other] - df["baseline"]).dropna().cumsum()
  134. offsets = stacked_lengths.shift(1).fillna(0)
  135. df[other] = stacked_lengths
  136. df["baseline"] = df["baseline"] + offsets
  137. return df
  138. def __call__(
  139. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  140. ) -> DataFrame:
  141. # TODO where to ensure that other semantic variables are sorted properly?
  142. # TODO why are we not using the passed in groupby here?
  143. groupers = ["col", "row", orient]
  144. return GroupBy(groupers).apply(data, self._stack, orient)
  145. @dataclass
  146. class Shift(Move):
  147. """
  148. Displacement of all marks with the same magnitude / direction.
  149. Parameters
  150. ----------
  151. x, y : float
  152. Magnitude of shift, in data units, along each axis.
  153. Examples
  154. --------
  155. .. include:: ../docstrings/objects.Shift.rst
  156. """
  157. x: float = 0
  158. y: float = 0
  159. def __call__(
  160. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  161. ) -> DataFrame:
  162. data = data.copy(deep=False)
  163. data["x"] = data["x"] + self.x
  164. data["y"] = data["y"] + self.y
  165. return data
  166. @dataclass
  167. class Norm(Move):
  168. """
  169. Divisive scaling on the value axis after aggregating within groups.
  170. Parameters
  171. ----------
  172. func : str or callable
  173. Function called on each group to define the comparison value.
  174. where : str
  175. Query string defining the subset used to define the comparison values.
  176. by : list of variables
  177. Variables used to define aggregation groups.
  178. percent : bool
  179. If True, multiply the result by 100.
  180. Examples
  181. --------
  182. .. include:: ../docstrings/objects.Norm.rst
  183. """
  184. func: Union[Callable, str] = "max"
  185. where: Optional[str] = None
  186. by: Optional[list[str]] = None
  187. percent: bool = False
  188. group_by_orient: ClassVar[bool] = False
  189. def _norm(self, df, var):
  190. if self.where is None:
  191. denom_data = df[var]
  192. else:
  193. denom_data = df.query(self.where)[var]
  194. df[var] = df[var] / denom_data.agg(self.func)
  195. if self.percent:
  196. df[var] = df[var] * 100
  197. return df
  198. def __call__(
  199. self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
  200. ) -> DataFrame:
  201. other = {"x": "y", "y": "x"}[orient]
  202. return groupby.apply(data, self._norm, other)
  203. # TODO
  204. # @dataclass
  205. # class Ridge(Move):
  206. # ...