_base.py 65 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775
  1. from __future__ import annotations
  2. import warnings
  3. import itertools
  4. from copy import copy
  5. from collections import UserString
  6. from collections.abc import Iterable, Sequence, Mapping
  7. from numbers import Number
  8. from datetime import datetime
  9. import numpy as np
  10. import pandas as pd
  11. import matplotlib as mpl
  12. from seaborn._core.data import PlotData
  13. from seaborn.palettes import (
  14. QUAL_PALETTES,
  15. color_palette,
  16. )
  17. from seaborn.utils import (
  18. _check_argument,
  19. _version_predates,
  20. desaturate,
  21. locator_to_legend_entries,
  22. get_color_cycle,
  23. remove_na,
  24. )
  25. class SemanticMapping:
  26. """Base class for mapping data values to plot attributes."""
  27. # -- Default attributes that all SemanticMapping subclasses must set
  28. # Whether the mapping is numeric, categorical, or datetime
  29. map_type: str | None = None
  30. # Ordered list of unique values in the input data
  31. levels = None
  32. # A mapping from the data values to corresponding plot attributes
  33. lookup_table = None
  34. def __init__(self, plotter):
  35. # TODO Putting this here so we can continue to use a lot of the
  36. # logic that's built into the library, but the idea of this class
  37. # is to move towards semantic mappings that are agnostic about the
  38. # kind of plot they're going to be used to draw.
  39. # Fully achieving that is going to take some thinking.
  40. self.plotter = plotter
  41. def _check_list_length(self, levels, values, variable):
  42. """Input check when values are provided as a list."""
  43. # Copied from _core/properties; eventually will be replaced for that.
  44. message = ""
  45. if len(levels) > len(values):
  46. message = " ".join([
  47. f"\nThe {variable} list has fewer values ({len(values)})",
  48. f"than needed ({len(levels)}) and will cycle, which may",
  49. "produce an uninterpretable plot."
  50. ])
  51. values = [x for _, x in zip(levels, itertools.cycle(values))]
  52. elif len(values) > len(levels):
  53. message = " ".join([
  54. f"The {variable} list has more values ({len(values)})",
  55. f"than needed ({len(levels)}), which may not be intended.",
  56. ])
  57. values = values[:len(levels)]
  58. if message:
  59. warnings.warn(message, UserWarning, stacklevel=6)
  60. return values
  61. def _lookup_single(self, key):
  62. """Apply the mapping to a single data value."""
  63. return self.lookup_table[key]
  64. def __call__(self, key, *args, **kwargs):
  65. """Get the attribute(s) values for the data key."""
  66. if isinstance(key, (list, np.ndarray, pd.Series)):
  67. return [self._lookup_single(k, *args, **kwargs) for k in key]
  68. else:
  69. return self._lookup_single(key, *args, **kwargs)
  70. class HueMapping(SemanticMapping):
  71. """Mapping that sets artist colors according to data values."""
  72. # A specification of the colors that should appear in the plot
  73. palette = None
  74. # An object that normalizes data values to [0, 1] range for color mapping
  75. norm = None
  76. # A continuous colormap object for interpolating in a numeric context
  77. cmap = None
  78. def __init__(
  79. self, plotter, palette=None, order=None, norm=None, saturation=1,
  80. ):
  81. """Map the levels of the `hue` variable to distinct colors.
  82. Parameters
  83. ----------
  84. # TODO add generic parameters
  85. """
  86. super().__init__(plotter)
  87. data = plotter.plot_data.get("hue", pd.Series(dtype=float))
  88. if isinstance(palette, np.ndarray):
  89. msg = (
  90. "Numpy array is not a supported type for `palette`. "
  91. "Please convert your palette to a list. "
  92. "This will become an error in v0.14"
  93. )
  94. warnings.warn(msg, stacklevel=4)
  95. palette = palette.tolist()
  96. if data.isna().all():
  97. if palette is not None:
  98. msg = "Ignoring `palette` because no `hue` variable has been assigned."
  99. warnings.warn(msg, stacklevel=4)
  100. else:
  101. map_type = self.infer_map_type(
  102. palette, norm, plotter.input_format, plotter.var_types["hue"]
  103. )
  104. # Our goal is to end up with a dictionary mapping every unique
  105. # value in `data` to a color. We will also keep track of the
  106. # metadata about this mapping we will need for, e.g., a legend
  107. # --- Option 1: numeric mapping with a matplotlib colormap
  108. if map_type == "numeric":
  109. data = pd.to_numeric(data)
  110. levels, lookup_table, norm, cmap = self.numeric_mapping(
  111. data, palette, norm,
  112. )
  113. # --- Option 2: categorical mapping using seaborn palette
  114. elif map_type == "categorical":
  115. cmap = norm = None
  116. levels, lookup_table = self.categorical_mapping(
  117. data, palette, order,
  118. )
  119. # --- Option 3: datetime mapping
  120. else:
  121. # TODO this needs actual implementation
  122. cmap = norm = None
  123. levels, lookup_table = self.categorical_mapping(
  124. # Casting data to list to handle differences in the way
  125. # pandas and numpy represent datetime64 data
  126. list(data), palette, order,
  127. )
  128. self.saturation = saturation
  129. self.map_type = map_type
  130. self.lookup_table = lookup_table
  131. self.palette = palette
  132. self.levels = levels
  133. self.norm = norm
  134. self.cmap = cmap
  135. def _lookup_single(self, key):
  136. """Get the color for a single value, using colormap to interpolate."""
  137. try:
  138. # Use a value that's in the original data vector
  139. value = self.lookup_table[key]
  140. except KeyError:
  141. if self.norm is None:
  142. # Currently we only get here in scatterplot with hue_order,
  143. # because scatterplot does not consider hue a grouping variable
  144. # So unused hue levels are in the data, but not the lookup table
  145. return (0, 0, 0, 0)
  146. # Use the colormap to interpolate between existing datapoints
  147. # (e.g. in the context of making a continuous legend)
  148. try:
  149. normed = self.norm(key)
  150. except TypeError as err:
  151. if np.isnan(key):
  152. value = (0, 0, 0, 0)
  153. else:
  154. raise err
  155. else:
  156. if np.ma.is_masked(normed):
  157. normed = np.nan
  158. value = self.cmap(normed)
  159. if self.saturation < 1:
  160. value = desaturate(value, self.saturation)
  161. return value
  162. def infer_map_type(self, palette, norm, input_format, var_type):
  163. """Determine how to implement the mapping."""
  164. if palette in QUAL_PALETTES:
  165. map_type = "categorical"
  166. elif norm is not None:
  167. map_type = "numeric"
  168. elif isinstance(palette, (dict, list)):
  169. map_type = "categorical"
  170. elif input_format == "wide":
  171. map_type = "categorical"
  172. else:
  173. map_type = var_type
  174. return map_type
  175. def categorical_mapping(self, data, palette, order):
  176. """Determine colors when the hue mapping is categorical."""
  177. # -- Identify the order and name of the levels
  178. levels = categorical_order(data, order)
  179. n_colors = len(levels)
  180. # -- Identify the set of colors to use
  181. if isinstance(palette, dict):
  182. missing = set(levels) - set(palette)
  183. if any(missing):
  184. err = "The palette dictionary is missing keys: {}"
  185. raise ValueError(err.format(missing))
  186. lookup_table = palette
  187. else:
  188. if palette is None:
  189. if n_colors <= len(get_color_cycle()):
  190. colors = color_palette(None, n_colors)
  191. else:
  192. colors = color_palette("husl", n_colors)
  193. elif isinstance(palette, list):
  194. colors = self._check_list_length(levels, palette, "palette")
  195. else:
  196. colors = color_palette(palette, n_colors)
  197. lookup_table = dict(zip(levels, colors))
  198. return levels, lookup_table
  199. def numeric_mapping(self, data, palette, norm):
  200. """Determine colors when the hue variable is quantitative."""
  201. if isinstance(palette, dict):
  202. # The presence of a norm object overrides a dictionary of hues
  203. # in specifying a numeric mapping, so we need to process it here.
  204. levels = list(sorted(palette))
  205. colors = [palette[k] for k in sorted(palette)]
  206. cmap = mpl.colors.ListedColormap(colors)
  207. lookup_table = palette.copy()
  208. else:
  209. # The levels are the sorted unique values in the data
  210. levels = list(np.sort(remove_na(data.unique())))
  211. # --- Sort out the colormap to use from the palette argument
  212. # Default numeric palette is our default cubehelix palette
  213. # TODO do we want to do something complicated to ensure contrast?
  214. palette = "ch:" if palette is None else palette
  215. if isinstance(palette, mpl.colors.Colormap):
  216. cmap = palette
  217. else:
  218. cmap = color_palette(palette, as_cmap=True)
  219. # Now sort out the data normalization
  220. if norm is None:
  221. norm = mpl.colors.Normalize()
  222. elif isinstance(norm, tuple):
  223. norm = mpl.colors.Normalize(*norm)
  224. elif not isinstance(norm, mpl.colors.Normalize):
  225. err = "``hue_norm`` must be None, tuple, or Normalize object."
  226. raise ValueError(err)
  227. if not norm.scaled():
  228. norm(np.asarray(data.dropna()))
  229. lookup_table = dict(zip(levels, cmap(norm(levels))))
  230. return levels, lookup_table, norm, cmap
  231. class SizeMapping(SemanticMapping):
  232. """Mapping that sets artist sizes according to data values."""
  233. # An object that normalizes data values to [0, 1] range
  234. norm = None
  235. def __init__(
  236. self, plotter, sizes=None, order=None, norm=None,
  237. ):
  238. """Map the levels of the `size` variable to distinct values.
  239. Parameters
  240. ----------
  241. # TODO add generic parameters
  242. """
  243. super().__init__(plotter)
  244. data = plotter.plot_data.get("size", pd.Series(dtype=float))
  245. if data.notna().any():
  246. map_type = self.infer_map_type(
  247. norm, sizes, plotter.var_types["size"]
  248. )
  249. # --- Option 1: numeric mapping
  250. if map_type == "numeric":
  251. levels, lookup_table, norm, size_range = self.numeric_mapping(
  252. data, sizes, norm,
  253. )
  254. # --- Option 2: categorical mapping
  255. elif map_type == "categorical":
  256. levels, lookup_table = self.categorical_mapping(
  257. data, sizes, order,
  258. )
  259. size_range = None
  260. # --- Option 3: datetime mapping
  261. # TODO this needs an actual implementation
  262. else:
  263. levels, lookup_table = self.categorical_mapping(
  264. # Casting data to list to handle differences in the way
  265. # pandas and numpy represent datetime64 data
  266. list(data), sizes, order,
  267. )
  268. size_range = None
  269. self.map_type = map_type
  270. self.levels = levels
  271. self.norm = norm
  272. self.sizes = sizes
  273. self.size_range = size_range
  274. self.lookup_table = lookup_table
  275. def infer_map_type(self, norm, sizes, var_type):
  276. if norm is not None:
  277. map_type = "numeric"
  278. elif isinstance(sizes, (dict, list)):
  279. map_type = "categorical"
  280. else:
  281. map_type = var_type
  282. return map_type
  283. def _lookup_single(self, key):
  284. try:
  285. value = self.lookup_table[key]
  286. except KeyError:
  287. normed = self.norm(key)
  288. if np.ma.is_masked(normed):
  289. normed = np.nan
  290. value = self.size_range[0] + normed * np.ptp(self.size_range)
  291. return value
  292. def categorical_mapping(self, data, sizes, order):
  293. levels = categorical_order(data, order)
  294. if isinstance(sizes, dict):
  295. # Dict inputs map existing data values to the size attribute
  296. missing = set(levels) - set(sizes)
  297. if any(missing):
  298. err = f"Missing sizes for the following levels: {missing}"
  299. raise ValueError(err)
  300. lookup_table = sizes.copy()
  301. elif isinstance(sizes, list):
  302. # List inputs give size values in the same order as the levels
  303. sizes = self._check_list_length(levels, sizes, "sizes")
  304. lookup_table = dict(zip(levels, sizes))
  305. else:
  306. if isinstance(sizes, tuple):
  307. # Tuple input sets the min, max size values
  308. if len(sizes) != 2:
  309. err = "A `sizes` tuple must have only 2 values"
  310. raise ValueError(err)
  311. elif sizes is not None:
  312. err = f"Value for `sizes` not understood: {sizes}"
  313. raise ValueError(err)
  314. else:
  315. # Otherwise, we need to get the min, max size values from
  316. # the plotter object we are attached to.
  317. # TODO this is going to cause us trouble later, because we
  318. # want to restructure things so that the plotter is generic
  319. # across the visual representation of the data. But at this
  320. # point, we don't know the visual representation. Likely we
  321. # want to change the logic of this Mapping so that it gives
  322. # points on a normalized range that then gets un-normalized
  323. # when we know what we're drawing. But given the way the
  324. # package works now, this way is cleanest.
  325. sizes = self.plotter._default_size_range
  326. # For categorical sizes, use regularly-spaced linear steps
  327. # between the minimum and maximum sizes. Then reverse the
  328. # ramp so that the largest value is used for the first entry
  329. # in size_order, etc. This is because "ordered" categories
  330. # are often though to go in decreasing priority.
  331. sizes = np.linspace(*sizes, len(levels))[::-1]
  332. lookup_table = dict(zip(levels, sizes))
  333. return levels, lookup_table
  334. def numeric_mapping(self, data, sizes, norm):
  335. if isinstance(sizes, dict):
  336. # The presence of a norm object overrides a dictionary of sizes
  337. # in specifying a numeric mapping, so we need to process it
  338. # dictionary here
  339. levels = list(np.sort(list(sizes)))
  340. size_values = sizes.values()
  341. size_range = min(size_values), max(size_values)
  342. else:
  343. # The levels here will be the unique values in the data
  344. levels = list(np.sort(remove_na(data.unique())))
  345. if isinstance(sizes, tuple):
  346. # For numeric inputs, the size can be parametrized by
  347. # the minimum and maximum artist values to map to. The
  348. # norm object that gets set up next specifies how to
  349. # do the mapping.
  350. if len(sizes) != 2:
  351. err = "A `sizes` tuple must have only 2 values"
  352. raise ValueError(err)
  353. size_range = sizes
  354. elif sizes is not None:
  355. err = f"Value for `sizes` not understood: {sizes}"
  356. raise ValueError(err)
  357. else:
  358. # When not provided, we get the size range from the plotter
  359. # object we are attached to. See the note in the categorical
  360. # method about how this is suboptimal for future development.
  361. size_range = self.plotter._default_size_range
  362. # Now that we know the minimum and maximum sizes that will get drawn,
  363. # we need to map the data values that we have into that range. We will
  364. # use a matplotlib Normalize class, which is typically used for numeric
  365. # color mapping but works fine here too. It takes data values and maps
  366. # them into a [0, 1] interval, potentially nonlinear-ly.
  367. if norm is None:
  368. # Default is a linear function between the min and max data values
  369. norm = mpl.colors.Normalize()
  370. elif isinstance(norm, tuple):
  371. # It is also possible to give different limits in data space
  372. norm = mpl.colors.Normalize(*norm)
  373. elif not isinstance(norm, mpl.colors.Normalize):
  374. err = f"Value for size `norm` parameter not understood: {norm}"
  375. raise ValueError(err)
  376. else:
  377. # If provided with Normalize object, copy it so we can modify
  378. norm = copy(norm)
  379. # Set the mapping so all output values are in [0, 1]
  380. norm.clip = True
  381. # If the input range is not set, use the full range of the data
  382. if not norm.scaled():
  383. norm(levels)
  384. # Map from data values to [0, 1] range
  385. sizes_scaled = norm(levels)
  386. # Now map from the scaled range into the artist units
  387. if isinstance(sizes, dict):
  388. lookup_table = sizes
  389. else:
  390. lo, hi = size_range
  391. sizes = lo + sizes_scaled * (hi - lo)
  392. lookup_table = dict(zip(levels, sizes))
  393. return levels, lookup_table, norm, size_range
  394. class StyleMapping(SemanticMapping):
  395. """Mapping that sets artist style according to data values."""
  396. # Style mapping is always treated as categorical
  397. map_type = "categorical"
  398. def __init__(self, plotter, markers=None, dashes=None, order=None):
  399. """Map the levels of the `style` variable to distinct values.
  400. Parameters
  401. ----------
  402. # TODO add generic parameters
  403. """
  404. super().__init__(plotter)
  405. data = plotter.plot_data.get("style", pd.Series(dtype=float))
  406. if data.notna().any():
  407. # Cast to list to handle numpy/pandas datetime quirks
  408. if variable_type(data) == "datetime":
  409. data = list(data)
  410. # Find ordered unique values
  411. levels = categorical_order(data, order)
  412. markers = self._map_attributes(
  413. markers, levels, unique_markers(len(levels)), "markers",
  414. )
  415. dashes = self._map_attributes(
  416. dashes, levels, unique_dashes(len(levels)), "dashes",
  417. )
  418. # Build the paths matplotlib will use to draw the markers
  419. paths = {}
  420. filled_markers = []
  421. for k, m in markers.items():
  422. if not isinstance(m, mpl.markers.MarkerStyle):
  423. m = mpl.markers.MarkerStyle(m)
  424. paths[k] = m.get_path().transformed(m.get_transform())
  425. filled_markers.append(m.is_filled())
  426. # Mixture of filled and unfilled markers will show line art markers
  427. # in the edge color, which defaults to white. This can be handled,
  428. # but there would be additional complexity with specifying the
  429. # weight of the line art markers without overwhelming the filled
  430. # ones with the edges. So for now, we will disallow mixtures.
  431. if any(filled_markers) and not all(filled_markers):
  432. err = "Filled and line art markers cannot be mixed"
  433. raise ValueError(err)
  434. lookup_table = {}
  435. for key in levels:
  436. lookup_table[key] = {}
  437. if markers:
  438. lookup_table[key]["marker"] = markers[key]
  439. lookup_table[key]["path"] = paths[key]
  440. if dashes:
  441. lookup_table[key]["dashes"] = dashes[key]
  442. self.levels = levels
  443. self.lookup_table = lookup_table
  444. def _lookup_single(self, key, attr=None):
  445. """Get attribute(s) for a given data point."""
  446. if attr is None:
  447. value = self.lookup_table[key]
  448. else:
  449. value = self.lookup_table[key][attr]
  450. return value
  451. def _map_attributes(self, arg, levels, defaults, attr):
  452. """Handle the specification for a given style attribute."""
  453. if arg is True:
  454. lookup_table = dict(zip(levels, defaults))
  455. elif isinstance(arg, dict):
  456. missing = set(levels) - set(arg)
  457. if missing:
  458. err = f"These `{attr}` levels are missing values: {missing}"
  459. raise ValueError(err)
  460. lookup_table = arg
  461. elif isinstance(arg, Sequence):
  462. arg = self._check_list_length(levels, arg, attr)
  463. lookup_table = dict(zip(levels, arg))
  464. elif arg:
  465. err = f"This `{attr}` argument was not understood: {arg}"
  466. raise ValueError(err)
  467. else:
  468. lookup_table = {}
  469. return lookup_table
  470. # =========================================================================== #
  471. class VectorPlotter:
  472. """Base class for objects underlying *plot functions."""
  473. wide_structure = {
  474. "x": "@index", "y": "@values", "hue": "@columns", "style": "@columns",
  475. }
  476. flat_structure = {"x": "@index", "y": "@values"}
  477. _default_size_range = 1, 2 # Unused but needed in tests, ugh
  478. def __init__(self, data=None, variables={}):
  479. self._var_levels = {}
  480. # var_ordered is relevant only for categorical axis variables, and may
  481. # be better handled by an internal axis information object that tracks
  482. # such information and is set up by the scale_* methods. The analogous
  483. # information for numeric axes would be information about log scales.
  484. self._var_ordered = {"x": False, "y": False} # alt., used DefaultDict
  485. self.assign_variables(data, variables)
  486. # TODO Lots of tests assume that these are called to initialize the
  487. # mappings to default values on class initialization. I'd prefer to
  488. # move away from that and only have a mapping when explicitly called.
  489. for var in ["hue", "size", "style"]:
  490. if var in variables:
  491. getattr(self, f"map_{var}")()
  492. @property
  493. def has_xy_data(self):
  494. """Return True at least one of x or y is defined."""
  495. return bool({"x", "y"} & set(self.variables))
  496. @property
  497. def var_levels(self):
  498. """Property interface to ordered list of variables levels.
  499. Each time it's accessed, it updates the var_levels dictionary with the
  500. list of levels in the current semantic mappers. But it also allows the
  501. dictionary to persist, so it can be used to set levels by a key. This is
  502. used to track the list of col/row levels using an attached FacetGrid
  503. object, but it's kind of messy and ideally fixed by improving the
  504. faceting logic so it interfaces better with the modern approach to
  505. tracking plot variables.
  506. """
  507. for var in self.variables:
  508. if (map_obj := getattr(self, f"_{var}_map", None)) is not None:
  509. self._var_levels[var] = map_obj.levels
  510. return self._var_levels
  511. def assign_variables(self, data=None, variables={}):
  512. """Define plot variables, optionally using lookup from `data`."""
  513. x = variables.get("x", None)
  514. y = variables.get("y", None)
  515. if x is None and y is None:
  516. self.input_format = "wide"
  517. frame, names = self._assign_variables_wideform(data, **variables)
  518. else:
  519. # When dealing with long-form input, use the newer PlotData
  520. # object (internal but introduced for the objects interface)
  521. # to centralize / standardize data consumption logic.
  522. self.input_format = "long"
  523. plot_data = PlotData(data, variables)
  524. frame = plot_data.frame
  525. names = plot_data.names
  526. self.plot_data = frame
  527. self.variables = names
  528. self.var_types = {
  529. v: variable_type(
  530. frame[v],
  531. boolean_type="numeric" if v in "xy" else "categorical"
  532. )
  533. for v in names
  534. }
  535. return self
  536. def _assign_variables_wideform(self, data=None, **kwargs):
  537. """Define plot variables given wide-form data.
  538. Parameters
  539. ----------
  540. data : flat vector or collection of vectors
  541. Data can be a vector or mapping that is coerceable to a Series
  542. or a sequence- or mapping-based collection of such vectors, or a
  543. rectangular numpy array, or a Pandas DataFrame.
  544. kwargs : variable -> data mappings
  545. Behavior with keyword arguments is currently undefined.
  546. Returns
  547. -------
  548. plot_data : :class:`pandas.DataFrame`
  549. Long-form data object mapping seaborn variables (x, y, hue, ...)
  550. to data vectors.
  551. variables : dict
  552. Keys are defined seaborn variables; values are names inferred from
  553. the inputs (or None when no name can be determined).
  554. """
  555. # Raise if semantic or other variables are assigned in wide-form mode
  556. assigned = [k for k, v in kwargs.items() if v is not None]
  557. if any(assigned):
  558. s = "s" if len(assigned) > 1 else ""
  559. err = f"The following variable{s} cannot be assigned with wide-form data: "
  560. err += ", ".join(f"`{v}`" for v in assigned)
  561. raise ValueError(err)
  562. # Determine if the data object actually has any data in it
  563. empty = data is None or not len(data)
  564. # Then, determine if we have "flat" data (a single vector)
  565. if isinstance(data, dict):
  566. values = data.values()
  567. else:
  568. values = np.atleast_1d(np.asarray(data, dtype=object))
  569. flat = not any(
  570. isinstance(v, Iterable) and not isinstance(v, (str, bytes))
  571. for v in values
  572. )
  573. if empty:
  574. # Make an object with the structure of plot_data, but empty
  575. plot_data = pd.DataFrame()
  576. variables = {}
  577. elif flat:
  578. # Handle flat data by converting to pandas Series and using the
  579. # index and/or values to define x and/or y
  580. # (Could be accomplished with a more general to_series() interface)
  581. flat_data = pd.Series(data).copy()
  582. names = {
  583. "@values": flat_data.name,
  584. "@index": flat_data.index.name
  585. }
  586. plot_data = {}
  587. variables = {}
  588. for var in ["x", "y"]:
  589. if var in self.flat_structure:
  590. attr = self.flat_structure[var]
  591. plot_data[var] = getattr(flat_data, attr[1:])
  592. variables[var] = names[self.flat_structure[var]]
  593. plot_data = pd.DataFrame(plot_data)
  594. else:
  595. # Otherwise assume we have some collection of vectors.
  596. # Handle Python sequences such that entries end up in the columns,
  597. # not in the rows, of the intermediate wide DataFrame.
  598. # One way to accomplish this is to convert to a dict of Series.
  599. if isinstance(data, Sequence):
  600. data_dict = {}
  601. for i, var in enumerate(data):
  602. key = getattr(var, "name", i)
  603. # TODO is there a safer/more generic way to ensure Series?
  604. # sort of like np.asarray, but for pandas?
  605. data_dict[key] = pd.Series(var)
  606. data = data_dict
  607. # Pandas requires that dict values either be Series objects
  608. # or all have the same length, but we want to allow "ragged" inputs
  609. if isinstance(data, Mapping):
  610. data = {key: pd.Series(val) for key, val in data.items()}
  611. # Otherwise, delegate to the pandas DataFrame constructor
  612. # This is where we'd prefer to use a general interface that says
  613. # "give me this data as a pandas DataFrame", so we can accept
  614. # DataFrame objects from other libraries
  615. wide_data = pd.DataFrame(data, copy=True)
  616. # At this point we should reduce the dataframe to numeric cols
  617. numeric_cols = [
  618. k for k, v in wide_data.items() if variable_type(v) == "numeric"
  619. ]
  620. wide_data = wide_data[numeric_cols]
  621. # Now melt the data to long form
  622. melt_kws = {"var_name": "@columns", "value_name": "@values"}
  623. use_index = "@index" in self.wide_structure.values()
  624. if use_index:
  625. melt_kws["id_vars"] = "@index"
  626. try:
  627. orig_categories = wide_data.columns.categories
  628. orig_ordered = wide_data.columns.ordered
  629. wide_data.columns = wide_data.columns.add_categories("@index")
  630. except AttributeError:
  631. category_columns = False
  632. else:
  633. category_columns = True
  634. wide_data["@index"] = wide_data.index.to_series()
  635. plot_data = wide_data.melt(**melt_kws)
  636. if use_index and category_columns:
  637. plot_data["@columns"] = pd.Categorical(plot_data["@columns"],
  638. orig_categories,
  639. orig_ordered)
  640. # Assign names corresponding to plot semantics
  641. for var, attr in self.wide_structure.items():
  642. plot_data[var] = plot_data[attr]
  643. # Define the variable names
  644. variables = {}
  645. for var, attr in self.wide_structure.items():
  646. obj = getattr(wide_data, attr[1:])
  647. variables[var] = getattr(obj, "name", None)
  648. # Remove redundant columns from plot_data
  649. plot_data = plot_data[list(variables)]
  650. return plot_data, variables
  651. def map_hue(self, palette=None, order=None, norm=None, saturation=1):
  652. mapping = HueMapping(self, palette, order, norm, saturation)
  653. self._hue_map = mapping
  654. def map_size(self, sizes=None, order=None, norm=None):
  655. mapping = SizeMapping(self, sizes, order, norm)
  656. self._size_map = mapping
  657. def map_style(self, markers=None, dashes=None, order=None):
  658. mapping = StyleMapping(self, markers, dashes, order)
  659. self._style_map = mapping
  660. def iter_data(
  661. self, grouping_vars=None, *,
  662. reverse=False, from_comp_data=False,
  663. by_facet=True, allow_empty=False, dropna=True,
  664. ):
  665. """Generator for getting subsets of data defined by semantic variables.
  666. Also injects "col" and "row" into grouping semantics.
  667. Parameters
  668. ----------
  669. grouping_vars : string or list of strings
  670. Semantic variables that define the subsets of data.
  671. reverse : bool
  672. If True, reverse the order of iteration.
  673. from_comp_data : bool
  674. If True, use self.comp_data rather than self.plot_data
  675. by_facet : bool
  676. If True, add faceting variables to the set of grouping variables.
  677. allow_empty : bool
  678. If True, yield an empty dataframe when no observations exist for
  679. combinations of grouping variables.
  680. dropna : bool
  681. If True, remove rows with missing data.
  682. Yields
  683. ------
  684. sub_vars : dict
  685. Keys are semantic names, values are the level of that semantic.
  686. sub_data : :class:`pandas.DataFrame`
  687. Subset of ``plot_data`` for this combination of semantic values.
  688. """
  689. # TODO should this default to using all (non x/y?) semantics?
  690. # or define grouping vars somewhere?
  691. if grouping_vars is None:
  692. grouping_vars = []
  693. elif isinstance(grouping_vars, str):
  694. grouping_vars = [grouping_vars]
  695. elif isinstance(grouping_vars, tuple):
  696. grouping_vars = list(grouping_vars)
  697. # Always insert faceting variables
  698. if by_facet:
  699. facet_vars = {"col", "row"}
  700. grouping_vars.extend(
  701. facet_vars & set(self.variables) - set(grouping_vars)
  702. )
  703. # Reduce to the semantics used in this plot
  704. grouping_vars = [var for var in grouping_vars if var in self.variables]
  705. if from_comp_data:
  706. data = self.comp_data
  707. else:
  708. data = self.plot_data
  709. if dropna:
  710. data = data.dropna()
  711. levels = self.var_levels.copy()
  712. if from_comp_data:
  713. for axis in {"x", "y"} & set(grouping_vars):
  714. converter = self.converters[axis].iloc[0]
  715. if self.var_types[axis] == "categorical":
  716. if self._var_ordered[axis]:
  717. # If the axis is ordered, then the axes in a possible
  718. # facet grid are by definition "shared", or there is a
  719. # single axis with a unique cat -> idx mapping.
  720. # So we can just take the first converter object.
  721. levels[axis] = converter.convert_units(levels[axis])
  722. else:
  723. # Otherwise, the mappings may not be unique, but we can
  724. # use the unique set of index values in comp_data.
  725. levels[axis] = np.sort(data[axis].unique())
  726. else:
  727. transform = converter.get_transform().transform
  728. levels[axis] = transform(converter.convert_units(levels[axis]))
  729. if grouping_vars:
  730. grouped_data = data.groupby(
  731. grouping_vars, sort=False, as_index=False, observed=False,
  732. )
  733. grouping_keys = []
  734. for var in grouping_vars:
  735. grouping_keys.append(levels.get(var, []))
  736. iter_keys = itertools.product(*grouping_keys)
  737. if reverse:
  738. iter_keys = reversed(list(iter_keys))
  739. for key in iter_keys:
  740. # Pandas fails with singleton tuple inputs
  741. pd_key = key[0] if len(key) == 1 else key
  742. try:
  743. data_subset = grouped_data.get_group(pd_key)
  744. except KeyError:
  745. # XXX we are adding this to allow backwards compatibility
  746. # with the empty artists that old categorical plots would
  747. # add (before 0.12), which we may decide to break, in which
  748. # case this option could be removed
  749. data_subset = data.loc[[]]
  750. if data_subset.empty and not allow_empty:
  751. continue
  752. sub_vars = dict(zip(grouping_vars, key))
  753. yield sub_vars, data_subset.copy()
  754. else:
  755. yield {}, data.copy()
  756. @property
  757. def comp_data(self):
  758. """Dataframe with numeric x and y, after unit conversion and log scaling."""
  759. if not hasattr(self, "ax"):
  760. # Probably a good idea, but will need a bunch of tests updated
  761. # Most of these tests should just use the external interface
  762. # Then this can be re-enabled.
  763. # raise AttributeError("No Axes attached to plotter")
  764. return self.plot_data
  765. if not hasattr(self, "_comp_data"):
  766. comp_data = (
  767. self.plot_data
  768. .copy(deep=False)
  769. .drop(["x", "y"], axis=1, errors="ignore")
  770. )
  771. for var in "yx":
  772. if var not in self.variables:
  773. continue
  774. parts = []
  775. grouped = self.plot_data[var].groupby(self.converters[var], sort=False)
  776. for converter, orig in grouped:
  777. orig = orig.mask(orig.isin([np.inf, -np.inf]), np.nan)
  778. orig = orig.dropna()
  779. if var in self.var_levels:
  780. # TODO this should happen in some centralized location
  781. # it is similar to GH2419, but more complicated because
  782. # supporting `order` in categorical plots is tricky
  783. orig = orig[orig.isin(self.var_levels[var])]
  784. comp = pd.to_numeric(converter.convert_units(orig)).astype(float)
  785. transform = converter.get_transform().transform
  786. parts.append(pd.Series(transform(comp), orig.index, name=orig.name))
  787. if parts:
  788. comp_col = pd.concat(parts)
  789. else:
  790. comp_col = pd.Series(dtype=float, name=var)
  791. comp_data.insert(0, var, comp_col)
  792. self._comp_data = comp_data
  793. return self._comp_data
  794. def _get_axes(self, sub_vars):
  795. """Return an Axes object based on existence of row/col variables."""
  796. row = sub_vars.get("row", None)
  797. col = sub_vars.get("col", None)
  798. if row is not None and col is not None:
  799. return self.facets.axes_dict[(row, col)]
  800. elif row is not None:
  801. return self.facets.axes_dict[row]
  802. elif col is not None:
  803. return self.facets.axes_dict[col]
  804. elif self.ax is None:
  805. return self.facets.ax
  806. else:
  807. return self.ax
  808. def _attach(
  809. self,
  810. obj,
  811. allowed_types=None,
  812. log_scale=None,
  813. ):
  814. """Associate the plotter with an Axes manager and initialize its units.
  815. Parameters
  816. ----------
  817. obj : :class:`matplotlib.axes.Axes` or :class:'FacetGrid`
  818. Structural object that we will eventually plot onto.
  819. allowed_types : str or list of str
  820. If provided, raise when either the x or y variable does not have
  821. one of the declared seaborn types.
  822. log_scale : bool, number, or pair of bools or numbers
  823. If not False, set the axes to use log scaling, with the given
  824. base or defaulting to 10. If a tuple, interpreted as separate
  825. arguments for the x and y axes.
  826. """
  827. from .axisgrid import FacetGrid
  828. if isinstance(obj, FacetGrid):
  829. self.ax = None
  830. self.facets = obj
  831. ax_list = obj.axes.flatten()
  832. if obj.col_names is not None:
  833. self.var_levels["col"] = obj.col_names
  834. if obj.row_names is not None:
  835. self.var_levels["row"] = obj.row_names
  836. else:
  837. self.ax = obj
  838. self.facets = None
  839. ax_list = [obj]
  840. # Identify which "axis" variables we have defined
  841. axis_variables = set("xy").intersection(self.variables)
  842. # -- Verify the types of our x and y variables here.
  843. # This doesn't really make complete sense being here here, but it's a fine
  844. # place for it, given the current system.
  845. # (Note that for some plots, there might be more complicated restrictions)
  846. # e.g. the categorical plots have their own check that as specific to the
  847. # non-categorical axis.
  848. if allowed_types is None:
  849. allowed_types = ["numeric", "datetime", "categorical"]
  850. elif isinstance(allowed_types, str):
  851. allowed_types = [allowed_types]
  852. for var in axis_variables:
  853. var_type = self.var_types[var]
  854. if var_type not in allowed_types:
  855. err = (
  856. f"The {var} variable is {var_type}, but one of "
  857. f"{allowed_types} is required"
  858. )
  859. raise TypeError(err)
  860. # -- Get axis objects for each row in plot_data for type conversions and scaling
  861. facet_dim = {"x": "col", "y": "row"}
  862. self.converters = {}
  863. for var in axis_variables:
  864. other_var = {"x": "y", "y": "x"}[var]
  865. converter = pd.Series(index=self.plot_data.index, name=var, dtype=object)
  866. share_state = getattr(self.facets, f"_share{var}", True)
  867. # Simplest cases are that we have a single axes, all axes are shared,
  868. # or sharing is only on the orthogonal facet dimension. In these cases,
  869. # all datapoints get converted the same way, so use the first axis
  870. if share_state is True or share_state == facet_dim[other_var]:
  871. converter.loc[:] = getattr(ax_list[0], f"{var}axis")
  872. else:
  873. # Next simplest case is when no axes are shared, and we can
  874. # use the axis objects within each facet
  875. if share_state is False:
  876. for axes_vars, axes_data in self.iter_data():
  877. ax = self._get_axes(axes_vars)
  878. converter.loc[axes_data.index] = getattr(ax, f"{var}axis")
  879. # In the more complicated case, the axes are shared within each
  880. # "file" of the facetgrid. In that case, we need to subset the data
  881. # for that file and assign it the first axis in the slice of the grid
  882. else:
  883. names = getattr(self.facets, f"{share_state}_names")
  884. for i, level in enumerate(names):
  885. idx = (i, 0) if share_state == "row" else (0, i)
  886. axis = getattr(self.facets.axes[idx], f"{var}axis")
  887. converter.loc[self.plot_data[share_state] == level] = axis
  888. # Store the converter vector, which we use elsewhere (e.g comp_data)
  889. self.converters[var] = converter
  890. # Now actually update the matplotlib objects to do the conversion we want
  891. grouped = self.plot_data[var].groupby(self.converters[var], sort=False)
  892. for converter, seed_data in grouped:
  893. if self.var_types[var] == "categorical":
  894. if self._var_ordered[var]:
  895. order = self.var_levels[var]
  896. else:
  897. order = None
  898. seed_data = categorical_order(seed_data, order)
  899. converter.update_units(seed_data)
  900. # -- Set numerical axis scales
  901. # First unpack the log_scale argument
  902. if log_scale is None:
  903. scalex = scaley = False
  904. else:
  905. # Allow single value or x, y tuple
  906. try:
  907. scalex, scaley = log_scale
  908. except TypeError:
  909. scalex = log_scale if self.var_types.get("x") == "numeric" else False
  910. scaley = log_scale if self.var_types.get("y") == "numeric" else False
  911. # Now use it
  912. for axis, scale in zip("xy", (scalex, scaley)):
  913. if scale:
  914. for ax in ax_list:
  915. set_scale = getattr(ax, f"set_{axis}scale")
  916. if scale is True:
  917. set_scale("log", nonpositive="mask")
  918. else:
  919. set_scale("log", base=scale, nonpositive="mask")
  920. # For categorical y, we want the "first" level to be at the top of the axis
  921. if self.var_types.get("y", None) == "categorical":
  922. for ax in ax_list:
  923. try:
  924. ax.yaxis.set_inverted(True)
  925. except AttributeError: # mpl < 3.1
  926. if not ax.yaxis_inverted():
  927. ax.invert_yaxis()
  928. # TODO -- Add axes labels
  929. def _get_scale_transforms(self, axis):
  930. """Return a function implementing the scale transform (or its inverse)."""
  931. if self.ax is None:
  932. axis_list = [getattr(ax, f"{axis}axis") for ax in self.facets.axes.flat]
  933. scales = {axis.get_scale() for axis in axis_list}
  934. if len(scales) > 1:
  935. # It is a simplifying assumption that faceted axes will always have
  936. # the same scale (even if they are unshared and have distinct limits).
  937. # Nothing in the seaborn API allows you to create a FacetGrid with
  938. # a mixture of scales, although it's possible via matplotlib.
  939. # This is constraining, but no more so than previous behavior that
  940. # only (properly) handled log scales, and there are some places where
  941. # it would be much too complicated to use axes-specific transforms.
  942. err = "Cannot determine transform with mixed scales on faceted axes."
  943. raise RuntimeError(err)
  944. transform_obj = axis_list[0].get_transform()
  945. else:
  946. # This case is more straightforward
  947. transform_obj = getattr(self.ax, f"{axis}axis").get_transform()
  948. return transform_obj.transform, transform_obj.inverted().transform
  949. def _add_axis_labels(self, ax, default_x="", default_y=""):
  950. """Add axis labels if not present, set visibility to match ticklabels."""
  951. # TODO ax could default to None and use attached axes if present
  952. # but what to do about the case of facets? Currently using FacetGrid's
  953. # set_axis_labels method, which doesn't add labels to the interior even
  954. # when the axes are not shared. Maybe that makes sense?
  955. if not ax.get_xlabel():
  956. x_visible = any(t.get_visible() for t in ax.get_xticklabels())
  957. ax.set_xlabel(self.variables.get("x", default_x), visible=x_visible)
  958. if not ax.get_ylabel():
  959. y_visible = any(t.get_visible() for t in ax.get_yticklabels())
  960. ax.set_ylabel(self.variables.get("y", default_y), visible=y_visible)
  961. def add_legend_data(
  962. self, ax, func, common_kws=None, attrs=None, semantic_kws=None,
  963. ):
  964. """Add labeled artists to represent the different plot semantics."""
  965. verbosity = self.legend
  966. if isinstance(verbosity, str) and verbosity not in ["auto", "brief", "full"]:
  967. err = "`legend` must be 'auto', 'brief', 'full', or a boolean."
  968. raise ValueError(err)
  969. elif verbosity is True:
  970. verbosity = "auto"
  971. keys = []
  972. legend_kws = {}
  973. common_kws = {} if common_kws is None else common_kws.copy()
  974. semantic_kws = {} if semantic_kws is None else semantic_kws.copy()
  975. # Assign a legend title if there is only going to be one sub-legend,
  976. # otherwise, subtitles will be inserted into the texts list with an
  977. # invisible handle (which is a hack)
  978. titles = {
  979. title for title in
  980. (self.variables.get(v, None) for v in ["hue", "size", "style"])
  981. if title is not None
  982. }
  983. title = "" if len(titles) != 1 else titles.pop()
  984. title_kws = dict(
  985. visible=False, color="w", s=0, linewidth=0, marker="", dashes=""
  986. )
  987. def update(var_name, val_name, **kws):
  988. key = var_name, val_name
  989. if key in legend_kws:
  990. legend_kws[key].update(**kws)
  991. else:
  992. keys.append(key)
  993. legend_kws[key] = dict(**kws)
  994. if attrs is None:
  995. attrs = {"hue": "color", "size": ["linewidth", "s"], "style": None}
  996. for var, names in attrs.items():
  997. self._update_legend_data(
  998. update, var, verbosity, title, title_kws, names, semantic_kws.get(var),
  999. )
  1000. legend_data = {}
  1001. legend_order = []
  1002. # Don't allow color=None so we can set a neutral color for size/style legends
  1003. if common_kws.get("color", False) is None:
  1004. common_kws.pop("color")
  1005. for key in keys:
  1006. _, label = key
  1007. kws = legend_kws[key]
  1008. level_kws = {}
  1009. use_attrs = [
  1010. *self._legend_attributes,
  1011. *common_kws,
  1012. *[attr for var_attrs in semantic_kws.values() for attr in var_attrs],
  1013. ]
  1014. for attr in use_attrs:
  1015. if attr in kws:
  1016. level_kws[attr] = kws[attr]
  1017. artist = func(label=label, **{"color": ".2", **common_kws, **level_kws})
  1018. if _version_predates(mpl, "3.5.0"):
  1019. if isinstance(artist, mpl.lines.Line2D):
  1020. ax.add_line(artist)
  1021. elif isinstance(artist, mpl.patches.Patch):
  1022. ax.add_patch(artist)
  1023. elif isinstance(artist, mpl.collections.Collection):
  1024. ax.add_collection(artist)
  1025. else:
  1026. ax.add_artist(artist)
  1027. legend_data[key] = artist
  1028. legend_order.append(key)
  1029. self.legend_title = title
  1030. self.legend_data = legend_data
  1031. self.legend_order = legend_order
  1032. def _update_legend_data(
  1033. self,
  1034. update,
  1035. var,
  1036. verbosity,
  1037. title,
  1038. title_kws,
  1039. attr_names,
  1040. other_props,
  1041. ):
  1042. """Generate legend tick values and formatted labels."""
  1043. brief_ticks = 6
  1044. mapper = getattr(self, f"_{var}_map", None)
  1045. if mapper is None:
  1046. return
  1047. brief = mapper.map_type == "numeric" and (
  1048. verbosity == "brief"
  1049. or (verbosity == "auto" and len(mapper.levels) > brief_ticks)
  1050. )
  1051. if brief:
  1052. if isinstance(mapper.norm, mpl.colors.LogNorm):
  1053. locator = mpl.ticker.LogLocator(numticks=brief_ticks)
  1054. else:
  1055. locator = mpl.ticker.MaxNLocator(nbins=brief_ticks)
  1056. limits = min(mapper.levels), max(mapper.levels)
  1057. levels, formatted_levels = locator_to_legend_entries(
  1058. locator, limits, self.plot_data[var].infer_objects().dtype
  1059. )
  1060. elif mapper.levels is None:
  1061. levels = formatted_levels = []
  1062. else:
  1063. levels = formatted_levels = mapper.levels
  1064. if not title and self.variables.get(var, None) is not None:
  1065. update((self.variables[var], "title"), self.variables[var], **title_kws)
  1066. other_props = {} if other_props is None else other_props
  1067. for level, formatted_level in zip(levels, formatted_levels):
  1068. if level is not None:
  1069. attr = mapper(level)
  1070. if isinstance(attr_names, list):
  1071. attr = {name: attr for name in attr_names}
  1072. elif attr_names is not None:
  1073. attr = {attr_names: attr}
  1074. attr.update({k: v[level] for k, v in other_props.items() if level in v})
  1075. update(self.variables[var], formatted_level, **attr)
  1076. # XXX If the scale_* methods are going to modify the plot_data structure, they
  1077. # can't be called twice. That means that if they are called twice, they should
  1078. # raise. Alternatively, we could store an original version of plot_data and each
  1079. # time they are called they operate on the store, not the current state.
  1080. def scale_native(self, axis, *args, **kwargs):
  1081. # Default, defer to matplotlib
  1082. raise NotImplementedError
  1083. def scale_numeric(self, axis, *args, **kwargs):
  1084. # Feels needed to completeness, what should it do?
  1085. # Perhaps handle log scaling? Set the ticker/formatter/limits?
  1086. raise NotImplementedError
  1087. def scale_datetime(self, axis, *args, **kwargs):
  1088. # Use pd.to_datetime to convert strings or numbers to datetime objects
  1089. # Note, use day-resolution for numeric->datetime to match matplotlib
  1090. raise NotImplementedError
  1091. def scale_categorical(self, axis, order=None, formatter=None):
  1092. """
  1093. Enforce categorical (fixed-scale) rules for the data on given axis.
  1094. Parameters
  1095. ----------
  1096. axis : "x" or "y"
  1097. Axis of the plot to operate on.
  1098. order : list
  1099. Order that unique values should appear in.
  1100. formatter : callable
  1101. Function mapping values to a string representation.
  1102. Returns
  1103. -------
  1104. self
  1105. """
  1106. # This method both modifies the internal representation of the data
  1107. # (converting it to string) and sets some attributes on self. It might be
  1108. # a good idea to have a separate object attached to self that contains the
  1109. # information in those attributes (i.e. whether to enforce variable order
  1110. # across facets, the order to use) similar to the SemanticMapping objects
  1111. # we have for semantic variables. That object could also hold the converter
  1112. # objects that get used, if we can decouple those from an existing axis
  1113. # (cf. https://github.com/matplotlib/matplotlib/issues/19229).
  1114. # There are some interactions with faceting information that would need
  1115. # to be thought through, since the converts to use depend on facets.
  1116. # If we go that route, these methods could become "borrowed" methods similar
  1117. # to what happens with the alternate semantic mapper constructors, although
  1118. # that approach is kind of fussy and confusing.
  1119. # TODO this method could also set the grid state? Since we like to have no
  1120. # grid on the categorical axis by default. Again, a case where we'll need to
  1121. # store information until we use it, so best to have a way to collect the
  1122. # attributes that this method sets.
  1123. # TODO if we are going to set visual properties of the axes with these methods,
  1124. # then we could do the steps currently in CategoricalPlotter._adjust_cat_axis
  1125. # TODO another, and distinct idea, is to expose a cut= param here
  1126. _check_argument("axis", ["x", "y"], axis)
  1127. # Categorical plots can be "univariate" in which case they get an anonymous
  1128. # category label on the opposite axis.
  1129. if axis not in self.variables:
  1130. self.variables[axis] = None
  1131. self.var_types[axis] = "categorical"
  1132. self.plot_data[axis] = ""
  1133. # If the "categorical" variable has a numeric type, sort the rows so that
  1134. # the default result from categorical_order has those values sorted after
  1135. # they have been coerced to strings. The reason for this is so that later
  1136. # we can get facet-wise orders that are correct.
  1137. # XXX Should this also sort datetimes?
  1138. # It feels more consistent, but technically will be a default change
  1139. # If so, should also change categorical_order to behave that way
  1140. if self.var_types[axis] == "numeric":
  1141. self.plot_data = self.plot_data.sort_values(axis, kind="mergesort")
  1142. # Now get a reference to the categorical data vector and remove na values
  1143. cat_data = self.plot_data[axis].dropna()
  1144. # Get the initial categorical order, which we do before string
  1145. # conversion to respect the original types of the order list.
  1146. # Track whether the order is given explicitly so that we can know
  1147. # whether or not to use the order constructed here downstream
  1148. self._var_ordered[axis] = order is not None or cat_data.dtype.name == "category"
  1149. order = pd.Index(categorical_order(cat_data, order), name=axis)
  1150. # Then convert data to strings. This is because in matplotlib,
  1151. # "categorical" data really mean "string" data, so doing this artists
  1152. # will be drawn on the categorical axis with a fixed scale.
  1153. # TODO implement formatter here; check that it returns strings?
  1154. if formatter is not None:
  1155. cat_data = cat_data.map(formatter)
  1156. order = order.map(formatter)
  1157. else:
  1158. cat_data = cat_data.astype(str)
  1159. order = order.astype(str)
  1160. # Update the levels list with the type-converted order variable
  1161. self.var_levels[axis] = order
  1162. # Now ensure that seaborn will use categorical rules internally
  1163. self.var_types[axis] = "categorical"
  1164. # Put the string-typed categorical vector back into the plot_data structure
  1165. self.plot_data[axis] = cat_data
  1166. return self
  1167. class VariableType(UserString):
  1168. """
  1169. Prevent comparisons elsewhere in the library from using the wrong name.
  1170. Errors are simple assertions because users should not be able to trigger
  1171. them. If that changes, they should be more verbose.
  1172. """
  1173. # TODO we can replace this with typing.Literal on Python 3.8+
  1174. allowed = "numeric", "datetime", "categorical"
  1175. def __init__(self, data):
  1176. assert data in self.allowed, data
  1177. super().__init__(data)
  1178. def __eq__(self, other):
  1179. assert other in self.allowed, other
  1180. return self.data == other
  1181. def variable_type(vector, boolean_type="numeric"):
  1182. """
  1183. Determine whether a vector contains numeric, categorical, or datetime data.
  1184. This function differs from the pandas typing API in two ways:
  1185. - Python sequences or object-typed PyData objects are considered numeric if
  1186. all of their entries are numeric.
  1187. - String or mixed-type data are considered categorical even if not
  1188. explicitly represented as a :class:`pandas.api.types.CategoricalDtype`.
  1189. Parameters
  1190. ----------
  1191. vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence
  1192. Input data to test.
  1193. boolean_type : 'numeric' or 'categorical'
  1194. Type to use for vectors containing only 0s and 1s (and NAs).
  1195. Returns
  1196. -------
  1197. var_type : 'numeric', 'categorical', or 'datetime'
  1198. Name identifying the type of data in the vector.
  1199. """
  1200. vector = pd.Series(vector)
  1201. # If a categorical dtype is set, infer categorical
  1202. if isinstance(vector.dtype, pd.CategoricalDtype):
  1203. return VariableType("categorical")
  1204. # Special-case all-na data, which is always "numeric"
  1205. if pd.isna(vector).all():
  1206. return VariableType("numeric")
  1207. # At this point, drop nans to simplify further type inference
  1208. vector = vector.dropna()
  1209. # Special-case binary/boolean data, allow caller to determine
  1210. # This triggers a numpy warning when vector has strings/objects
  1211. # https://github.com/numpy/numpy/issues/6784
  1212. # Because we reduce with .all(), we are agnostic about whether the
  1213. # comparison returns a scalar or vector, so we will ignore the warning.
  1214. # It triggers a separate DeprecationWarning when the vector has datetimes:
  1215. # https://github.com/numpy/numpy/issues/13548
  1216. # This is considered a bug by numpy and will likely go away.
  1217. with warnings.catch_warnings():
  1218. warnings.simplefilter(
  1219. action='ignore', category=(FutureWarning, DeprecationWarning)
  1220. )
  1221. if np.isin(vector, [0, 1]).all():
  1222. return VariableType(boolean_type)
  1223. # Defer to positive pandas tests
  1224. if pd.api.types.is_numeric_dtype(vector):
  1225. return VariableType("numeric")
  1226. if pd.api.types.is_datetime64_dtype(vector):
  1227. return VariableType("datetime")
  1228. # --- If we get to here, we need to check the entries
  1229. # Check for a collection where everything is a number
  1230. def all_numeric(x):
  1231. for x_i in x:
  1232. if not isinstance(x_i, Number):
  1233. return False
  1234. return True
  1235. if all_numeric(vector):
  1236. return VariableType("numeric")
  1237. # Check for a collection where everything is a datetime
  1238. def all_datetime(x):
  1239. for x_i in x:
  1240. if not isinstance(x_i, (datetime, np.datetime64)):
  1241. return False
  1242. return True
  1243. if all_datetime(vector):
  1244. return VariableType("datetime")
  1245. # Otherwise, our final fallback is to consider things categorical
  1246. return VariableType("categorical")
  1247. def infer_orient(x=None, y=None, orient=None, require_numeric=True):
  1248. """Determine how the plot should be oriented based on the data.
  1249. For historical reasons, the convention is to call a plot "horizontally"
  1250. or "vertically" oriented based on the axis representing its dependent
  1251. variable. Practically, this is used when determining the axis for
  1252. numerical aggregation.
  1253. Parameters
  1254. ----------
  1255. x, y : Vector data or None
  1256. Positional data vectors for the plot.
  1257. orient : string or None
  1258. Specified orientation. If not None, can be "x" or "y", or otherwise
  1259. must start with "v" or "h".
  1260. require_numeric : bool
  1261. If set, raise when the implied dependent variable is not numeric.
  1262. Returns
  1263. -------
  1264. orient : "x" or "y"
  1265. Raises
  1266. ------
  1267. ValueError: When `orient` is an unknown string.
  1268. TypeError: When dependent variable is not numeric, with `require_numeric`
  1269. """
  1270. x_type = None if x is None else variable_type(x)
  1271. y_type = None if y is None else variable_type(y)
  1272. nonnumeric_dv_error = "{} orientation requires numeric `{}` variable."
  1273. single_var_warning = "{} orientation ignored with only `{}` specified."
  1274. if x is None:
  1275. if str(orient).startswith("h"):
  1276. warnings.warn(single_var_warning.format("Horizontal", "y"))
  1277. if require_numeric and y_type != "numeric":
  1278. raise TypeError(nonnumeric_dv_error.format("Vertical", "y"))
  1279. return "x"
  1280. elif y is None:
  1281. if str(orient).startswith("v"):
  1282. warnings.warn(single_var_warning.format("Vertical", "x"))
  1283. if require_numeric and x_type != "numeric":
  1284. raise TypeError(nonnumeric_dv_error.format("Horizontal", "x"))
  1285. return "y"
  1286. elif str(orient).startswith("v") or orient == "x":
  1287. if require_numeric and y_type != "numeric":
  1288. raise TypeError(nonnumeric_dv_error.format("Vertical", "y"))
  1289. return "x"
  1290. elif str(orient).startswith("h") or orient == "y":
  1291. if require_numeric and x_type != "numeric":
  1292. raise TypeError(nonnumeric_dv_error.format("Horizontal", "x"))
  1293. return "y"
  1294. elif orient is not None:
  1295. err = (
  1296. "`orient` must start with 'v' or 'h' or be None, "
  1297. f"but `{repr(orient)}` was passed."
  1298. )
  1299. raise ValueError(err)
  1300. elif x_type != "categorical" and y_type == "categorical":
  1301. return "y"
  1302. elif x_type != "numeric" and y_type == "numeric":
  1303. return "x"
  1304. elif x_type == "numeric" and y_type != "numeric":
  1305. return "y"
  1306. elif require_numeric and "numeric" not in (x_type, y_type):
  1307. err = "Neither the `x` nor `y` variable appears to be numeric."
  1308. raise TypeError(err)
  1309. else:
  1310. return "x"
  1311. def unique_dashes(n):
  1312. """Build an arbitrarily long list of unique dash styles for lines.
  1313. Parameters
  1314. ----------
  1315. n : int
  1316. Number of unique dash specs to generate.
  1317. Returns
  1318. -------
  1319. dashes : list of strings or tuples
  1320. Valid arguments for the ``dashes`` parameter on
  1321. :class:`matplotlib.lines.Line2D`. The first spec is a solid
  1322. line (``""``), the remainder are sequences of long and short
  1323. dashes.
  1324. """
  1325. # Start with dash specs that are well distinguishable
  1326. dashes = [
  1327. "",
  1328. (4, 1.5),
  1329. (1, 1),
  1330. (3, 1.25, 1.5, 1.25),
  1331. (5, 1, 1, 1),
  1332. ]
  1333. # Now programmatically build as many as we need
  1334. p = 3
  1335. while len(dashes) < n:
  1336. # Take combinations of long and short dashes
  1337. a = itertools.combinations_with_replacement([3, 1.25], p)
  1338. b = itertools.combinations_with_replacement([4, 1], p)
  1339. # Interleave the combinations, reversing one of the streams
  1340. segment_list = itertools.chain(*zip(
  1341. list(a)[1:-1][::-1],
  1342. list(b)[1:-1]
  1343. ))
  1344. # Now insert the gaps
  1345. for segments in segment_list:
  1346. gap = min(segments)
  1347. spec = tuple(itertools.chain(*((seg, gap) for seg in segments)))
  1348. dashes.append(spec)
  1349. p += 1
  1350. return dashes[:n]
  1351. def unique_markers(n):
  1352. """Build an arbitrarily long list of unique marker styles for points.
  1353. Parameters
  1354. ----------
  1355. n : int
  1356. Number of unique marker specs to generate.
  1357. Returns
  1358. -------
  1359. markers : list of string or tuples
  1360. Values for defining :class:`matplotlib.markers.MarkerStyle` objects.
  1361. All markers will be filled.
  1362. """
  1363. # Start with marker specs that are well distinguishable
  1364. markers = [
  1365. "o",
  1366. "X",
  1367. (4, 0, 45),
  1368. "P",
  1369. (4, 0, 0),
  1370. (4, 1, 0),
  1371. "^",
  1372. (4, 1, 45),
  1373. "v",
  1374. ]
  1375. # Now generate more from regular polygons of increasing order
  1376. s = 5
  1377. while len(markers) < n:
  1378. a = 360 / (s + 1) / 2
  1379. markers.extend([
  1380. (s + 1, 1, a),
  1381. (s + 1, 0, a),
  1382. (s, 1, 0),
  1383. (s, 0, 0),
  1384. ])
  1385. s += 1
  1386. # Convert to MarkerStyle object, using only exactly what we need
  1387. # markers = [mpl.markers.MarkerStyle(m) for m in markers[:n]]
  1388. return markers[:n]
  1389. def categorical_order(vector, order=None):
  1390. """Return a list of unique data values.
  1391. Determine an ordered list of levels in ``values``.
  1392. Parameters
  1393. ----------
  1394. vector : list, array, Categorical, or Series
  1395. Vector of "categorical" values
  1396. order : list-like, optional
  1397. Desired order of category levels to override the order determined
  1398. from the ``values`` object.
  1399. Returns
  1400. -------
  1401. order : list
  1402. Ordered list of category levels not including null values.
  1403. """
  1404. if order is None:
  1405. if hasattr(vector, "categories"):
  1406. order = vector.categories
  1407. else:
  1408. try:
  1409. order = vector.cat.categories
  1410. except (TypeError, AttributeError):
  1411. order = pd.Series(vector).unique()
  1412. if variable_type(vector) == "numeric":
  1413. order = np.sort(order)
  1414. order = filter(pd.notnull, order)
  1415. return list(order)