123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- import numpy as np
- import matplotlib as mpl
- from seaborn.utils import _version_predates
- def MarkerStyle(marker=None, fillstyle=None):
- """
- Allow MarkerStyle to accept a MarkerStyle object as parameter.
- Supports matplotlib < 3.3.0
- https://github.com/matplotlib/matplotlib/pull/16692
- """
- if isinstance(marker, mpl.markers.MarkerStyle):
- if fillstyle is None:
- return marker
- else:
- marker = marker.get_marker()
- return mpl.markers.MarkerStyle(marker, fillstyle)
- def norm_from_scale(scale, norm):
- """Produce a Normalize object given a Scale and min/max domain limits."""
- # This is an internal maplotlib function that simplifies things to access
- # It is likely to become part of the matplotlib API at some point:
- # https://github.com/matplotlib/matplotlib/issues/20329
- if isinstance(norm, mpl.colors.Normalize):
- return norm
- if scale is None:
- return None
- if norm is None:
- vmin = vmax = None
- else:
- vmin, vmax = norm # TODO more helpful error if this fails?
- class ScaledNorm(mpl.colors.Normalize):
- def __call__(self, value, clip=None):
- # From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py
- # See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE
- value, is_scalar = self.process_value(value)
- self.autoscale_None(value)
- if self.vmin > self.vmax:
- raise ValueError("vmin must be less or equal to vmax")
- if self.vmin == self.vmax:
- return np.full_like(value, 0)
- if clip is None:
- clip = self.clip
- if clip:
- value = np.clip(value, self.vmin, self.vmax)
- # ***** Seaborn changes start ****
- t_value = self.transform(value).reshape(np.shape(value))
- t_vmin, t_vmax = self.transform([self.vmin, self.vmax])
- # ***** Seaborn changes end *****
- if not np.isfinite([t_vmin, t_vmax]).all():
- raise ValueError("Invalid vmin or vmax")
- t_value -= t_vmin
- t_value /= (t_vmax - t_vmin)
- t_value = np.ma.masked_invalid(t_value, copy=False)
- return t_value[0] if is_scalar else t_value
- new_norm = ScaledNorm(vmin, vmax)
- new_norm.transform = scale.get_transform().transform
- return new_norm
- def scale_factory(scale, axis, **kwargs):
- """
- Backwards compatability for creation of independent scales.
- Matplotlib scales require an Axis object for instantiation on < 3.4.
- But the axis is not used, aside from extraction of the axis_name in LogScale.
- """
- modify_transform = False
- if _version_predates(mpl, "3.4"):
- if axis[0] in "xy":
- modify_transform = True
- axis = axis[0]
- base = kwargs.pop("base", None)
- if base is not None:
- kwargs[f"base{axis}"] = base
- nonpos = kwargs.pop("nonpositive", None)
- if nonpos is not None:
- kwargs[f"nonpos{axis}"] = nonpos
- if isinstance(scale, str):
- class Axis:
- axis_name = axis
- axis = Axis()
- scale = mpl.scale.scale_factory(scale, axis, **kwargs)
- if modify_transform:
- transform = scale.get_transform()
- transform.base = kwargs.get("base", 10)
- if kwargs.get("nonpositive") == "mask":
- # Setting a private attribute, but we only get here
- # on an old matplotlib, so this won't break going forwards
- transform._clip = False
- return scale
- def set_scale_obj(ax, axis, scale):
- """Handle backwards compatability with setting matplotlib scale."""
- if _version_predates(mpl, "3.4"):
- # The ability to pass a BaseScale instance to Axes.set_{}scale was added
- # to matplotlib in version 3.4.0: GH: matplotlib/matplotlib/pull/19089
- # Workaround: use the scale name, which is restrictive only if the user
- # wants to define a custom scale; they'll need to update the registry too.
- if scale.name is None:
- # Hack to support our custom Formatter-less CatScale
- return
- method = getattr(ax, f"set_{axis}scale")
- kws = {}
- if scale.name == "function":
- trans = scale.get_transform()
- kws["functions"] = (trans._forward, trans._inverse)
- method(scale.name, **kws)
- axis_obj = getattr(ax, f"{axis}axis")
- scale.set_default_locators_and_formatters(axis_obj)
- else:
- ax.set(**{f"{axis}scale": scale})
- def get_colormap(name):
- """Handle changes to matplotlib colormap interface in 3.6."""
- try:
- return mpl.colormaps[name]
- except AttributeError:
- return mpl.cm.get_cmap(name)
- def register_colormap(name, cmap):
- """Handle changes to matplotlib colormap interface in 3.6."""
- try:
- if name not in mpl.colormaps:
- mpl.colormaps.register(cmap, name=name)
- except AttributeError:
- mpl.cm.register_cmap(name, cmap)
- def set_layout_engine(fig, engine):
- """Handle changes to auto layout engine interface in 3.6"""
- if hasattr(fig, "set_layout_engine"):
- fig.set_layout_engine(engine)
- else:
- # _version_predates(mpl, 3.6)
- if engine == "tight":
- fig.set_tight_layout(True)
- elif engine == "constrained":
- fig.set_constrained_layout(True)
- elif engine == "none":
- fig.set_tight_layout(False)
- fig.set_constrained_layout(False)
- def share_axis(ax0, ax1, which):
- """Handle changes to post-hoc axis sharing."""
- if _version_predates(mpl, "3.5"):
- group = getattr(ax0, f"get_shared_{which}_axes")()
- group.join(ax1, ax0)
- else:
- getattr(ax1, f"share{which}")(ax0)
- def get_legend_handles(legend):
- """Handle legendHandles attribute rename."""
- if _version_predates(mpl, "3.7"):
- return legend.legendHandles
- else:
- return legend.legend_handles
|