_compat.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import numpy as np
  2. import matplotlib as mpl
  3. from seaborn.utils import _version_predates
  4. def MarkerStyle(marker=None, fillstyle=None):
  5. """
  6. Allow MarkerStyle to accept a MarkerStyle object as parameter.
  7. Supports matplotlib < 3.3.0
  8. https://github.com/matplotlib/matplotlib/pull/16692
  9. """
  10. if isinstance(marker, mpl.markers.MarkerStyle):
  11. if fillstyle is None:
  12. return marker
  13. else:
  14. marker = marker.get_marker()
  15. return mpl.markers.MarkerStyle(marker, fillstyle)
  16. def norm_from_scale(scale, norm):
  17. """Produce a Normalize object given a Scale and min/max domain limits."""
  18. # This is an internal maplotlib function that simplifies things to access
  19. # It is likely to become part of the matplotlib API at some point:
  20. # https://github.com/matplotlib/matplotlib/issues/20329
  21. if isinstance(norm, mpl.colors.Normalize):
  22. return norm
  23. if scale is None:
  24. return None
  25. if norm is None:
  26. vmin = vmax = None
  27. else:
  28. vmin, vmax = norm # TODO more helpful error if this fails?
  29. class ScaledNorm(mpl.colors.Normalize):
  30. def __call__(self, value, clip=None):
  31. # From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py
  32. # See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE
  33. value, is_scalar = self.process_value(value)
  34. self.autoscale_None(value)
  35. if self.vmin > self.vmax:
  36. raise ValueError("vmin must be less or equal to vmax")
  37. if self.vmin == self.vmax:
  38. return np.full_like(value, 0)
  39. if clip is None:
  40. clip = self.clip
  41. if clip:
  42. value = np.clip(value, self.vmin, self.vmax)
  43. # ***** Seaborn changes start ****
  44. t_value = self.transform(value).reshape(np.shape(value))
  45. t_vmin, t_vmax = self.transform([self.vmin, self.vmax])
  46. # ***** Seaborn changes end *****
  47. if not np.isfinite([t_vmin, t_vmax]).all():
  48. raise ValueError("Invalid vmin or vmax")
  49. t_value -= t_vmin
  50. t_value /= (t_vmax - t_vmin)
  51. t_value = np.ma.masked_invalid(t_value, copy=False)
  52. return t_value[0] if is_scalar else t_value
  53. new_norm = ScaledNorm(vmin, vmax)
  54. new_norm.transform = scale.get_transform().transform
  55. return new_norm
  56. def scale_factory(scale, axis, **kwargs):
  57. """
  58. Backwards compatability for creation of independent scales.
  59. Matplotlib scales require an Axis object for instantiation on < 3.4.
  60. But the axis is not used, aside from extraction of the axis_name in LogScale.
  61. """
  62. modify_transform = False
  63. if _version_predates(mpl, "3.4"):
  64. if axis[0] in "xy":
  65. modify_transform = True
  66. axis = axis[0]
  67. base = kwargs.pop("base", None)
  68. if base is not None:
  69. kwargs[f"base{axis}"] = base
  70. nonpos = kwargs.pop("nonpositive", None)
  71. if nonpos is not None:
  72. kwargs[f"nonpos{axis}"] = nonpos
  73. if isinstance(scale, str):
  74. class Axis:
  75. axis_name = axis
  76. axis = Axis()
  77. scale = mpl.scale.scale_factory(scale, axis, **kwargs)
  78. if modify_transform:
  79. transform = scale.get_transform()
  80. transform.base = kwargs.get("base", 10)
  81. if kwargs.get("nonpositive") == "mask":
  82. # Setting a private attribute, but we only get here
  83. # on an old matplotlib, so this won't break going forwards
  84. transform._clip = False
  85. return scale
  86. def set_scale_obj(ax, axis, scale):
  87. """Handle backwards compatability with setting matplotlib scale."""
  88. if _version_predates(mpl, "3.4"):
  89. # The ability to pass a BaseScale instance to Axes.set_{}scale was added
  90. # to matplotlib in version 3.4.0: GH: matplotlib/matplotlib/pull/19089
  91. # Workaround: use the scale name, which is restrictive only if the user
  92. # wants to define a custom scale; they'll need to update the registry too.
  93. if scale.name is None:
  94. # Hack to support our custom Formatter-less CatScale
  95. return
  96. method = getattr(ax, f"set_{axis}scale")
  97. kws = {}
  98. if scale.name == "function":
  99. trans = scale.get_transform()
  100. kws["functions"] = (trans._forward, trans._inverse)
  101. method(scale.name, **kws)
  102. axis_obj = getattr(ax, f"{axis}axis")
  103. scale.set_default_locators_and_formatters(axis_obj)
  104. else:
  105. ax.set(**{f"{axis}scale": scale})
  106. def get_colormap(name):
  107. """Handle changes to matplotlib colormap interface in 3.6."""
  108. try:
  109. return mpl.colormaps[name]
  110. except AttributeError:
  111. return mpl.cm.get_cmap(name)
  112. def register_colormap(name, cmap):
  113. """Handle changes to matplotlib colormap interface in 3.6."""
  114. try:
  115. if name not in mpl.colormaps:
  116. mpl.colormaps.register(cmap, name=name)
  117. except AttributeError:
  118. mpl.cm.register_cmap(name, cmap)
  119. def set_layout_engine(fig, engine):
  120. """Handle changes to auto layout engine interface in 3.6"""
  121. if hasattr(fig, "set_layout_engine"):
  122. fig.set_layout_engine(engine)
  123. else:
  124. # _version_predates(mpl, 3.6)
  125. if engine == "tight":
  126. fig.set_tight_layout(True)
  127. elif engine == "constrained":
  128. fig.set_constrained_layout(True)
  129. elif engine == "none":
  130. fig.set_tight_layout(False)
  131. fig.set_constrained_layout(False)
  132. def share_axis(ax0, ax1, which):
  133. """Handle changes to post-hoc axis sharing."""
  134. if _version_predates(mpl, "3.5"):
  135. group = getattr(ax0, f"get_shared_{which}_axes")()
  136. group.join(ax1, ax0)
  137. else:
  138. getattr(ax1, f"share{which}")(ax0)
  139. def get_legend_handles(legend):
  140. """Handle legendHandles attribute rename."""
  141. if _version_predates(mpl, "3.7"):
  142. return legend.legendHandles
  143. else:
  144. return legend.legend_handles