scales.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092
  1. from __future__ import annotations
  2. import re
  3. from copy import copy
  4. from collections.abc import Sequence
  5. from dataclasses import dataclass
  6. from functools import partial
  7. from typing import Any, Callable, Tuple, Optional, ClassVar
  8. import numpy as np
  9. import matplotlib as mpl
  10. from matplotlib.ticker import (
  11. Locator,
  12. Formatter,
  13. AutoLocator,
  14. AutoMinorLocator,
  15. FixedLocator,
  16. LinearLocator,
  17. LogLocator,
  18. SymmetricalLogLocator,
  19. MaxNLocator,
  20. MultipleLocator,
  21. EngFormatter,
  22. FuncFormatter,
  23. LogFormatterSciNotation,
  24. ScalarFormatter,
  25. StrMethodFormatter,
  26. )
  27. from matplotlib.dates import (
  28. AutoDateLocator,
  29. AutoDateFormatter,
  30. ConciseDateFormatter,
  31. )
  32. from matplotlib.axis import Axis
  33. from matplotlib.scale import ScaleBase
  34. from pandas import Series
  35. from seaborn._core.rules import categorical_order
  36. from seaborn._core.typing import Default, default
  37. from typing import TYPE_CHECKING
  38. if TYPE_CHECKING:
  39. from seaborn._core.plot import Plot
  40. from seaborn._core.properties import Property
  41. from numpy.typing import ArrayLike, NDArray
  42. TransFuncs = Tuple[
  43. Callable[[ArrayLike], ArrayLike], Callable[[ArrayLike], ArrayLike]
  44. ]
  45. # TODO Reverting typing to Any as it was proving too complicated to
  46. # work out the right way to communicate the types to mypy. Revisit!
  47. Pipeline = Sequence[Optional[Callable[[Any], Any]]]
  48. class Scale:
  49. """Base class for objects that map data values to visual properties."""
  50. values: tuple | str | list | dict | None
  51. _priority: ClassVar[int]
  52. _pipeline: Pipeline
  53. _matplotlib_scale: ScaleBase
  54. _spacer: staticmethod
  55. _legend: tuple[list[Any], list[str]] | None
  56. def __post_init__(self):
  57. self._tick_params = None
  58. self._label_params = None
  59. self._legend = None
  60. def tick(self):
  61. raise NotImplementedError()
  62. def label(self):
  63. raise NotImplementedError()
  64. def _get_locators(self):
  65. raise NotImplementedError()
  66. def _get_formatter(self, locator: Locator | None = None):
  67. raise NotImplementedError()
  68. def _get_scale(self, name: str, forward: Callable, inverse: Callable):
  69. major_locator, minor_locator = self._get_locators(**self._tick_params)
  70. major_formatter = self._get_formatter(major_locator, **self._label_params)
  71. class InternalScale(mpl.scale.FuncScale):
  72. def set_default_locators_and_formatters(self, axis):
  73. axis.set_major_locator(major_locator)
  74. if minor_locator is not None:
  75. axis.set_minor_locator(minor_locator)
  76. axis.set_major_formatter(major_formatter)
  77. return InternalScale(name, (forward, inverse))
  78. def _spacing(self, x: Series) -> float:
  79. space = self._spacer(x)
  80. if np.isnan(space):
  81. # This happens when there is no variance in the orient coordinate data
  82. # Not exactly clear what the right default is, but 1 seems reasonable?
  83. return 1
  84. return space
  85. def _setup(
  86. self, data: Series, prop: Property, axis: Axis | None = None,
  87. ) -> Scale:
  88. raise NotImplementedError()
  89. def _finalize(self, p: Plot, axis: Axis) -> None:
  90. """Perform scale-specific axis tweaks after adding artists."""
  91. pass
  92. def __call__(self, data: Series) -> ArrayLike:
  93. trans_data: Series | NDArray | list
  94. # TODO sometimes we need to handle scalars (e.g. for Line)
  95. # but what is the best way to do that?
  96. scalar_data = np.isscalar(data)
  97. if scalar_data:
  98. trans_data = np.array([data])
  99. else:
  100. trans_data = data
  101. for func in self._pipeline:
  102. if func is not None:
  103. trans_data = func(trans_data)
  104. if scalar_data:
  105. return trans_data[0]
  106. else:
  107. return trans_data
  108. @staticmethod
  109. def _identity():
  110. class Identity(Scale):
  111. _pipeline = []
  112. _spacer = None
  113. _legend = None
  114. _matplotlib_scale = None
  115. return Identity()
  116. @dataclass
  117. class Boolean(Scale):
  118. """
  119. A scale with a discrete domain of True and False values.
  120. The behavior is similar to the :class:`Nominal` scale, but property
  121. mappings and legends will use a [True, False] ordering rather than
  122. a sort using numeric rules. Coordinate variables accomplish this by
  123. inverting axis limits so as to maintain underlying numeric positioning.
  124. Input data are cast to boolean values, respecting missing data.
  125. """
  126. values: tuple | list | dict | None = None
  127. _priority: ClassVar[int] = 3
  128. def _setup(
  129. self, data: Series, prop: Property, axis: Axis | None = None,
  130. ) -> Scale:
  131. new = copy(self)
  132. if new._tick_params is None:
  133. new = new.tick()
  134. if new._label_params is None:
  135. new = new.label()
  136. def na_safe_cast(x):
  137. # TODO this doesn't actually need to be a closure
  138. if np.isscalar(x):
  139. return float(bool(x))
  140. else:
  141. if hasattr(x, "notna"):
  142. # Handle pd.NA; np<>pd interop with NA is tricky
  143. use = x.notna().to_numpy()
  144. else:
  145. use = np.isfinite(x)
  146. out = np.full(len(x), np.nan, dtype=float)
  147. out[use] = x[use].astype(bool).astype(float)
  148. return out
  149. new._pipeline = [na_safe_cast, prop.get_mapping(new, data)]
  150. new._spacer = _default_spacer
  151. if prop.legend:
  152. new._legend = [True, False], ["True", "False"]
  153. forward, inverse = _make_identity_transforms()
  154. mpl_scale = new._get_scale(str(data.name), forward, inverse)
  155. axis = PseudoAxis(mpl_scale) if axis is None else axis
  156. mpl_scale.set_default_locators_and_formatters(axis)
  157. new._matplotlib_scale = mpl_scale
  158. return new
  159. def _finalize(self, p: Plot, axis: Axis) -> None:
  160. # We want values to appear in a True, False order but also want
  161. # True/False to be drawn at 1/0 positions respectively to avoid nasty
  162. # surprises if additional artists are added through the matplotlib API.
  163. # We accomplish this using axis inversion akin to what we do in Nominal.
  164. ax = axis.axes
  165. name = axis.axis_name
  166. axis.grid(False, which="both")
  167. if name not in p._limits:
  168. nticks = len(axis.get_major_ticks())
  169. lo, hi = -.5, nticks - .5
  170. if name == "x":
  171. lo, hi = hi, lo
  172. set_lim = getattr(ax, f"set_{name}lim")
  173. set_lim(lo, hi, auto=None)
  174. def tick(self, locator: Locator | None = None):
  175. new = copy(self)
  176. new._tick_params = {"locator": locator}
  177. return new
  178. def label(self, formatter: Formatter | None = None):
  179. new = copy(self)
  180. new._label_params = {"formatter": formatter}
  181. return new
  182. def _get_locators(self, locator):
  183. if locator is not None:
  184. return locator
  185. return FixedLocator([0, 1]), None
  186. def _get_formatter(self, locator, formatter):
  187. if formatter is not None:
  188. return formatter
  189. return FuncFormatter(lambda x, _: str(bool(x)))
  190. @dataclass
  191. class Nominal(Scale):
  192. """
  193. A categorical scale without relative importance / magnitude.
  194. """
  195. # Categorical (convert to strings), un-sortable
  196. values: tuple | str | list | dict | None = None
  197. order: list | None = None
  198. _priority: ClassVar[int] = 4
  199. def _setup(
  200. self, data: Series, prop: Property, axis: Axis | None = None,
  201. ) -> Scale:
  202. new = copy(self)
  203. if new._tick_params is None:
  204. new = new.tick()
  205. if new._label_params is None:
  206. new = new.label()
  207. # TODO flexibility over format() which isn't great for numbers / dates
  208. stringify = np.vectorize(format, otypes=["object"])
  209. units_seed = categorical_order(data, new.order)
  210. # TODO move to Nominal._get_scale?
  211. # TODO this needs some more complicated rethinking about how to pass
  212. # a unit dictionary down to these methods, along with how much we want
  213. # to invest in their API. What is it useful for tick() to do here?
  214. # (Ordinal may be different if we draw that contrast).
  215. # Any customization we do to allow, e.g., label wrapping will probably
  216. # require defining our own Formatter subclass.
  217. # We could also potentially implement auto-wrapping in an Axis subclass
  218. # (see Axis.draw ... it already is computing the bboxes).
  219. # major_locator, minor_locator = new._get_locators(**new._tick_params)
  220. # major_formatter = new._get_formatter(major_locator, **new._label_params)
  221. class CatScale(mpl.scale.LinearScale):
  222. name = None # To work around mpl<3.4 compat issues
  223. def set_default_locators_and_formatters(self, axis):
  224. ...
  225. # axis.set_major_locator(major_locator)
  226. # if minor_locator is not None:
  227. # axis.set_minor_locator(minor_locator)
  228. # axis.set_major_formatter(major_formatter)
  229. mpl_scale = CatScale(data.name)
  230. if axis is None:
  231. axis = PseudoAxis(mpl_scale)
  232. # TODO Currently just used in non-Coordinate contexts, but should
  233. # we use this to (A) set the padding we want for categorial plots
  234. # and (B) allow the values parameter for a Coordinate to set xlim/ylim
  235. axis.set_view_interval(0, len(units_seed) - 1)
  236. new._matplotlib_scale = mpl_scale
  237. # TODO array cast necessary to handle float/int mixture, which we need
  238. # to solve in a more systematic way probably
  239. # (i.e. if we have [1, 2.5], do we want [1.0, 2.5]? Unclear)
  240. axis.update_units(stringify(np.array(units_seed)))
  241. # TODO define this more centrally
  242. def convert_units(x):
  243. # TODO only do this with explicit order?
  244. # (But also category dtype?)
  245. # TODO isin fails when units_seed mixes numbers and strings (numpy error?)
  246. # but np.isin also does not seem any faster? (Maybe not broadcasting in C)
  247. # keep = x.isin(units_seed)
  248. keep = np.array([x_ in units_seed for x_ in x], bool)
  249. out = np.full(len(x), np.nan)
  250. out[keep] = axis.convert_units(stringify(x[keep]))
  251. return out
  252. new._pipeline = [convert_units, prop.get_mapping(new, data)]
  253. new._spacer = _default_spacer
  254. if prop.legend:
  255. new._legend = units_seed, list(stringify(units_seed))
  256. return new
  257. def _finalize(self, p: Plot, axis: Axis) -> None:
  258. ax = axis.axes
  259. name = axis.axis_name
  260. axis.grid(False, which="both")
  261. if name not in p._limits:
  262. nticks = len(axis.get_major_ticks())
  263. lo, hi = -.5, nticks - .5
  264. if name == "y":
  265. lo, hi = hi, lo
  266. set_lim = getattr(ax, f"set_{name}lim")
  267. set_lim(lo, hi, auto=None)
  268. def tick(self, locator: Locator | None = None) -> Nominal:
  269. """
  270. Configure the selection of ticks for the scale's axis or legend.
  271. .. note::
  272. This API is under construction and will be enhanced over time.
  273. At the moment, it is probably not very useful.
  274. Parameters
  275. ----------
  276. locator : :class:`matplotlib.ticker.Locator` subclass
  277. Pre-configured matplotlib locator; other parameters will not be used.
  278. Returns
  279. -------
  280. Copy of self with new tick configuration.
  281. """
  282. new = copy(self)
  283. new._tick_params = {"locator": locator}
  284. return new
  285. def label(self, formatter: Formatter | None = None) -> Nominal:
  286. """
  287. Configure the selection of labels for the scale's axis or legend.
  288. .. note::
  289. This API is under construction and will be enhanced over time.
  290. At the moment, it is probably not very useful.
  291. Parameters
  292. ----------
  293. formatter : :class:`matplotlib.ticker.Formatter` subclass
  294. Pre-configured matplotlib formatter; other parameters will not be used.
  295. Returns
  296. -------
  297. scale
  298. Copy of self with new tick configuration.
  299. """
  300. new = copy(self)
  301. new._label_params = {"formatter": formatter}
  302. return new
  303. def _get_locators(self, locator):
  304. if locator is not None:
  305. return locator, None
  306. locator = mpl.category.StrCategoryLocator({})
  307. return locator, None
  308. def _get_formatter(self, locator, formatter):
  309. if formatter is not None:
  310. return formatter
  311. formatter = mpl.category.StrCategoryFormatter({})
  312. return formatter
  313. @dataclass
  314. class Ordinal(Scale):
  315. # Categorical (convert to strings), sortable, can skip ticklabels
  316. ...
  317. @dataclass
  318. class Discrete(Scale):
  319. # Numeric, integral, can skip ticks/ticklabels
  320. ...
  321. @dataclass
  322. class ContinuousBase(Scale):
  323. values: tuple | str | None = None
  324. norm: tuple | None = None
  325. def _setup(
  326. self, data: Series, prop: Property, axis: Axis | None = None,
  327. ) -> Scale:
  328. new = copy(self)
  329. if new._tick_params is None:
  330. new = new.tick()
  331. if new._label_params is None:
  332. new = new.label()
  333. forward, inverse = new._get_transform()
  334. mpl_scale = new._get_scale(str(data.name), forward, inverse)
  335. if axis is None:
  336. axis = PseudoAxis(mpl_scale)
  337. axis.update_units(data)
  338. mpl_scale.set_default_locators_and_formatters(axis)
  339. new._matplotlib_scale = mpl_scale
  340. normalize: Optional[Callable[[ArrayLike], ArrayLike]]
  341. if prop.normed:
  342. if new.norm is None:
  343. vmin, vmax = data.min(), data.max()
  344. else:
  345. vmin, vmax = new.norm
  346. vmin, vmax = map(float, axis.convert_units((vmin, vmax)))
  347. a = forward(vmin)
  348. b = forward(vmax) - forward(vmin)
  349. def normalize(x):
  350. return (x - a) / b
  351. else:
  352. normalize = vmin = vmax = None
  353. new._pipeline = [
  354. axis.convert_units,
  355. forward,
  356. normalize,
  357. prop.get_mapping(new, data)
  358. ]
  359. def spacer(x):
  360. x = x.dropna().unique()
  361. if len(x) < 2:
  362. return np.nan
  363. return np.min(np.diff(np.sort(x)))
  364. new._spacer = spacer
  365. # TODO How to allow disabling of legend for all uses of property?
  366. # Could add a Scale parameter, or perhaps Scale.suppress()?
  367. # Are there other useful parameters that would be in Scale.legend()
  368. # besides allowing Scale.legend(False)?
  369. if prop.legend:
  370. axis.set_view_interval(vmin, vmax)
  371. locs = axis.major.locator()
  372. locs = locs[(vmin <= locs) & (locs <= vmax)]
  373. # Avoid having an offset / scientific notation in a legend
  374. # as we don't represent that anywhere so it ends up incorrect.
  375. # This could become an option (e.g. Continuous.label(offset=True))
  376. # in which case we would need to figure out how to show it.
  377. if hasattr(axis.major.formatter, "set_useOffset"):
  378. axis.major.formatter.set_useOffset(False)
  379. if hasattr(axis.major.formatter, "set_scientific"):
  380. axis.major.formatter.set_scientific(False)
  381. labels = axis.major.formatter.format_ticks(locs)
  382. new._legend = list(locs), list(labels)
  383. return new
  384. def _get_transform(self):
  385. arg = self.trans
  386. def get_param(method, default):
  387. if arg == method:
  388. return default
  389. return float(arg[len(method):])
  390. if arg is None:
  391. return _make_identity_transforms()
  392. elif isinstance(arg, tuple):
  393. return arg
  394. elif isinstance(arg, str):
  395. if arg == "ln":
  396. return _make_log_transforms()
  397. elif arg == "logit":
  398. base = get_param("logit", 10)
  399. return _make_logit_transforms(base)
  400. elif arg.startswith("log"):
  401. base = get_param("log", 10)
  402. return _make_log_transforms(base)
  403. elif arg.startswith("symlog"):
  404. c = get_param("symlog", 1)
  405. return _make_symlog_transforms(c)
  406. elif arg.startswith("pow"):
  407. exp = get_param("pow", 2)
  408. return _make_power_transforms(exp)
  409. elif arg == "sqrt":
  410. return _make_sqrt_transforms()
  411. else:
  412. raise ValueError(f"Unknown value provided for trans: {arg!r}")
  413. @dataclass
  414. class Continuous(ContinuousBase):
  415. """
  416. A numeric scale supporting norms and functional transforms.
  417. """
  418. values: tuple | str | None = None
  419. trans: str | TransFuncs | None = None
  420. # TODO Add this to deal with outliers?
  421. # outside: Literal["keep", "drop", "clip"] = "keep"
  422. _priority: ClassVar[int] = 1
  423. def tick(
  424. self,
  425. locator: Locator | None = None, *,
  426. at: Sequence[float] | None = None,
  427. upto: int | None = None,
  428. count: int | None = None,
  429. every: float | None = None,
  430. between: tuple[float, float] | None = None,
  431. minor: int | None = None,
  432. ) -> Continuous:
  433. """
  434. Configure the selection of ticks for the scale's axis or legend.
  435. Parameters
  436. ----------
  437. locator : :class:`matplotlib.ticker.Locator` subclass
  438. Pre-configured matplotlib locator; other parameters will not be used.
  439. at : sequence of floats
  440. Place ticks at these specific locations (in data units).
  441. upto : int
  442. Choose "nice" locations for ticks, but do not exceed this number.
  443. count : int
  444. Choose exactly this number of ticks, bounded by `between` or axis limits.
  445. every : float
  446. Choose locations at this interval of separation (in data units).
  447. between : pair of floats
  448. Bound upper / lower ticks when using `every` or `count`.
  449. minor : int
  450. Number of unlabeled ticks to draw between labeled "major" ticks.
  451. Returns
  452. -------
  453. scale
  454. Copy of self with new tick configuration.
  455. """
  456. # Input checks
  457. if locator is not None and not isinstance(locator, Locator):
  458. raise TypeError(
  459. f"Tick locator must be an instance of {Locator!r}, "
  460. f"not {type(locator)!r}."
  461. )
  462. log_base, symlog_thresh = self._parse_for_log_params(self.trans)
  463. if log_base or symlog_thresh:
  464. if count is not None and between is None:
  465. raise RuntimeError("`count` requires `between` with log transform.")
  466. if every is not None:
  467. raise RuntimeError("`every` not supported with log transform.")
  468. new = copy(self)
  469. new._tick_params = {
  470. "locator": locator,
  471. "at": at,
  472. "upto": upto,
  473. "count": count,
  474. "every": every,
  475. "between": between,
  476. "minor": minor,
  477. }
  478. return new
  479. def label(
  480. self,
  481. formatter: Formatter | None = None, *,
  482. like: str | Callable | None = None,
  483. base: int | None | Default = default,
  484. unit: str | None = None,
  485. ) -> Continuous:
  486. """
  487. Configure the appearance of tick labels for the scale's axis or legend.
  488. Parameters
  489. ----------
  490. formatter : :class:`matplotlib.ticker.Formatter` subclass
  491. Pre-configured formatter to use; other parameters will be ignored.
  492. like : str or callable
  493. Either a format pattern (e.g., `".2f"`), a format string with fields named
  494. `x` and/or `pos` (e.g., `"${x:.2f}"`), or a callable with a signature like
  495. `f(x: float, pos: int) -> str`. In the latter variants, `x` is passed as the
  496. tick value and `pos` is passed as the tick index.
  497. base : number
  498. Use log formatter (with scientific notation) having this value as the base.
  499. Set to `None` to override the default formatter with a log transform.
  500. unit : str or (str, str) tuple
  501. Use SI prefixes with these units (e.g., with `unit="g"`, a tick value
  502. of 5000 will appear as `5 kg`). When a tuple, the first element gives the
  503. separator between the number and unit.
  504. Returns
  505. -------
  506. scale
  507. Copy of self with new label configuration.
  508. """
  509. # Input checks
  510. if formatter is not None and not isinstance(formatter, Formatter):
  511. raise TypeError(
  512. f"Label formatter must be an instance of {Formatter!r}, "
  513. f"not {type(formatter)!r}"
  514. )
  515. if like is not None and not (isinstance(like, str) or callable(like)):
  516. msg = f"`like` must be a string or callable, not {type(like).__name__}."
  517. raise TypeError(msg)
  518. new = copy(self)
  519. new._label_params = {
  520. "formatter": formatter,
  521. "like": like,
  522. "base": base,
  523. "unit": unit,
  524. }
  525. return new
  526. def _parse_for_log_params(
  527. self, trans: str | TransFuncs | None
  528. ) -> tuple[float | None, float | None]:
  529. log_base = symlog_thresh = None
  530. if isinstance(trans, str):
  531. m = re.match(r"^log(\d*)", trans)
  532. if m is not None:
  533. log_base = float(m[1] or 10)
  534. m = re.match(r"symlog(\d*)", trans)
  535. if m is not None:
  536. symlog_thresh = float(m[1] or 1)
  537. return log_base, symlog_thresh
  538. def _get_locators(self, locator, at, upto, count, every, between, minor):
  539. log_base, symlog_thresh = self._parse_for_log_params(self.trans)
  540. if locator is not None:
  541. major_locator = locator
  542. elif upto is not None:
  543. if log_base:
  544. major_locator = LogLocator(base=log_base, numticks=upto)
  545. else:
  546. major_locator = MaxNLocator(upto, steps=[1, 1.5, 2, 2.5, 3, 5, 10])
  547. elif count is not None:
  548. if between is None:
  549. # This is rarely useful (unless you are setting limits)
  550. major_locator = LinearLocator(count)
  551. else:
  552. if log_base or symlog_thresh:
  553. forward, inverse = self._get_transform()
  554. lo, hi = forward(between)
  555. ticks = inverse(np.linspace(lo, hi, num=count))
  556. else:
  557. ticks = np.linspace(*between, num=count)
  558. major_locator = FixedLocator(ticks)
  559. elif every is not None:
  560. if between is None:
  561. major_locator = MultipleLocator(every)
  562. else:
  563. lo, hi = between
  564. ticks = np.arange(lo, hi + every, every)
  565. major_locator = FixedLocator(ticks)
  566. elif at is not None:
  567. major_locator = FixedLocator(at)
  568. else:
  569. if log_base:
  570. major_locator = LogLocator(log_base)
  571. elif symlog_thresh:
  572. major_locator = SymmetricalLogLocator(linthresh=symlog_thresh, base=10)
  573. else:
  574. major_locator = AutoLocator()
  575. if minor is None:
  576. minor_locator = LogLocator(log_base, subs=None) if log_base else None
  577. else:
  578. if log_base:
  579. subs = np.linspace(0, log_base, minor + 2)[1:-1]
  580. minor_locator = LogLocator(log_base, subs=subs)
  581. else:
  582. minor_locator = AutoMinorLocator(minor + 1)
  583. return major_locator, minor_locator
  584. def _get_formatter(self, locator, formatter, like, base, unit):
  585. log_base, symlog_thresh = self._parse_for_log_params(self.trans)
  586. if base is default:
  587. if symlog_thresh:
  588. log_base = 10
  589. base = log_base
  590. if formatter is not None:
  591. return formatter
  592. if like is not None:
  593. if isinstance(like, str):
  594. if "{x" in like or "{pos" in like:
  595. fmt = like
  596. else:
  597. fmt = f"{{x:{like}}}"
  598. formatter = StrMethodFormatter(fmt)
  599. else:
  600. formatter = FuncFormatter(like)
  601. elif base is not None:
  602. # We could add other log options if necessary
  603. formatter = LogFormatterSciNotation(base)
  604. elif unit is not None:
  605. if isinstance(unit, tuple):
  606. sep, unit = unit
  607. elif not unit:
  608. sep = ""
  609. else:
  610. sep = " "
  611. formatter = EngFormatter(unit, sep=sep)
  612. else:
  613. formatter = ScalarFormatter()
  614. return formatter
  615. @dataclass
  616. class Temporal(ContinuousBase):
  617. """
  618. A scale for date/time data.
  619. """
  620. # TODO date: bool?
  621. # For when we only care about the time component, would affect
  622. # default formatter and norm conversion. Should also happen in
  623. # Property.default_scale. The alternative was having distinct
  624. # Calendric / Temporal scales, but that feels a bit fussy, and it
  625. # would get in the way of using first-letter shorthands because
  626. # Calendric and Continuous would collide. Still, we haven't implemented
  627. # those yet, and having a clear distinction betewen date(time) / time
  628. # may be more useful.
  629. trans = None
  630. _priority: ClassVar[int] = 2
  631. def tick(
  632. self, locator: Locator | None = None, *,
  633. upto: int | None = None,
  634. ) -> Temporal:
  635. """
  636. Configure the selection of ticks for the scale's axis or legend.
  637. .. note::
  638. This API is under construction and will be enhanced over time.
  639. Parameters
  640. ----------
  641. locator : :class:`matplotlib.ticker.Locator` subclass
  642. Pre-configured matplotlib locator; other parameters will not be used.
  643. upto : int
  644. Choose "nice" locations for ticks, but do not exceed this number.
  645. Returns
  646. -------
  647. scale
  648. Copy of self with new tick configuration.
  649. """
  650. if locator is not None and not isinstance(locator, Locator):
  651. err = (
  652. f"Tick locator must be an instance of {Locator!r}, "
  653. f"not {type(locator)!r}."
  654. )
  655. raise TypeError(err)
  656. new = copy(self)
  657. new._tick_params = {"locator": locator, "upto": upto}
  658. return new
  659. def label(
  660. self,
  661. formatter: Formatter | None = None, *,
  662. concise: bool = False,
  663. ) -> Temporal:
  664. """
  665. Configure the appearance of tick labels for the scale's axis or legend.
  666. .. note::
  667. This API is under construction and will be enhanced over time.
  668. Parameters
  669. ----------
  670. formatter : :class:`matplotlib.ticker.Formatter` subclass
  671. Pre-configured formatter to use; other parameters will be ignored.
  672. concise : bool
  673. If True, use :class:`matplotlib.dates.ConciseDateFormatter` to make
  674. the tick labels as compact as possible.
  675. Returns
  676. -------
  677. scale
  678. Copy of self with new label configuration.
  679. """
  680. new = copy(self)
  681. new._label_params = {"formatter": formatter, "concise": concise}
  682. return new
  683. def _get_locators(self, locator, upto):
  684. if locator is not None:
  685. major_locator = locator
  686. elif upto is not None:
  687. major_locator = AutoDateLocator(minticks=2, maxticks=upto)
  688. else:
  689. major_locator = AutoDateLocator(minticks=2, maxticks=6)
  690. minor_locator = None
  691. return major_locator, minor_locator
  692. def _get_formatter(self, locator, formatter, concise):
  693. if formatter is not None:
  694. return formatter
  695. if concise:
  696. # TODO ideally we would have concise coordinate ticks,
  697. # but full semantic ticks. Is that possible?
  698. formatter = ConciseDateFormatter(locator)
  699. else:
  700. formatter = AutoDateFormatter(locator)
  701. return formatter
  702. # ----------------------------------------------------------------------------------- #
  703. # TODO Have this separate from Temporal or have Temporal(date=True) or similar?
  704. # class Calendric(Scale):
  705. # TODO Needed? Or handle this at layer (in stat or as param, eg binning=)
  706. # class Binned(Scale):
  707. # TODO any need for color-specific scales?
  708. # class Sequential(Continuous):
  709. # class Diverging(Continuous):
  710. # class Qualitative(Nominal):
  711. # ----------------------------------------------------------------------------------- #
  712. class PseudoAxis:
  713. """
  714. Internal class implementing minimal interface equivalent to matplotlib Axis.
  715. Coordinate variables are typically scaled by attaching the Axis object from
  716. the figure where the plot will end up. Matplotlib has no similar concept of
  717. and axis for the other mappable variables (color, etc.), but to simplify the
  718. code, this object acts like an Axis and can be used to scale other variables.
  719. """
  720. axis_name = "" # Matplotlib requirement but not actually used
  721. def __init__(self, scale):
  722. self.converter = None
  723. self.units = None
  724. self.scale = scale
  725. self.major = mpl.axis.Ticker()
  726. self.minor = mpl.axis.Ticker()
  727. # It appears that this needs to be initialized this way on matplotlib 3.1,
  728. # but not later versions. It is unclear whether there are any issues with it.
  729. self._data_interval = None, None
  730. scale.set_default_locators_and_formatters(self)
  731. # self.set_default_intervals() Is this ever needed?
  732. def set_view_interval(self, vmin, vmax):
  733. self._view_interval = vmin, vmax
  734. def get_view_interval(self):
  735. return self._view_interval
  736. # TODO do we want to distinguish view/data intervals? e.g. for a legend
  737. # we probably want to represent the full range of the data values, but
  738. # still norm the colormap. If so, we'll need to track data range separately
  739. # from the norm, which we currently don't do.
  740. def set_data_interval(self, vmin, vmax):
  741. self._data_interval = vmin, vmax
  742. def get_data_interval(self):
  743. return self._data_interval
  744. def get_tick_space(self):
  745. # TODO how to do this in a configurable / auto way?
  746. # Would be cool to have legend density adapt to figure size, etc.
  747. return 5
  748. def set_major_locator(self, locator):
  749. self.major.locator = locator
  750. locator.set_axis(self)
  751. def set_major_formatter(self, formatter):
  752. self.major.formatter = formatter
  753. formatter.set_axis(self)
  754. def set_minor_locator(self, locator):
  755. self.minor.locator = locator
  756. locator.set_axis(self)
  757. def set_minor_formatter(self, formatter):
  758. self.minor.formatter = formatter
  759. formatter.set_axis(self)
  760. def set_units(self, units):
  761. self.units = units
  762. def update_units(self, x):
  763. """Pass units to the internal converter, potentially updating its mapping."""
  764. self.converter = mpl.units.registry.get_converter(x)
  765. if self.converter is not None:
  766. self.converter.default_units(x, self)
  767. info = self.converter.axisinfo(self.units, self)
  768. if info is None:
  769. return
  770. if info.majloc is not None:
  771. self.set_major_locator(info.majloc)
  772. if info.majfmt is not None:
  773. self.set_major_formatter(info.majfmt)
  774. # This is in matplotlib method; do we need this?
  775. # self.set_default_intervals()
  776. def convert_units(self, x):
  777. """Return a numeric representation of the input data."""
  778. if np.issubdtype(np.asarray(x).dtype, np.number):
  779. return x
  780. elif self.converter is None:
  781. return x
  782. return self.converter.convert(x, self.units, self)
  783. def get_scale(self):
  784. # Note that matplotlib actually returns a string here!
  785. # (e.g., with a log scale, axis.get_scale() returns "log")
  786. # Currently we just hit it with minor ticks where it checks for
  787. # scale == "log". I'm not sure how you'd actually use log-scale
  788. # minor "ticks" in a legend context, so this is fine....
  789. return self.scale
  790. def get_majorticklocs(self):
  791. return self.major.locator()
  792. # ------------------------------------------------------------------------------------ #
  793. # Transform function creation
  794. def _make_identity_transforms() -> TransFuncs:
  795. def identity(x):
  796. return x
  797. return identity, identity
  798. def _make_logit_transforms(base: float | None = None) -> TransFuncs:
  799. log, exp = _make_log_transforms(base)
  800. def logit(x):
  801. with np.errstate(invalid="ignore", divide="ignore"):
  802. return log(x) - log(1 - x)
  803. def expit(x):
  804. with np.errstate(invalid="ignore", divide="ignore"):
  805. return exp(x) / (1 + exp(x))
  806. return logit, expit
  807. def _make_log_transforms(base: float | None = None) -> TransFuncs:
  808. fs: TransFuncs
  809. if base is None:
  810. fs = np.log, np.exp
  811. elif base == 2:
  812. fs = np.log2, partial(np.power, 2)
  813. elif base == 10:
  814. fs = np.log10, partial(np.power, 10)
  815. else:
  816. def forward(x):
  817. return np.log(x) / np.log(base)
  818. fs = forward, partial(np.power, base)
  819. def log(x: ArrayLike) -> ArrayLike:
  820. with np.errstate(invalid="ignore", divide="ignore"):
  821. return fs[0](x)
  822. def exp(x: ArrayLike) -> ArrayLike:
  823. with np.errstate(invalid="ignore", divide="ignore"):
  824. return fs[1](x)
  825. return log, exp
  826. def _make_symlog_transforms(c: float = 1, base: float = 10) -> TransFuncs:
  827. # From https://iopscience.iop.org/article/10.1088/0957-0233/24/2/027001
  828. # Note: currently not using base because we only get
  829. # one parameter from the string, and are using c (this is consistent with d3)
  830. log, exp = _make_log_transforms(base)
  831. def symlog(x):
  832. with np.errstate(invalid="ignore", divide="ignore"):
  833. return np.sign(x) * log(1 + np.abs(np.divide(x, c)))
  834. def symexp(x):
  835. with np.errstate(invalid="ignore", divide="ignore"):
  836. return np.sign(x) * c * (exp(np.abs(x)) - 1)
  837. return symlog, symexp
  838. def _make_sqrt_transforms() -> TransFuncs:
  839. def sqrt(x):
  840. return np.sign(x) * np.sqrt(np.abs(x))
  841. def square(x):
  842. return np.sign(x) * np.square(x)
  843. return sqrt, square
  844. def _make_power_transforms(exp: float) -> TransFuncs:
  845. def forward(x):
  846. return np.sign(x) * np.power(np.abs(x), exp)
  847. def inverse(x):
  848. return np.sign(x) * np.power(np.abs(x), 1 / exp)
  849. return forward, inverse
  850. def _default_spacer(x: Series) -> float:
  851. return 1