Source code for plot_corner

Production of corner plots.

Modified from a fork of .
Original code: Copyright (c) 2013-2020 Daniel Foreman-Mackey
Full license:

This modified version:
 - Add the observed quantities to the corner plots
 - Colours for the plots
 - Add KDE to non-contour panels
 - Cleaned
import logging
import colorsys

import numpy as np
import matplotlib
from scipy.stats import gaussian_kde
from scipy.ndimage import gaussian_filter
from matplotlib.ticker import MaxNLocator, NullLocator
from matplotlib.colors import LinearSegmentedColormap, colorConverter
from matplotlib.ticker import ScalarFormatter
import matplotlib.colors as mc

import matplotlib.pyplot as plt

fontdic = {"size": 12}

__all__ = ["corner", "hist2d"]

[docs] def corner( xs, smooth=None, smooth1d="kde", labels=None, label_kwargs=fontdic, show_titles=False, title_fmt=".3f", title_kwargs=fontdic, truth_color="#4682b4", scale_hist=False, quantiles=None, max_n_ticks=5, use_math_text=False, reverse=False, plotin=None, plotout=None, autobins=True, binrule_fallback="scott", uncert="quantiles", kde_points=250, kde_method="silverman", nameinplot=False, **hist2d_kwargs, ): """ Make a corner plot showing the projections of a data set in a multi-dimensional space. kwargs are passed to hist2d() or used for `matplotlib` styling. Parameters ---------- xs : array_like[nsamples, ndim] The samples. This should be a 1- or 2-dimensional array. For a 1-D array this results in a simple histogram. For a 2-D array, the zeroth axis is the list of samples and the next axis are the dimensions of the space. smooth: float, optional The standard deviation for Gaussian kernel passed to `scipy.ndimage.gaussian_filter` to smooth the 2-D histograms. If `None` (default), no smoothing is applied. smooth1d: str or float, optional If "kde", a Kernel Density Estimate (KDE) is used in the 1D histograms. Otherwise, as ``smooth`` above, but for the 1D histograms. labels : None or iterable (ndim,), optional A list of names for the dimensions. label_kwargs : dict, optional Any extra keyword arguments to send to the `set_xlabel` and `set_ylabel` methods. show_titles : bool, optional Displays a title above each 1-D histogram showing the 0.5 quantile with the upper and lower errors supplied by the quantiles argument. title_fmt : string, optional The format string for the quantiles given in titles. If you explicitly set ``show_titles=True`` and ``title_fmt=None``, the labels will be shown as the titles. (default: ``.2f``) title_kwargs : dict, optional Any extra keyword arguments to send to the `set_title` command. truth_color : str or dict, optional A ``matplotlib`` style color for the truths makers or a dict with the colors with keys being the labels. scale_hist : bool, optional Should the 1-D histograms be scaled in such a way that the zero line is visible? quantiles : iterable, optional A list of fractional quantiles to show on the 1-D histograms as vertical dashed lines. max_n_ticks: int, optional Maximum number of ticks to try to use use_math_text : bool, optional If true, then axis tick labels for very large or small exponents will be displayed as powers of 10 rather than using `e`. reverse : bool, optional If true, plot the corner plot starting in the upper-right corner instead of the usual bottom-left corner plotin : iterable (ndim,), optional A list of reference input values to indicate on the plots. plotout : iterable (ndim,), optional A list of reference output values to indicate on the plots. autobins : bool or int or array_like[ndim,] optional If True, automatically determine bin edges. Otherwise, the number of bins to use in histograms, either as a fixed value for all dimensions, or as a list of integers for each dimension. binrule_fallback : str, optional In case auto-binning fails for the posterior distribution (usually due to too many zeros, which causes a memory leak), use this rule for posterior binning instead. uncert : str, optional If uncertainties are given in terms of 'quantiles' or 'std' (standard deviation), included here to change formatting when reporting inferred quantities in titles. kde_points : float, optional Number of points to sample the KDE on. The higher number of points, the smoother the KDE, but the longer computation time. kde_method : str, optional Method used to select the bandwidth in the gaussian KDE. Passed directly to the routine in SciPy. Default is Scott's rule. nameinplot : str, bool Star identifier if it is to be included in the figure. **hist2d_kwargs, optional Any remaining keyword arguments are sent to `corner.hist2d` to generate the 2-D histogram plots. """ if quantiles is None: quantiles = [] if title_kwargs is None: title_kwargs = {} if label_kwargs is None: label_kwargs = {} # Deal with 1D sample lists. xs = np.atleast_1d(xs) if len(xs.shape) == 1: xs = np.atleast_2d(xs) else: assert len(xs.shape) == 2, "The input sample array must be 1- or 2-D." xs = xs.T assert xs.shape[0] <= xs.shape[1], ( "I don't believe that you want more " "dimensions than samples!" ) # Parse the parameter ranges. # --> Set dummy ranges [v-1, v+1] for parameters that never change.. mins = np.array([x.min() for x in xs]) maxs = np.array([x.max() for x in xs]) m = mins == maxs mins[m] -= 1 maxs[m] += 1 prange = np.transpose((mins, maxs)).tolist() if len(prange) != xs.shape[0]: raise ValueError("Dimension mismatch between samples and range") # Parse the bin specifications. if isinstance(autobins, bool) and autobins: bins = [] for i, x in enumerate(xs): try: xbin = np.histogram_bin_edges(x, bins="auto", range=np.sort(prange[i])) if len(xbin) > 1000: print( "Parameter {0} resulted in {1} bins, raising MemoryError".format( labels[i], len(xbin) ) ) raise MemoryError except MemoryError: print( "WARNING! Using 'auto' as bin-rule causes a memory crash!" "Switching to '{0}'".format(binrule_fallback), "for the parameter '{0}'!".format(labels[i]), ) xbin = np.histogram_bin_edges( x, bins=binrule_fallback, range=np.sort(prange[i]) ) bins.append(xbin) else: try: bins = [int(autobins) for _ in prange] except TypeError: if len(autobins) != len(prange): raise ValueError("Dimension mismatch between bins and range") # Some magic numbers for pretty axis layout. K = len(xs) factor = 2.0 # size of one side of one panel if reverse: lbdim = 0.2 * factor # size of left/bottom margin trdim = 0.5 * factor # size of top/right margin else: lbdim = 0.5 * factor # size of left/bottom margin trdim = 0.2 * factor # size of top/right margin whspace = 0.05 # w/hspace size plotdim = factor * K + factor * (K - 1.0) * whspace dim = lbdim + plotdim + trdim # Create a new figure fig, axes = plt.subplots(K, K, figsize=(dim, dim)) # Format the figure. lb = lbdim / dim tr = (lbdim + plotdim) / dim fig.subplots_adjust( left=lb, bottom=lb, right=tr, top=tr, wspace=whspace, hspace=whspace ) # Set up the default histogram keywords. color = "k" hist_kwargs = {"color": color} if smooth1d is None: hist_kwargs["histtype"] = hist_kwargs.get("histtype", "step") for i, x in enumerate(xs): # Deal with masked arrays. if hasattr(x, "compressed"): x = x.compressed() if np.shape(xs)[0] == 1: ax = axes else: if reverse: ax = axes[K - i - 1, K - i - 1] else: ax = axes[i, i] if isinstance(truth_color, str): tcolor = truth_color else: tcolor = lighten_color(truth_color[i], 0.5) # Plot the histograms. if smooth1d is None: n, _, _ = ax.hist(x, bins=bins[i], range=np.sort(prange[i]), **hist_kwargs) else: if gaussian_filter is None: raise ImportError("Please install scipy for smoothing") n, b = np.histogram(x, bins=bins[i], range=np.sort(prange[i])) if smooth1d != "kde": n = gaussian_filter(n, smooth1d) x0 = np.array(list(zip(b[:-1], b[1:]))).flatten() y0 = np.array(list(zip(n, n))).flatten() else: try: kernel = gaussian_kde(x, bw_method=kde_method) except np.linalg.LinAlgError: print("WARNING! Unable to create KDE. Skipping plot...") raise x0 = np.linspace(np.amin(x), np.amax(x), num=kde_points) y0 = kernel(x0) y0 /= np.amax(y0) n = gaussian_filter(n, 1) x0_hist = np.array(list(zip(b[:-1], b[1:]))).flatten() y0_hist = np.array(list(zip(n, n))).flatten() / np.amax(n) ax.fill_between( x0_hist, y0_hist, y2=-1, interpolate=True, color=tcolor, alpha=0.15 ) ax.plot(x0, y0, **hist_kwargs) ax.fill_between(x0, y0, y2=-1, interpolate=True, color=tcolor, alpha=0.15) # Plot quantiles q = plotout[3 * i] p = plotout[3 * i + 1] m = plotout[3 * i + 2] ax.axvline(q, ls="solid", color=color) ax.axvline(q + p, ls="dashed", color=color) ax.axvline(q - m, ls="dashed", color=color) # Plot input parameters when they are given if plotin is not None: if plotin[2 * i] != -9999: inx = plotin[2 * i] instd = plotin[2 * i + 1] ax.axvline(inx, ls="dashdot", color="0.4") ax.axvline(inx - instd, ls="dotted", color="0.4") ax.axvline(inx + instd, ls="dotted", color="0.4") if show_titles: title = None if title_fmt is not None: # Format the quantile display. fmt = "{{0:{0}}}".format(title_fmt).format if uncert == "quantiles": title = r"${{{0}}}_{{-{1}}}^{{+{2}}}$" title = title.format(fmt(q), fmt(m), fmt(p)) else: title = r"${{{0}}}\pm{{{1}}}$" title = title.format(fmt(q), fmt(p)) # Add in the column name if it's given. if labels is not None: title = "{0} = {1}".format(labels[i], title) elif labels is not None: title = "{0}".format(labels[i]) if title is not None: if reverse: ax.set_xlabel(title, **title_kwargs) else: ax.set_title(title, **title_kwargs) # Set up the axes. ax.set_xlim(prange[i]) if scale_hist: maxn = np.max(n) ax.set_ylim(-0.1 * maxn, 1.05 * maxn) elif smooth1d == "kde": maxn = np.amax(y0) ax.set_ylim(-0.1 * maxn, 1.05 * maxn) else: ax.set_ylim(0, 1.05 * np.max(n)) ax.set_yticklabels([]) if max_n_ticks == 0: ax.xaxis.set_major_locator(NullLocator()) ax.yaxis.set_major_locator(NullLocator()) else: ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower")) ax.yaxis.set_major_locator(NullLocator()) if i < K - 1: ax.set_xticklabels([]) else: if reverse: ax.xaxis.tick_top() [l.set_rotation(45) for l in ax.get_xticklabels()] if labels is not None: if reverse: ax.set_title(labels[i], y=1.25, **label_kwargs) else: ax.set_xlabel(labels[i], **label_kwargs) ax.xaxis.set_label_coords(0.5, -0.35) # use MathText for axes ticks ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=use_math_text)) for j, y in enumerate(xs): if np.shape(xs)[0] == 1: ax = axes else: if reverse: ax = axes[K - i - 1, K - j - 1] else: ax = axes[i, j] if j > i: ax.set_frame_on(False) ax.set_xticks([]) ax.set_yticks([]) if j == K - 1 and i == 0: ax.set_title(nameinplot if nameinplot else "") continue elif j == i: continue if isinstance(truth_color, str): tcolor = truth_color else: tcolor = lighten_color(truth_color[j], 0.5) # Deal with masked arrays. if hasattr(y, "compressed"): y = y.compressed() hist2d( y, x, ax=ax, range=[prange[j], prange[i]], color=tcolor, smooth=smooth, bins=[bins[j], bins[i]], **hist2d_kwargs, ) if max_n_ticks == 0: ax.xaxis.set_major_locator(NullLocator()) ax.yaxis.set_major_locator(NullLocator()) else: ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower")) ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower")) if i < K - 1: ax.set_xticklabels([]) else: if reverse: ax.xaxis.tick_top() [l.set_rotation(45) for l in ax.get_xticklabels()] if labels is not None: ax.set_xlabel(labels[j], **label_kwargs) if reverse: ax.xaxis.set_label_coords(0.5, 1.4) else: ax.xaxis.set_label_coords(0.5, -0.35) # use MathText for axes ticks ax.xaxis.set_major_formatter(ScalarFormatter(useMathText=use_math_text)) if j > 0: ax.set_yticklabels([]) else: if reverse: ax.yaxis.tick_right() [l.set_rotation(45) for l in ax.get_yticklabels()] if labels is not None: if reverse: ax.set_ylabel(labels[i], rotation=-90, **label_kwargs) ax.yaxis.set_label_coords(1.3, 0.5) else: ax.set_ylabel(labels[i], **label_kwargs) ax.yaxis.set_label_coords(-0.35, 0.5) # use MathText for axes ticks ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=use_math_text)) return fig
[docs] def hist2d( x, y, bins=20, prange=None, weights=None, levels=None, smooth=None, ax=None, color=None, plot_datapoints=True, plot_density=True, plot_contours=True, no_fill_contours=True, fill_contours=True, contour_kwargs=None, contourf_kwargs=None, data_kwargs=None, **kwargs, ): """ Plot a 2-D histogram of samples. Parameters ---------- x : array_like[nsamples,] The samples. y : array_like[nsamples,] The samples. levels : array_like The contour levels to draw. ax : matplotlib.Axes A axes instance on which to add the 2-D histogram. plot_datapoints : bool Draw the individual data points. plot_density : bool Draw the density colormap. plot_contours : bool Draw the contours. no_fill_contours : bool Add no filling at all to the contours (unlike setting ``fill_contours=False``, which still adds a white fill at the densest points). fill_contours : bool Fill the contours. contour_kwargs : dict Any additional keyword arguments to pass to the `contour` method. contourf_kwargs : dict Any additional keyword arguments to pass to the `contourf` method. data_kwargs : dict Any additional keyword arguments to pass to the `plot` method when adding the individual data points. """ if ax is None: ax = plt.gca() # Set the default range based on the data range if not provided. if prange is None: prange = [[x.min(), x.max()], [y.min(), y.max()]] # Set up the default plotting arguments. if color is None: color = "k" # Choose the default "sigma" contour levels, # if levels is None: levels = 1.0 - np.exp(-0.5 * np.arange(0.5, 2.1, 0.5) ** 2) # This is the color map for the density plot, over-plotted to indicate the # density of the points near the center. density_cmap = LinearSegmentedColormap.from_list( "density_cmap", [color, (1, 1, 1, 0)] ) # This color map is used to hide the points at the high density areas. white_cmap = LinearSegmentedColormap.from_list( "white_cmap", [(1, 1, 1), (1, 1, 1)], N=2 ) # This "color map" is the list of colors for the contour levels if the # contours are filled. rgba_color = colorConverter.to_rgba(color) contour_cmap = [list(rgba_color) for l in levels] + [rgba_color] for i, l in enumerate(levels): contour_cmap[i][-1] *= float(i) / (len(levels) + 1) # We'll make the 2D histogram to directly estimate the density. H, X, Y = np.histogram2d( x.flatten(), y.flatten(), bins=bins, range=list(map(np.sort, prange)), weights=weights, ) if smooth is not None: if gaussian_filter is None: raise ImportError("Please install scipy for smoothing") H = gaussian_filter(H, smooth) # Compute the density levels. if not (np.all(x == x[0]) or np.all(y == y[0])): if plot_contours or plot_density: Hflat = H.flatten() inds = np.argsort(Hflat)[::-1] Hflat = Hflat[inds] sm = np.cumsum(Hflat) sm /= sm[-1] V = np.empty(len(levels)) for i, v0 in enumerate(levels): try: V[i] = Hflat[sm <= v0][-1] except Exception: V[i] = Hflat[0] V.sort() m = np.diff(V) == 0 if np.any(m): logging.warning("Too few points to create valid contours") while np.any(m): V[np.where(m)[0][0]] *= 1.0 - 1e-4 m = np.diff(V) == 0 V.sort() # Compute the bin centers. X1, Y1 = 0.5 * (X[1:] + X[:-1]), 0.5 * (Y[1:] + Y[:-1]) # Extend the array for the sake of the contours at the plot edges. H2 = H.min() + np.zeros((H.shape[0] + 4, H.shape[1] + 4)) H2[2:-2, 2:-2] = H H2[2:-2, 1] = H[:, 0] H2[2:-2, -2] = H[:, -1] H2[1, 2:-2] = H[0] H2[-2, 2:-2] = H[-1] H2[1, 1] = H[0, 0] H2[1, -2] = H[0, -1] H2[-2, 1] = H[-1, 0] H2[-2, -2] = H[-1, -1] X2 = np.concatenate( [ X1[0] + np.array([-2, -1]) * np.diff(X1[:2]), X1, X1[-1] + np.array([1, 2]) * np.diff(X1[-2:]), ] ) Y2 = np.concatenate( [ Y1[0] + np.array([-2, -1]) * np.diff(Y1[:2]), Y1, Y1[-1] + np.array([1, 2]) * np.diff(Y1[-2:]), ] ) if plot_datapoints: if data_kwargs is None: data_kwargs = dict() data_kwargs["color"] = data_kwargs.get("color", color) data_kwargs["ms"] = data_kwargs.get("ms", 2.0) data_kwargs["mec"] = data_kwargs.get("mec", "none") data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1) ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs) # Plot the base fill to hide the densest data points. if not (np.all(x == x[0]) or np.all(y == y[0])): if (plot_contours or plot_density) and not no_fill_contours: ax.contourf( X2, Y2, H2.T, [V.min(), H.max()], cmap=white_cmap, antialiased=False ) if plot_contours and fill_contours: if contourf_kwargs is None: contourf_kwargs = {} contourf_kwargs["colors"] = contourf_kwargs.get("colors", contour_cmap) contourf_kwargs["antialiased"] = contourf_kwargs.get("antialiased", False) ax.contourf( X2, Y2, H2.T, np.concatenate([[0], V, [H.max() * (1 + 1e-4)]]), **contourf_kwargs, ) # Plot the density map. This can't be plotted at the same time as the # contour fills. elif plot_density: ax.pcolor(X, Y, H.max() - H.T, cmap=density_cmap) # Plot the contour edge colors. if plot_contours: if contour_kwargs is None: contour_kwargs = dict() contour_kwargs["colors"] = contour_kwargs.get("colors", color) ax.contour(X2, Y2, H2.T, V, **contour_kwargs) ax.set_xlim(prange[0]) ax.set_ylim(prange[1])
def lighten_color(color, amount=0.5): """ Lightens the given color by multiplying (1-luminosity) by the given amount. Input can be matplotlib color string, hex string, or RGB tuple. Examples: >> lighten_color('g', 0.3) >> lighten_color('#F034A3', 0.6) >> lighten_color((.3,.55,.1), 0.5) """ try: c = mc.cnames[color] except: c = color c = colorsys.rgb_to_hls(*mc.to_rgb(c)) return colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2])