12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262 |
- """Functions to visualize matrices of data."""
- import warnings
- import matplotlib as mpl
- from matplotlib.collections import LineCollection
- import matplotlib.pyplot as plt
- from matplotlib import gridspec
- import numpy as np
- import pandas as pd
- try:
- from scipy.cluster import hierarchy
- _no_scipy = False
- except ImportError:
- _no_scipy = True
- from . import cm
- from .axisgrid import Grid
- from ._compat import get_colormap
- from .utils import (
- despine,
- axis_ticklabels_overlap,
- relative_luminance,
- to_utf8,
- _draw_figure,
- )
- __all__ = ["heatmap", "clustermap"]
- def _index_to_label(index):
- """Convert a pandas index or multiindex to an axis label."""
- if isinstance(index, pd.MultiIndex):
- return "-".join(map(to_utf8, index.names))
- else:
- return index.name
- def _index_to_ticklabels(index):
- """Convert a pandas index or multiindex into ticklabels."""
- if isinstance(index, pd.MultiIndex):
- return ["-".join(map(to_utf8, i)) for i in index.values]
- else:
- return index.values
- def _convert_colors(colors):
- """Convert either a list of colors or nested lists of colors to RGB."""
- to_rgb = mpl.colors.to_rgb
- try:
- to_rgb(colors[0])
- # If this works, there is only one level of colors
- return list(map(to_rgb, colors))
- except ValueError:
- # If we get here, we have nested lists
- return [list(map(to_rgb, color_list)) for color_list in colors]
- def _matrix_mask(data, mask):
- """Ensure that data and mask are compatible and add missing values.
- Values will be plotted for cells where ``mask`` is ``False``.
- ``data`` is expected to be a DataFrame; ``mask`` can be an array or
- a DataFrame.
- """
- if mask is None:
- mask = np.zeros(data.shape, bool)
- if isinstance(mask, np.ndarray):
- # For array masks, ensure that shape matches data then convert
- if mask.shape != data.shape:
- raise ValueError("Mask must have the same shape as data.")
- mask = pd.DataFrame(mask,
- index=data.index,
- columns=data.columns,
- dtype=bool)
- elif isinstance(mask, pd.DataFrame):
- # For DataFrame masks, ensure that semantic labels match data
- if not mask.index.equals(data.index) \
- and mask.columns.equals(data.columns):
- err = "Mask must have the same index and columns as data."
- raise ValueError(err)
- # Add any cells with missing data to the mask
- # This works around an issue where `plt.pcolormesh` doesn't represent
- # missing data properly
- mask = mask | pd.isnull(data)
- return mask
- class _HeatMapper:
- """Draw a heatmap plot of a matrix with nice labels and colormaps."""
- def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt,
- annot_kws, cbar, cbar_kws,
- xticklabels=True, yticklabels=True, mask=None):
- """Initialize the plotting object."""
- # We always want to have a DataFrame with semantic information
- # and an ndarray to pass to matplotlib
- if isinstance(data, pd.DataFrame):
- plot_data = data.values
- else:
- plot_data = np.asarray(data)
- data = pd.DataFrame(plot_data)
- # Validate the mask and convert to DataFrame
- mask = _matrix_mask(data, mask)
- plot_data = np.ma.masked_where(np.asarray(mask), plot_data)
- # Get good names for the rows and columns
- xtickevery = 1
- if isinstance(xticklabels, int):
- xtickevery = xticklabels
- xticklabels = _index_to_ticklabels(data.columns)
- elif xticklabels is True:
- xticklabels = _index_to_ticklabels(data.columns)
- elif xticklabels is False:
- xticklabels = []
- ytickevery = 1
- if isinstance(yticklabels, int):
- ytickevery = yticklabels
- yticklabels = _index_to_ticklabels(data.index)
- elif yticklabels is True:
- yticklabels = _index_to_ticklabels(data.index)
- elif yticklabels is False:
- yticklabels = []
- if not len(xticklabels):
- self.xticks = []
- self.xticklabels = []
- elif isinstance(xticklabels, str) and xticklabels == "auto":
- self.xticks = "auto"
- self.xticklabels = _index_to_ticklabels(data.columns)
- else:
- self.xticks, self.xticklabels = self._skip_ticks(xticklabels,
- xtickevery)
- if not len(yticklabels):
- self.yticks = []
- self.yticklabels = []
- elif isinstance(yticklabels, str) and yticklabels == "auto":
- self.yticks = "auto"
- self.yticklabels = _index_to_ticklabels(data.index)
- else:
- self.yticks, self.yticklabels = self._skip_ticks(yticklabels,
- ytickevery)
- # Get good names for the axis labels
- xlabel = _index_to_label(data.columns)
- ylabel = _index_to_label(data.index)
- self.xlabel = xlabel if xlabel is not None else ""
- self.ylabel = ylabel if ylabel is not None else ""
- # Determine good default values for the colormapping
- self._determine_cmap_params(plot_data, vmin, vmax,
- cmap, center, robust)
- # Sort out the annotations
- if annot is None or annot is False:
- annot = False
- annot_data = None
- else:
- if isinstance(annot, bool):
- annot_data = plot_data
- else:
- annot_data = np.asarray(annot)
- if annot_data.shape != plot_data.shape:
- err = "`data` and `annot` must have same shape."
- raise ValueError(err)
- annot = True
- # Save other attributes to the object
- self.data = data
- self.plot_data = plot_data
- self.annot = annot
- self.annot_data = annot_data
- self.fmt = fmt
- self.annot_kws = {} if annot_kws is None else annot_kws.copy()
- self.cbar = cbar
- self.cbar_kws = {} if cbar_kws is None else cbar_kws.copy()
- def _determine_cmap_params(self, plot_data, vmin, vmax,
- cmap, center, robust):
- """Use some heuristics to set good defaults for colorbar and range."""
- # plot_data is a np.ma.array instance
- calc_data = plot_data.astype(float).filled(np.nan)
- if vmin is None:
- if robust:
- vmin = np.nanpercentile(calc_data, 2)
- else:
- vmin = np.nanmin(calc_data)
- if vmax is None:
- if robust:
- vmax = np.nanpercentile(calc_data, 98)
- else:
- vmax = np.nanmax(calc_data)
- self.vmin, self.vmax = vmin, vmax
- # Choose default colormaps if not provided
- if cmap is None:
- if center is None:
- self.cmap = cm.rocket
- else:
- self.cmap = cm.icefire
- elif isinstance(cmap, str):
- self.cmap = get_colormap(cmap)
- elif isinstance(cmap, list):
- self.cmap = mpl.colors.ListedColormap(cmap)
- else:
- self.cmap = cmap
- # Recenter a divergent colormap
- if center is not None:
- # Copy bad values
- # in mpl<3.2 only masked values are honored with "bad" color spec
- # (see https://github.com/matplotlib/matplotlib/pull/14257)
- bad = self.cmap(np.ma.masked_invalid([np.nan]))[0]
- # under/over values are set for sure when cmap extremes
- # do not map to the same color as +-inf
- under = self.cmap(-np.inf)
- over = self.cmap(np.inf)
- under_set = under != self.cmap(0)
- over_set = over != self.cmap(self.cmap.N - 1)
- vrange = max(vmax - center, center - vmin)
- normlize = mpl.colors.Normalize(center - vrange, center + vrange)
- cmin, cmax = normlize([vmin, vmax])
- cc = np.linspace(cmin, cmax, 256)
- self.cmap = mpl.colors.ListedColormap(self.cmap(cc))
- self.cmap.set_bad(bad)
- if under_set:
- self.cmap.set_under(under)
- if over_set:
- self.cmap.set_over(over)
- def _annotate_heatmap(self, ax, mesh):
- """Add textual labels with the value in each cell."""
- mesh.update_scalarmappable()
- height, width = self.annot_data.shape
- xpos, ypos = np.meshgrid(np.arange(width) + .5, np.arange(height) + .5)
- for x, y, m, color, val in zip(xpos.flat, ypos.flat,
- mesh.get_array().flat, mesh.get_facecolors(),
- self.annot_data.flat):
- if m is not np.ma.masked:
- lum = relative_luminance(color)
- text_color = ".15" if lum > .408 else "w"
- annotation = ("{:" + self.fmt + "}").format(val)
- text_kwargs = dict(color=text_color, ha="center", va="center")
- text_kwargs.update(self.annot_kws)
- ax.text(x, y, annotation, **text_kwargs)
- def _skip_ticks(self, labels, tickevery):
- """Return ticks and labels at evenly spaced intervals."""
- n = len(labels)
- if tickevery == 0:
- ticks, labels = [], []
- elif tickevery == 1:
- ticks, labels = np.arange(n) + .5, labels
- else:
- start, end, step = 0, n, tickevery
- ticks = np.arange(start, end, step) + .5
- labels = labels[start:end:step]
- return ticks, labels
- def _auto_ticks(self, ax, labels, axis):
- """Determine ticks and ticklabels that minimize overlap."""
- transform = ax.figure.dpi_scale_trans.inverted()
- bbox = ax.get_window_extent().transformed(transform)
- size = [bbox.width, bbox.height][axis]
- axis = [ax.xaxis, ax.yaxis][axis]
- tick, = axis.set_ticks([0])
- fontsize = tick.label1.get_size()
- max_ticks = int(size // (fontsize / 72))
- if max_ticks < 1:
- return [], []
- tick_every = len(labels) // max_ticks + 1
- tick_every = 1 if tick_every == 0 else tick_every
- ticks, labels = self._skip_ticks(labels, tick_every)
- return ticks, labels
- def plot(self, ax, cax, kws):
- """Draw the heatmap on the provided Axes."""
- # Remove all the Axes spines
- despine(ax=ax, left=True, bottom=True)
- # setting vmin/vmax in addition to norm is deprecated
- # so avoid setting if norm is set
- if kws.get("norm") is None:
- kws.setdefault("vmin", self.vmin)
- kws.setdefault("vmax", self.vmax)
- # Draw the heatmap
- mesh = ax.pcolormesh(self.plot_data, cmap=self.cmap, **kws)
- # Set the axis limits
- ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0]))
- # Invert the y axis to show the plot in matrix form
- ax.invert_yaxis()
- # Possibly add a colorbar
- if self.cbar:
- cb = ax.figure.colorbar(mesh, cax, ax, **self.cbar_kws)
- cb.outline.set_linewidth(0)
- # If rasterized is passed to pcolormesh, also rasterize the
- # colorbar to avoid white lines on the PDF rendering
- if kws.get('rasterized', False):
- cb.solids.set_rasterized(True)
- # Add row and column labels
- if isinstance(self.xticks, str) and self.xticks == "auto":
- xticks, xticklabels = self._auto_ticks(ax, self.xticklabels, 0)
- else:
- xticks, xticklabels = self.xticks, self.xticklabels
- if isinstance(self.yticks, str) and self.yticks == "auto":
- yticks, yticklabels = self._auto_ticks(ax, self.yticklabels, 1)
- else:
- yticks, yticklabels = self.yticks, self.yticklabels
- ax.set(xticks=xticks, yticks=yticks)
- xtl = ax.set_xticklabels(xticklabels)
- ytl = ax.set_yticklabels(yticklabels, rotation="vertical")
- plt.setp(ytl, va="center") # GH2484
- # Possibly rotate them if they overlap
- _draw_figure(ax.figure)
- if axis_ticklabels_overlap(xtl):
- plt.setp(xtl, rotation="vertical")
- if axis_ticklabels_overlap(ytl):
- plt.setp(ytl, rotation="horizontal")
- # Add the axis labels
- ax.set(xlabel=self.xlabel, ylabel=self.ylabel)
- # Annotate the cells with the formatted values
- if self.annot:
- self._annotate_heatmap(ax, mesh)
- def heatmap(
- data, *,
- vmin=None, vmax=None, cmap=None, center=None, robust=False,
- annot=None, fmt=".2g", annot_kws=None,
- linewidths=0, linecolor="white",
- cbar=True, cbar_kws=None, cbar_ax=None,
- square=False, xticklabels="auto", yticklabels="auto",
- mask=None, ax=None,
- **kwargs
- ):
- """Plot rectangular data as a color-encoded matrix.
- This is an Axes-level function and will draw the heatmap into the
- currently-active Axes if none is provided to the ``ax`` argument. Part of
- this Axes space will be taken and used to plot a colormap, unless ``cbar``
- is False or a separate Axes is provided to ``cbar_ax``.
- Parameters
- ----------
- data : rectangular dataset
- 2D dataset that can be coerced into an ndarray. If a Pandas DataFrame
- is provided, the index/column information will be used to label the
- columns and rows.
- vmin, vmax : floats, optional
- Values to anchor the colormap, otherwise they are inferred from the
- data and other keyword arguments.
- cmap : matplotlib colormap name or object, or list of colors, optional
- The mapping from data values to color space. If not provided, the
- default will depend on whether ``center`` is set.
- center : float, optional
- The value at which to center the colormap when plotting divergent data.
- Using this parameter will change the default ``cmap`` if none is
- specified.
- robust : bool, optional
- If True and ``vmin`` or ``vmax`` are absent, the colormap range is
- computed with robust quantiles instead of the extreme values.
- annot : bool or rectangular dataset, optional
- If True, write the data value in each cell. If an array-like with the
- same shape as ``data``, then use this to annotate the heatmap instead
- of the data. Note that DataFrames will match on position, not index.
- fmt : str, optional
- String formatting code to use when adding annotations.
- annot_kws : dict of key, value mappings, optional
- Keyword arguments for :meth:`matplotlib.axes.Axes.text` when ``annot``
- is True.
- linewidths : float, optional
- Width of the lines that will divide each cell.
- linecolor : color, optional
- Color of the lines that will divide each cell.
- cbar : bool, optional
- Whether to draw a colorbar.
- cbar_kws : dict of key, value mappings, optional
- Keyword arguments for :meth:`matplotlib.figure.Figure.colorbar`.
- cbar_ax : matplotlib Axes, optional
- Axes in which to draw the colorbar, otherwise take space from the
- main Axes.
- square : bool, optional
- If True, set the Axes aspect to "equal" so each cell will be
- square-shaped.
- xticklabels, yticklabels : "auto", bool, list-like, or int, optional
- If True, plot the column names of the dataframe. If False, don't plot
- the column names. If list-like, plot these alternate labels as the
- xticklabels. If an integer, use the column names but plot only every
- n label. If "auto", try to densely plot non-overlapping labels.
- mask : bool array or DataFrame, optional
- If passed, data will not be shown in cells where ``mask`` is True.
- Cells with missing values are automatically masked.
- ax : matplotlib Axes, optional
- Axes in which to draw the plot, otherwise use the currently-active
- Axes.
- kwargs : other keyword arguments
- All other keyword arguments are passed to
- :meth:`matplotlib.axes.Axes.pcolormesh`.
- Returns
- -------
- ax : matplotlib Axes
- Axes object with the heatmap.
- See Also
- --------
- clustermap : Plot a matrix using hierarchical clustering to arrange the
- rows and columns.
- Examples
- --------
- .. include:: ../docstrings/heatmap.rst
- """
- # Initialize the plotter object
- plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt,
- annot_kws, cbar, cbar_kws, xticklabels,
- yticklabels, mask)
- # Add the pcolormesh kwargs here
- kwargs["linewidths"] = linewidths
- kwargs["edgecolor"] = linecolor
- # Draw the plot and return the Axes
- if ax is None:
- ax = plt.gca()
- if square:
- ax.set_aspect("equal")
- plotter.plot(ax, cbar_ax, kwargs)
- return ax
- class _DendrogramPlotter:
- """Object for drawing tree of similarities between data rows/columns"""
- def __init__(self, data, linkage, metric, method, axis, label, rotate):
- """Plot a dendrogram of the relationships between the columns of data
- Parameters
- ----------
- data : pandas.DataFrame
- Rectangular data
- """
- self.axis = axis
- if self.axis == 1:
- data = data.T
- if isinstance(data, pd.DataFrame):
- array = data.values
- else:
- array = np.asarray(data)
- data = pd.DataFrame(array)
- self.array = array
- self.data = data
- self.shape = self.data.shape
- self.metric = metric
- self.method = method
- self.axis = axis
- self.label = label
- self.rotate = rotate
- if linkage is None:
- self.linkage = self.calculated_linkage
- else:
- self.linkage = linkage
- self.dendrogram = self.calculate_dendrogram()
- # Dendrogram ends are always at multiples of 5, who knows why
- ticks = 10 * np.arange(self.data.shape[0]) + 5
- if self.label:
- ticklabels = _index_to_ticklabels(self.data.index)
- ticklabels = [ticklabels[i] for i in self.reordered_ind]
- if self.rotate:
- self.xticks = []
- self.yticks = ticks
- self.xticklabels = []
- self.yticklabels = ticklabels
- self.ylabel = _index_to_label(self.data.index)
- self.xlabel = ''
- else:
- self.xticks = ticks
- self.yticks = []
- self.xticklabels = ticklabels
- self.yticklabels = []
- self.ylabel = ''
- self.xlabel = _index_to_label(self.data.index)
- else:
- self.xticks, self.yticks = [], []
- self.yticklabels, self.xticklabels = [], []
- self.xlabel, self.ylabel = '', ''
- self.dependent_coord = self.dendrogram['dcoord']
- self.independent_coord = self.dendrogram['icoord']
- def _calculate_linkage_scipy(self):
- linkage = hierarchy.linkage(self.array, method=self.method,
- metric=self.metric)
- return linkage
- def _calculate_linkage_fastcluster(self):
- import fastcluster
- # Fastcluster has a memory-saving vectorized version, but only
- # with certain linkage methods, and mostly with euclidean metric
- # vector_methods = ('single', 'centroid', 'median', 'ward')
- euclidean_methods = ('centroid', 'median', 'ward')
- euclidean = self.metric == 'euclidean' and self.method in \
- euclidean_methods
- if euclidean or self.method == 'single':
- return fastcluster.linkage_vector(self.array,
- method=self.method,
- metric=self.metric)
- else:
- linkage = fastcluster.linkage(self.array, method=self.method,
- metric=self.metric)
- return linkage
- @property
- def calculated_linkage(self):
- try:
- return self._calculate_linkage_fastcluster()
- except ImportError:
- if np.prod(self.shape) >= 10000:
- msg = ("Clustering large matrix with scipy. Installing "
- "`fastcluster` may give better performance.")
- warnings.warn(msg)
- return self._calculate_linkage_scipy()
- def calculate_dendrogram(self):
- """Calculates a dendrogram based on the linkage matrix
- Made a separate function, not a property because don't want to
- recalculate the dendrogram every time it is accessed.
- Returns
- -------
- dendrogram : dict
- Dendrogram dictionary as returned by scipy.cluster.hierarchy
- .dendrogram. The important key-value pairing is
- "reordered_ind" which indicates the re-ordering of the matrix
- """
- return hierarchy.dendrogram(self.linkage, no_plot=True,
- color_threshold=-np.inf)
- @property
- def reordered_ind(self):
- """Indices of the matrix, reordered by the dendrogram"""
- return self.dendrogram['leaves']
- def plot(self, ax, tree_kws):
- """Plots a dendrogram of the similarities between data on the axes
- Parameters
- ----------
- ax : matplotlib.axes.Axes
- Axes object upon which the dendrogram is plotted
- """
- tree_kws = {} if tree_kws is None else tree_kws.copy()
- tree_kws.setdefault("linewidths", .5)
- tree_kws.setdefault("colors", tree_kws.pop("color", (.2, .2, .2)))
- if self.rotate and self.axis == 0:
- coords = zip(self.dependent_coord, self.independent_coord)
- else:
- coords = zip(self.independent_coord, self.dependent_coord)
- lines = LineCollection([list(zip(x, y)) for x, y in coords],
- **tree_kws)
- ax.add_collection(lines)
- number_of_leaves = len(self.reordered_ind)
- max_dependent_coord = max(map(max, self.dependent_coord))
- if self.rotate:
- ax.yaxis.set_ticks_position('right')
- # Constants 10 and 1.05 come from
- # `scipy.cluster.hierarchy._plot_dendrogram`
- ax.set_ylim(0, number_of_leaves * 10)
- ax.set_xlim(0, max_dependent_coord * 1.05)
- ax.invert_xaxis()
- ax.invert_yaxis()
- else:
- # Constants 10 and 1.05 come from
- # `scipy.cluster.hierarchy._plot_dendrogram`
- ax.set_xlim(0, number_of_leaves * 10)
- ax.set_ylim(0, max_dependent_coord * 1.05)
- despine(ax=ax, bottom=True, left=True)
- ax.set(xticks=self.xticks, yticks=self.yticks,
- xlabel=self.xlabel, ylabel=self.ylabel)
- xtl = ax.set_xticklabels(self.xticklabels)
- ytl = ax.set_yticklabels(self.yticklabels, rotation='vertical')
- # Force a draw of the plot to avoid matplotlib window error
- _draw_figure(ax.figure)
- if len(ytl) > 0 and axis_ticklabels_overlap(ytl):
- plt.setp(ytl, rotation="horizontal")
- if len(xtl) > 0 and axis_ticklabels_overlap(xtl):
- plt.setp(xtl, rotation="vertical")
- return self
- def dendrogram(
- data, *,
- linkage=None, axis=1, label=True, metric='euclidean',
- method='average', rotate=False, tree_kws=None, ax=None
- ):
- """Draw a tree diagram of relationships within a matrix
- Parameters
- ----------
- data : pandas.DataFrame
- Rectangular data
- linkage : numpy.array, optional
- Linkage matrix
- axis : int, optional
- Which axis to use to calculate linkage. 0 is rows, 1 is columns.
- label : bool, optional
- If True, label the dendrogram at leaves with column or row names
- metric : str, optional
- Distance metric. Anything valid for scipy.spatial.distance.pdist
- method : str, optional
- Linkage method to use. Anything valid for
- scipy.cluster.hierarchy.linkage
- rotate : bool, optional
- When plotting the matrix, whether to rotate it 90 degrees
- counter-clockwise, so the leaves face right
- tree_kws : dict, optional
- Keyword arguments for the ``matplotlib.collections.LineCollection``
- that is used for plotting the lines of the dendrogram tree.
- ax : matplotlib axis, optional
- Axis to plot on, otherwise uses current axis
- Returns
- -------
- dendrogramplotter : _DendrogramPlotter
- A Dendrogram plotter object.
- Notes
- -----
- Access the reordered dendrogram indices with
- dendrogramplotter.reordered_ind
- """
- if _no_scipy:
- raise RuntimeError("dendrogram requires scipy to be installed")
- plotter = _DendrogramPlotter(data, linkage=linkage, axis=axis,
- metric=metric, method=method,
- label=label, rotate=rotate)
- if ax is None:
- ax = plt.gca()
- return plotter.plot(ax=ax, tree_kws=tree_kws)
- class ClusterGrid(Grid):
- def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None,
- figsize=None, row_colors=None, col_colors=None, mask=None,
- dendrogram_ratio=None, colors_ratio=None, cbar_pos=None):
- """Grid object for organizing clustered heatmap input on to axes"""
- if _no_scipy:
- raise RuntimeError("ClusterGrid requires scipy to be available")
- if isinstance(data, pd.DataFrame):
- self.data = data
- else:
- self.data = pd.DataFrame(data)
- self.data2d = self.format_data(self.data, pivot_kws, z_score,
- standard_scale)
- self.mask = _matrix_mask(self.data2d, mask)
- self._figure = plt.figure(figsize=figsize)
- self.row_colors, self.row_color_labels = \
- self._preprocess_colors(data, row_colors, axis=0)
- self.col_colors, self.col_color_labels = \
- self._preprocess_colors(data, col_colors, axis=1)
- try:
- row_dendrogram_ratio, col_dendrogram_ratio = dendrogram_ratio
- except TypeError:
- row_dendrogram_ratio = col_dendrogram_ratio = dendrogram_ratio
- try:
- row_colors_ratio, col_colors_ratio = colors_ratio
- except TypeError:
- row_colors_ratio = col_colors_ratio = colors_ratio
- width_ratios = self.dim_ratios(self.row_colors,
- row_dendrogram_ratio,
- row_colors_ratio)
- height_ratios = self.dim_ratios(self.col_colors,
- col_dendrogram_ratio,
- col_colors_ratio)
- nrows = 2 if self.col_colors is None else 3
- ncols = 2 if self.row_colors is None else 3
- self.gs = gridspec.GridSpec(nrows, ncols,
- width_ratios=width_ratios,
- height_ratios=height_ratios)
- self.ax_row_dendrogram = self._figure.add_subplot(self.gs[-1, 0])
- self.ax_col_dendrogram = self._figure.add_subplot(self.gs[0, -1])
- self.ax_row_dendrogram.set_axis_off()
- self.ax_col_dendrogram.set_axis_off()
- self.ax_row_colors = None
- self.ax_col_colors = None
- if self.row_colors is not None:
- self.ax_row_colors = self._figure.add_subplot(
- self.gs[-1, 1])
- if self.col_colors is not None:
- self.ax_col_colors = self._figure.add_subplot(
- self.gs[1, -1])
- self.ax_heatmap = self._figure.add_subplot(self.gs[-1, -1])
- if cbar_pos is None:
- self.ax_cbar = self.cax = None
- else:
- # Initialize the colorbar axes in the gridspec so that tight_layout
- # works. We will move it where it belongs later. This is a hack.
- self.ax_cbar = self._figure.add_subplot(self.gs[0, 0])
- self.cax = self.ax_cbar # Backwards compatibility
- self.cbar_pos = cbar_pos
- self.dendrogram_row = None
- self.dendrogram_col = None
- def _preprocess_colors(self, data, colors, axis):
- """Preprocess {row/col}_colors to extract labels and convert colors."""
- labels = None
- if colors is not None:
- if isinstance(colors, (pd.DataFrame, pd.Series)):
- # If data is unindexed, raise
- if (not hasattr(data, "index") and axis == 0) or (
- not hasattr(data, "columns") and axis == 1
- ):
- axis_name = "col" if axis else "row"
- msg = (f"{axis_name}_colors indices can't be matched with data "
- f"indices. Provide {axis_name}_colors as a non-indexed "
- "datatype, e.g. by using `.to_numpy()``")
- raise TypeError(msg)
- # Ensure colors match data indices
- if axis == 0:
- colors = colors.reindex(data.index)
- else:
- colors = colors.reindex(data.columns)
- # Replace na's with white color
- # TODO We should set these to transparent instead
- colors = colors.astype(object).fillna('white')
- # Extract color values and labels from frame/series
- if isinstance(colors, pd.DataFrame):
- labels = list(colors.columns)
- colors = colors.T.values
- else:
- if colors.name is None:
- labels = [""]
- else:
- labels = [colors.name]
- colors = colors.values
- colors = _convert_colors(colors)
- return colors, labels
- def format_data(self, data, pivot_kws, z_score=None,
- standard_scale=None):
- """Extract variables from data or use directly."""
- # Either the data is already in 2d matrix format, or need to do a pivot
- if pivot_kws is not None:
- data2d = data.pivot(**pivot_kws)
- else:
- data2d = data
- if z_score is not None and standard_scale is not None:
- raise ValueError(
- 'Cannot perform both z-scoring and standard-scaling on data')
- if z_score is not None:
- data2d = self.z_score(data2d, z_score)
- if standard_scale is not None:
- data2d = self.standard_scale(data2d, standard_scale)
- return data2d
- @staticmethod
- def z_score(data2d, axis=1):
- """Standarize the mean and variance of the data axis
- Parameters
- ----------
- data2d : pandas.DataFrame
- Data to normalize
- axis : int
- Which axis to normalize across. If 0, normalize across rows, if 1,
- normalize across columns.
- Returns
- -------
- normalized : pandas.DataFrame
- Noramlized data with a mean of 0 and variance of 1 across the
- specified axis.
- """
- if axis == 1:
- z_scored = data2d
- else:
- z_scored = data2d.T
- z_scored = (z_scored - z_scored.mean()) / z_scored.std()
- if axis == 1:
- return z_scored
- else:
- return z_scored.T
- @staticmethod
- def standard_scale(data2d, axis=1):
- """Divide the data by the difference between the max and min
- Parameters
- ----------
- data2d : pandas.DataFrame
- Data to normalize
- axis : int
- Which axis to normalize across. If 0, normalize across rows, if 1,
- normalize across columns.
- Returns
- -------
- standardized : pandas.DataFrame
- Noramlized data with a mean of 0 and variance of 1 across the
- specified axis.
- """
- # Normalize these values to range from 0 to 1
- if axis == 1:
- standardized = data2d
- else:
- standardized = data2d.T
- subtract = standardized.min()
- standardized = (standardized - subtract) / (
- standardized.max() - standardized.min())
- if axis == 1:
- return standardized
- else:
- return standardized.T
- def dim_ratios(self, colors, dendrogram_ratio, colors_ratio):
- """Get the proportions of the figure taken up by each axes."""
- ratios = [dendrogram_ratio]
- if colors is not None:
- # Colors are encoded as rgb, so there is an extra dimension
- if np.ndim(colors) > 2:
- n_colors = len(colors)
- else:
- n_colors = 1
- ratios += [n_colors * colors_ratio]
- # Add the ratio for the heatmap itself
- ratios.append(1 - sum(ratios))
- return ratios
- @staticmethod
- def color_list_to_matrix_and_cmap(colors, ind, axis=0):
- """Turns a list of colors into a numpy matrix and matplotlib colormap
- These arguments can now be plotted using heatmap(matrix, cmap)
- and the provided colors will be plotted.
- Parameters
- ----------
- colors : list of matplotlib colors
- Colors to label the rows or columns of a dataframe.
- ind : list of ints
- Ordering of the rows or columns, to reorder the original colors
- by the clustered dendrogram order
- axis : int
- Which axis this is labeling
- Returns
- -------
- matrix : numpy.array
- A numpy array of integer values, where each indexes into the cmap
- cmap : matplotlib.colors.ListedColormap
- """
- try:
- mpl.colors.to_rgb(colors[0])
- except ValueError:
- # We have a 2D color structure
- m, n = len(colors), len(colors[0])
- if not all(len(c) == n for c in colors[1:]):
- raise ValueError("Multiple side color vectors must have same size")
- else:
- # We have one vector of colors
- m, n = 1, len(colors)
- colors = [colors]
- # Map from unique colors to colormap index value
- unique_colors = {}
- matrix = np.zeros((m, n), int)
- for i, inner in enumerate(colors):
- for j, color in enumerate(inner):
- idx = unique_colors.setdefault(color, len(unique_colors))
- matrix[i, j] = idx
- # Reorder for clustering and transpose for axis
- matrix = matrix[:, ind]
- if axis == 0:
- matrix = matrix.T
- cmap = mpl.colors.ListedColormap(list(unique_colors))
- return matrix, cmap
- def plot_dendrograms(self, row_cluster, col_cluster, metric, method,
- row_linkage, col_linkage, tree_kws):
- # Plot the row dendrogram
- if row_cluster:
- self.dendrogram_row = dendrogram(
- self.data2d, metric=metric, method=method, label=False, axis=0,
- ax=self.ax_row_dendrogram, rotate=True, linkage=row_linkage,
- tree_kws=tree_kws
- )
- else:
- self.ax_row_dendrogram.set_xticks([])
- self.ax_row_dendrogram.set_yticks([])
- # PLot the column dendrogram
- if col_cluster:
- self.dendrogram_col = dendrogram(
- self.data2d, metric=metric, method=method, label=False,
- axis=1, ax=self.ax_col_dendrogram, linkage=col_linkage,
- tree_kws=tree_kws
- )
- else:
- self.ax_col_dendrogram.set_xticks([])
- self.ax_col_dendrogram.set_yticks([])
- despine(ax=self.ax_row_dendrogram, bottom=True, left=True)
- despine(ax=self.ax_col_dendrogram, bottom=True, left=True)
- def plot_colors(self, xind, yind, **kws):
- """Plots color labels between the dendrogram and the heatmap
- Parameters
- ----------
- heatmap_kws : dict
- Keyword arguments heatmap
- """
- # Remove any custom colormap and centering
- # TODO this code has consistently caused problems when we
- # have missed kwargs that need to be excluded that it might
- # be better to rewrite *in*clusively.
- kws = kws.copy()
- kws.pop('cmap', None)
- kws.pop('norm', None)
- kws.pop('center', None)
- kws.pop('annot', None)
- kws.pop('vmin', None)
- kws.pop('vmax', None)
- kws.pop('robust', None)
- kws.pop('xticklabels', None)
- kws.pop('yticklabels', None)
- # Plot the row colors
- if self.row_colors is not None:
- matrix, cmap = self.color_list_to_matrix_and_cmap(
- self.row_colors, yind, axis=0)
- # Get row_color labels
- if self.row_color_labels is not None:
- row_color_labels = self.row_color_labels
- else:
- row_color_labels = False
- heatmap(matrix, cmap=cmap, cbar=False, ax=self.ax_row_colors,
- xticklabels=row_color_labels, yticklabels=False, **kws)
- # Adjust rotation of labels
- if row_color_labels is not False:
- plt.setp(self.ax_row_colors.get_xticklabels(), rotation=90)
- else:
- despine(self.ax_row_colors, left=True, bottom=True)
- # Plot the column colors
- if self.col_colors is not None:
- matrix, cmap = self.color_list_to_matrix_and_cmap(
- self.col_colors, xind, axis=1)
- # Get col_color labels
- if self.col_color_labels is not None:
- col_color_labels = self.col_color_labels
- else:
- col_color_labels = False
- heatmap(matrix, cmap=cmap, cbar=False, ax=self.ax_col_colors,
- xticklabels=False, yticklabels=col_color_labels, **kws)
- # Adjust rotation of labels, place on right side
- if col_color_labels is not False:
- self.ax_col_colors.yaxis.tick_right()
- plt.setp(self.ax_col_colors.get_yticklabels(), rotation=0)
- else:
- despine(self.ax_col_colors, left=True, bottom=True)
- def plot_matrix(self, colorbar_kws, xind, yind, **kws):
- self.data2d = self.data2d.iloc[yind, xind]
- self.mask = self.mask.iloc[yind, xind]
- # Try to reorganize specified tick labels, if provided
- xtl = kws.pop("xticklabels", "auto")
- try:
- xtl = np.asarray(xtl)[xind]
- except (TypeError, IndexError):
- pass
- ytl = kws.pop("yticklabels", "auto")
- try:
- ytl = np.asarray(ytl)[yind]
- except (TypeError, IndexError):
- pass
- # Reorganize the annotations to match the heatmap
- annot = kws.pop("annot", None)
- if annot is None or annot is False:
- pass
- else:
- if isinstance(annot, bool):
- annot_data = self.data2d
- else:
- annot_data = np.asarray(annot)
- if annot_data.shape != self.data2d.shape:
- err = "`data` and `annot` must have same shape."
- raise ValueError(err)
- annot_data = annot_data[yind][:, xind]
- annot = annot_data
- # Setting ax_cbar=None in clustermap call implies no colorbar
- kws.setdefault("cbar", self.ax_cbar is not None)
- heatmap(self.data2d, ax=self.ax_heatmap, cbar_ax=self.ax_cbar,
- cbar_kws=colorbar_kws, mask=self.mask,
- xticklabels=xtl, yticklabels=ytl, annot=annot, **kws)
- ytl = self.ax_heatmap.get_yticklabels()
- ytl_rot = None if not ytl else ytl[0].get_rotation()
- self.ax_heatmap.yaxis.set_ticks_position('right')
- self.ax_heatmap.yaxis.set_label_position('right')
- if ytl_rot is not None:
- ytl = self.ax_heatmap.get_yticklabels()
- plt.setp(ytl, rotation=ytl_rot)
- tight_params = dict(h_pad=.02, w_pad=.02)
- if self.ax_cbar is None:
- self._figure.tight_layout(**tight_params)
- else:
- # Turn the colorbar axes off for tight layout so that its
- # ticks don't interfere with the rest of the plot layout.
- # Then move it.
- self.ax_cbar.set_axis_off()
- self._figure.tight_layout(**tight_params)
- self.ax_cbar.set_axis_on()
- self.ax_cbar.set_position(self.cbar_pos)
- def plot(self, metric, method, colorbar_kws, row_cluster, col_cluster,
- row_linkage, col_linkage, tree_kws, **kws):
- # heatmap square=True sets the aspect ratio on the axes, but that is
- # not compatible with the multi-axes layout of clustergrid
- if kws.get("square", False):
- msg = "``square=True`` ignored in clustermap"
- warnings.warn(msg)
- kws.pop("square")
- colorbar_kws = {} if colorbar_kws is None else colorbar_kws
- self.plot_dendrograms(row_cluster, col_cluster, metric, method,
- row_linkage=row_linkage, col_linkage=col_linkage,
- tree_kws=tree_kws)
- try:
- xind = self.dendrogram_col.reordered_ind
- except AttributeError:
- xind = np.arange(self.data2d.shape[1])
- try:
- yind = self.dendrogram_row.reordered_ind
- except AttributeError:
- yind = np.arange(self.data2d.shape[0])
- self.plot_colors(xind, yind, **kws)
- self.plot_matrix(colorbar_kws, xind, yind, **kws)
- return self
- def clustermap(
- data, *,
- pivot_kws=None, method='average', metric='euclidean',
- z_score=None, standard_scale=None, figsize=(10, 10),
- cbar_kws=None, row_cluster=True, col_cluster=True,
- row_linkage=None, col_linkage=None,
- row_colors=None, col_colors=None, mask=None,
- dendrogram_ratio=.2, colors_ratio=0.03,
- cbar_pos=(.02, .8, .05, .18), tree_kws=None,
- **kwargs
- ):
- """
- Plot a matrix dataset as a hierarchically-clustered heatmap.
- This function requires scipy to be available.
- Parameters
- ----------
- data : 2D array-like
- Rectangular data for clustering. Cannot contain NAs.
- pivot_kws : dict, optional
- If `data` is a tidy dataframe, can provide keyword arguments for
- pivot to create a rectangular dataframe.
- method : str, optional
- Linkage method to use for calculating clusters. See
- :func:`scipy.cluster.hierarchy.linkage` documentation for more
- information.
- metric : str, optional
- Distance metric to use for the data. See
- :func:`scipy.spatial.distance.pdist` documentation for more options.
- To use different metrics (or methods) for rows and columns, you may
- construct each linkage matrix yourself and provide them as
- `{row,col}_linkage`.
- z_score : int or None, optional
- Either 0 (rows) or 1 (columns). Whether or not to calculate z-scores
- for the rows or the columns. Z scores are: z = (x - mean)/std, so
- values in each row (column) will get the mean of the row (column)
- subtracted, then divided by the standard deviation of the row (column).
- This ensures that each row (column) has mean of 0 and variance of 1.
- standard_scale : int or None, optional
- Either 0 (rows) or 1 (columns). Whether or not to standardize that
- dimension, meaning for each row or column, subtract the minimum and
- divide each by its maximum.
- figsize : tuple of (width, height), optional
- Overall size of the figure.
- cbar_kws : dict, optional
- Keyword arguments to pass to `cbar_kws` in :func:`heatmap`, e.g. to
- add a label to the colorbar.
- {row,col}_cluster : bool, optional
- If ``True``, cluster the {rows, columns}.
- {row,col}_linkage : :class:`numpy.ndarray`, optional
- Precomputed linkage matrix for the rows or columns. See
- :func:`scipy.cluster.hierarchy.linkage` for specific formats.
- {row,col}_colors : list-like or pandas DataFrame/Series, optional
- List of colors to label for either the rows or columns. Useful to evaluate
- whether samples within a group are clustered together. Can use nested lists or
- DataFrame for multiple color levels of labeling. If given as a
- :class:`pandas.DataFrame` or :class:`pandas.Series`, labels for the colors are
- extracted from the DataFrames column names or from the name of the Series.
- DataFrame/Series colors are also matched to the data by their index, ensuring
- colors are drawn in the correct order.
- mask : bool array or DataFrame, optional
- If passed, data will not be shown in cells where `mask` is True.
- Cells with missing values are automatically masked. Only used for
- visualizing, not for calculating.
- {dendrogram,colors}_ratio : float, or pair of floats, optional
- Proportion of the figure size devoted to the two marginal elements. If
- a pair is given, they correspond to (row, col) ratios.
- cbar_pos : tuple of (left, bottom, width, height), optional
- Position of the colorbar axes in the figure. Setting to ``None`` will
- disable the colorbar.
- tree_kws : dict, optional
- Parameters for the :class:`matplotlib.collections.LineCollection`
- that is used to plot the lines of the dendrogram tree.
- kwargs : other keyword arguments
- All other keyword arguments are passed to :func:`heatmap`.
- Returns
- -------
- :class:`ClusterGrid`
- A :class:`ClusterGrid` instance.
- See Also
- --------
- heatmap : Plot rectangular data as a color-encoded matrix.
- Notes
- -----
- The returned object has a ``savefig`` method that should be used if you
- want to save the figure object without clipping the dendrograms.
- To access the reordered row indices, use:
- ``clustergrid.dendrogram_row.reordered_ind``
- Column indices, use:
- ``clustergrid.dendrogram_col.reordered_ind``
- Examples
- --------
- .. include:: ../docstrings/clustermap.rst
- """
- if _no_scipy:
- raise RuntimeError("clustermap requires scipy to be available")
- plotter = ClusterGrid(data, pivot_kws=pivot_kws, figsize=figsize,
- row_colors=row_colors, col_colors=col_colors,
- z_score=z_score, standard_scale=standard_scale,
- mask=mask, dendrogram_ratio=dendrogram_ratio,
- colors_ratio=colors_ratio, cbar_pos=cbar_pos)
- return plotter.plot(metric=metric, method=method,
- colorbar_kws=cbar_kws,
- row_cluster=row_cluster, col_cluster=col_cluster,
- row_linkage=row_linkage, col_linkage=col_linkage,
- tree_kws=tree_kws, **kwargs)
|