"""
Production of corner plots.
Modified from a fork of https://github.com/dfm/corner.py .
Original code: Copyright (c) 2013-2020 Daniel Foreman-Mackey
Full license: https://github.com/dfm/corner.py/blob/main/LICENSE
This modified version:
- Add the observed quantities to the corner plots
- Colours for the plots
- Add KDE to non-contour panels
- Cleaned
"""
import colorsys
import logging
import matplotlib as mpl
import matplotlib.colors as mc
import numpy as np
from matplotlib.colors import LinearSegmentedColormap, colorConverter
from matplotlib.ticker import MaxNLocator, NullLocator, ScalarFormatter
from scipy.ndimage import gaussian_filter # type: ignore[import]
from scipy.stats import gaussian_kde # type: ignore[import]
mpl.use("Agg")
import matplotlib.pyplot as plt
fontdic = {"size": 12}
__all__ = ["corner", "hist2d"]
[docs]
def corner(
xs,
smooth=None,
smooth1d="kde",
labels=None,
label_kwargs=fontdic,
show_titles=False,
title_fmt=".3f",
title_kwargs=fontdic,
truth_color="#4682b4",
scale_hist=False,
quantiles=None,
max_n_ticks=5,
use_math_text=False,
reverse=False,
plotin=None,
plotout=None,
autobins=True,
binrule_fallback="scott",
uncert="quantiles",
kde_points=250,
kde_method="silverman",
nameinplot=False,
**hist2d_kwargs,
):
"""
Make a corner plot showing the projections of a data set in a multi-dimensional
space. kwargs are passed to hist2d() or used for `matplotlib` styling.
Parameters
----------
xs : array_like[nsamples, ndim]
The samples. This should be a 1- or 2-dimensional array. For a 1-D
array this results in a simple histogram. For a 2-D array, the zeroth
axis is the list of samples and the next axis are the dimensions of
the space.
smooth: float, optional
The standard deviation for Gaussian kernel passed to
`scipy.ndimage.gaussian_filter` to smooth the 2-D histograms. If `None`
(default), no smoothing is applied.
smooth1d: str or float, optional
If "kde", a Kernel Density Estimate (KDE) is used in the 1D histograms.
Otherwise, as ``smooth`` above, but for the 1D histograms.
labels : None or iterable (ndim,), optional
A list of names for the dimensions.
label_kwargs : dict, optional
Any extra keyword arguments to send to the `set_xlabel` and
`set_ylabel` methods.
show_titles : bool, optional
Displays a title above each 1-D histogram showing the 0.5 quantile
with the upper and lower errors supplied by the quantiles argument.
title_fmt : string, optional
The format string for the quantiles given in titles. If you explicitly
set ``show_titles=True`` and ``title_fmt=None``, the labels will be
shown as the titles. (default: ``.2f``)
title_kwargs : dict, optional
Any extra keyword arguments to send to the `set_title` command.
truth_color : str or dict, optional
A ``matplotlib`` style color for the truths makers or a dict with the colors
with keys being the labels.
scale_hist : bool, optional
Should the 1-D histograms be scaled in such a way that the zero line
is visible?
quantiles : iterable, optional
A list of fractional quantiles to show on the 1-D histograms as
vertical dashed lines.
max_n_ticks: int, optional
Maximum number of ticks to try to use
use_math_text : bool, optional
If true, then axis tick labels for very large or small exponents will
be displayed as powers of 10 rather than using `e`.
reverse : bool, optional
If true, plot the corner plot starting in the upper-right corner instead
of the usual bottom-left corner
plotin : iterable (ndim,), optional
A list of reference input values to indicate on the plots.
plotout : iterable (ndim,), optional
A list of reference output values to indicate on the plots.
autobins : bool or int or array_like[ndim,] optional
If True, automatically determine bin edges. Otherwise, the number of bins to use
in histograms, either as a fixed value for all dimensions, or as a list of
integers for each dimension.
binrule_fallback : str, optional
In case auto-binning fails for the posterior distribution (usually due to too
many zeros, which causes a memory leak), use this rule for posterior binning
instead.
uncert : str, optional
If uncertainties are given in terms of 'quantiles' or 'std' (standard
deviation), included here to change formatting when reporting inferred
quantities in titles.
kde_points : float, optional
Number of points to sample the KDE on. The higher number of points, the smoother
the KDE, but the longer computation time.
kde_method : str, optional
Method used to select the bandwidth in the gaussian KDE. Passed directly to
the routine in SciPy. Default is Scott's rule.
nameinplot : str, bool
Star identifier if it is to be included in the figure.
**hist2d_kwargs, optional
Any remaining keyword arguments are sent to `corner.hist2d` to generate
the 2-D histogram plots.
"""
if quantiles is None:
quantiles = []
if title_kwargs is None:
title_kwargs = {}
if label_kwargs is None:
label_kwargs = {}
formatter = ScalarFormatter()
formatter.set_scientific("%.2e")
formatter.set_useMathText(True)
formatter.set_powerlimits((-2, 4))
# Deal with 1D sample lists.
xs = np.atleast_1d(xs)
if len(xs.shape) == 1:
xs = np.atleast_2d(xs)
else:
assert len(xs.shape) == 2, "The input sample array must be 1- or 2-D."
xs = xs.T
assert (
xs.shape[0] <= xs.shape[1]
), "I don't believe that you want more dimensions than samples!"
# Parse the parameter ranges.
# --> Set dummy ranges [v-1, v+1] for parameters that never change..
mins = np.array([x.min() for x in xs])
maxs = np.array([x.max() for x in xs])
m = mins == maxs
mins[m] -= 1
maxs[m] += 1
prange = np.transpose((mins, maxs)).tolist()
if len(prange) != xs.shape[0]:
raise ValueError("Dimension mismatch between samples and range")
# Parse the bin specifications.
if isinstance(autobins, bool) and autobins:
bins = []
for i, x in enumerate(xs):
try:
xbin = np.histogram_bin_edges(x, bins="auto", range=np.sort(prange[i]))
if len(xbin) > 1000:
print(
f"Parameter {labels[i]} resulted in {len(xbin)} bins, raising MemoryError"
)
raise MemoryError
except MemoryError:
print(
"WARNING! Using 'auto' as bin-rule causes a memory crash!"
f"Switching to '{binrule_fallback}'",
f"for the parameter '{labels[i]}'!",
)
xbin = np.histogram_bin_edges(
x, bins=binrule_fallback, range=np.sort(prange[i])
)
bins.append(xbin)
else:
try:
bins = [int(autobins) for _ in prange]
except TypeError:
if len(autobins) != len(prange):
raise ValueError("Dimension mismatch between bins and range")
# Some magic numbers for pretty axis layout.
K = len(xs)
factor = 2.0 # size of one side of one panel
if reverse:
lbdim = 0.2 * factor # size of left/bottom margin
trdim = 0.5 * factor # size of top/right margin
else:
lbdim = 0.5 * factor # size of left/bottom margin
trdim = 0.2 * factor # size of top/right margin
whspace = 0.05 # w/hspace size
plotdim = factor * K + factor * (K - 1.0) * whspace
dim = lbdim + plotdim + trdim
# Create a new figure
fig, axes = plt.subplots(K, K, figsize=(dim, dim))
# Format the figure.
lb = lbdim / dim
tr = (lbdim + plotdim) / dim
fig.subplots_adjust(
left=lb, bottom=lb, right=tr, top=tr, wspace=whspace, hspace=whspace
)
# Set up the default histogram keywords.
color = "k"
hist_kwargs = {"color": color}
if smooth1d is None:
hist_kwargs["histtype"] = hist_kwargs.get("histtype", "step")
for i, x in enumerate(xs):
# Deal with masked arrays.
if hasattr(x, "compressed"):
x = x.compressed()
if np.shape(xs)[0] == 1:
ax = axes
elif reverse:
ax = axes[K - i - 1, K - i - 1]
else:
ax = axes[i, i]
if isinstance(truth_color, str):
tcolor = truth_color
else:
tcolor = lighten_color(truth_color[i], 0.5)
# Plot the histograms.
if smooth1d is None:
n, _, _ = ax.hist(x, bins=bins[i], range=np.sort(prange[i]), **hist_kwargs)
else:
if gaussian_filter is None:
raise ImportError("Please install scipy for smoothing")
n, b = np.histogram(x, bins=bins[i], range=np.sort(prange[i]))
if smooth1d != "kde":
n = gaussian_filter(n, smooth1d)
x0 = np.array(list(zip(b[:-1], b[1:]))).flatten()
y0 = np.array(list(zip(n, n))).flatten()
ax.plot(x0, y0, **hist_kwargs)
ax.fill_between(
x0, y0, y2=-1, interpolate=True, color=tcolor, alpha=0.15
)
else:
try:
kernel = gaussian_kde(x, bw_method=kde_method)
x0 = np.linspace(np.amin(x), np.amax(x), num=kde_points)
y0 = kernel(x0)
y0 /= np.amax(y0)
n = gaussian_filter(n, 1)
ax.plot(x0, y0, **hist_kwargs)
ax.fill_between(
x0, y0, y2=-1, interpolate=True, color=tcolor, alpha=0.15
)
except np.linalg.LinAlgError:
print("WARNING! Unable to create a KDE...")
x0_hist = np.array(list(zip(b[:-1], b[1:]))).flatten()
y0_hist = np.array(list(zip(n, n))).flatten() / np.amax(n)
ax.fill_between(
x0_hist, y0_hist, y2=-1, interpolate=True, color=tcolor, alpha=0.15
)
# Plot quantiles
if plotout is not None:
q = plotout[3 * i]
p = plotout[3 * i + 1]
m = plotout[3 * i + 2]
ax.axvline(q, ls="solid", color=color)
ax.axvline(q + p, ls="dashed", color=color)
ax.axvline(q - m, ls="dashed", color=color)
# Plot input parameters when they are given
if plotin is not None:
if plotin[2 * i] != -9999:
inx = plotin[2 * i]
instd = plotin[2 * i + 1]
ax.axvline(inx, ls="dashdot", color="0.4")
ax.axvline(inx - instd, ls="dotted", color="0.4")
ax.axvline(inx + instd, ls="dotted", color="0.4")
if show_titles:
title = None
if title_fmt is not None:
# Format the quantile display.
fmt = f"{{0:{title_fmt}}}".format
if uncert == "quantiles":
title = r"${{{0}}}_{{-{1}}}^{{+{2}}}$"
title = title.format(fmt(q), fmt(m), fmt(p))
else:
title = r"${{{0}}}\pm{{{1}}}$"
title = title.format(fmt(q), fmt(p))
# Add in the column name if it's given.
if labels is not None:
title = f"{labels[i]} = {title}"
elif labels is not None:
title = f"{labels[i]}"
if title is not None:
if reverse:
ax.set_xlabel(title, **title_kwargs)
else:
ax.set_title(title, **title_kwargs)
# Set up the axes.
ax.set_xlim(prange[i])
if scale_hist:
maxn = np.max(n)
ax.set_ylim(-0.1 * maxn, 1.05 * maxn)
elif smooth1d == "kde":
maxn = np.amax(y0)
ax.set_ylim(-0.1 * maxn, 1.05 * maxn)
else:
ax.set_ylim(0, 1.05 * np.max(n))
ax.set_yticklabels([])
if max_n_ticks == 0:
ax.xaxis.set_major_locator(NullLocator())
ax.yaxis.set_major_locator(NullLocator())
else:
ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower"))
ax.yaxis.set_major_locator(NullLocator())
if i < K - 1:
ax.set_xticklabels([])
else:
if reverse:
ax.xaxis.tick_top()
[l.set_rotation(45) for l in ax.get_xticklabels()]
if labels is not None:
if reverse:
ax.set_title(labels[i], y=1.25, **label_kwargs)
else:
ax.set_xlabel(labels[i], **label_kwargs)
ax.xaxis.set_label_coords(0.5, -0.35)
# use MathText for axes ticks
ax.xaxis.set_major_formatter(formatter)
for j, y in enumerate(xs):
if np.shape(xs)[0] == 1:
ax = axes
elif reverse:
ax = axes[K - i - 1, K - j - 1]
else:
ax = axes[i, j]
if j > i:
ax.set_frame_on(False)
ax.set_xticks([])
ax.set_yticks([])
if j == K - 1 and i == 0:
ax.set_title(nameinplot if nameinplot else "")
continue
if j == i:
continue
if isinstance(truth_color, str):
tcolor = truth_color
else:
tcolor = lighten_color(truth_color[j], 0.5)
# Deal with masked arrays.
if hasattr(y, "compressed"):
y = y.compressed()
hist2d(
y,
x,
ax=ax,
range=[prange[j], prange[i]],
color=tcolor,
smooth=smooth,
bins=[bins[j], bins[i]],
**hist2d_kwargs,
)
if max_n_ticks == 0:
ax.xaxis.set_major_locator(NullLocator())
ax.yaxis.set_major_locator(NullLocator())
else:
ax.xaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower"))
ax.yaxis.set_major_locator(MaxNLocator(max_n_ticks, prune="lower"))
if i < K - 1:
ax.set_xticklabels([])
else:
if reverse:
ax.xaxis.tick_top()
[l.set_rotation(45) for l in ax.get_xticklabels()]
if labels is not None:
ax.set_xlabel(labels[j], **label_kwargs)
if reverse:
ax.xaxis.set_label_coords(0.5, 1.4)
else:
ax.xaxis.set_label_coords(0.5, -0.35)
# use MathText for axes ticks
ax.xaxis.set_major_formatter(formatter)
if j > 0:
ax.set_yticklabels([])
else:
if reverse:
ax.yaxis.tick_right()
[l.set_rotation(45) for l in ax.get_yticklabels()]
if labels is not None:
if reverse:
ax.set_ylabel(labels[i], rotation=-90, **label_kwargs)
ax.yaxis.set_label_coords(1.3, 0.5)
else:
ax.set_ylabel(labels[i], **label_kwargs)
ax.yaxis.set_label_coords(-0.35, 0.5)
# use MathText for axes ticks
ax.yaxis.set_major_formatter(formatter)
return fig
[docs]
def hist2d(
x,
y,
bins=20,
prange=None,
weights=None,
levels=None,
smooth=None,
ax=None,
color=None,
plot_datapoints=True,
plot_density=True,
plot_contours=True,
no_fill_contours=True,
fill_contours=True,
contour_kwargs=None,
contourf_kwargs=None,
data_kwargs=None,
**kwargs,
):
"""
Plot a 2-D histogram of samples.
Parameters
----------
x : array_like[nsamples,]
The samples.
y : array_like[nsamples,]
The samples.
levels : array_like
The contour levels to draw.
ax : matplotlib.Axes
A axes instance on which to add the 2-D histogram.
plot_datapoints : bool
Draw the individual data points.
plot_density : bool
Draw the density colormap.
plot_contours : bool
Draw the contours.
no_fill_contours : bool
Add no filling at all to the contours (unlike setting
``fill_contours=False``, which still adds a white fill at the densest
points).
fill_contours : bool
Fill the contours.
contour_kwargs : dict
Any additional keyword arguments to pass to the `contour` method.
contourf_kwargs : dict
Any additional keyword arguments to pass to the `contourf` method.
data_kwargs : dict
Any additional keyword arguments to pass to the `plot` method when
adding the individual data points.
"""
if ax is None:
ax = plt.gca()
# Set the default range based on the data range if not provided.
if prange is None:
prange = [[x.min(), x.max()], [y.min(), y.max()]]
# Set up the default plotting arguments.
if color is None:
color = "k"
# Choose the default "sigma" contour levels,
# https://corner.readthedocs.io/en/latest/pages/sigmas.html
if levels is None:
levels = 1.0 - np.exp(-0.5 * np.arange(0.5, 2.1, 0.5) ** 2)
# This is the color map for the density plot, over-plotted to indicate the
# density of the points near the center.
density_cmap = LinearSegmentedColormap.from_list(
"density_cmap", [color, (1, 1, 1, 0)]
)
# This color map is used to hide the points at the high density areas.
white_cmap = LinearSegmentedColormap.from_list(
"white_cmap", [(1, 1, 1), (1, 1, 1)], N=2
)
# This "color map" is the list of colors for the contour levels if the
# contours are filled.
rgba_color = colorConverter.to_rgba(color)
contour_cmap = [list(rgba_color) for l in levels] + [rgba_color]
for i, _l in enumerate(levels):
contour_cmap[i][-1] *= float(i) / (len(levels) + 1)
# We'll make the 2D histogram to directly estimate the density.
H, X, Y = np.histogram2d(
x.flatten(),
y.flatten(),
bins=bins,
range=list(map(np.sort, prange)),
weights=weights,
)
if smooth is not None:
if gaussian_filter is None:
raise ImportError("Please install scipy for smoothing")
H = gaussian_filter(H, smooth)
# Compute the density levels.
if not (np.all(x == x[0]) or np.all(y == y[0])):
if plot_contours or plot_density:
Hflat = H.flatten()
inds = np.argsort(Hflat)[::-1]
Hflat = Hflat[inds]
sm = np.cumsum(Hflat)
sm /= sm[-1]
V = np.empty(len(levels))
for i, v0 in enumerate(levels):
try:
V[i] = Hflat[sm <= v0][-1]
except Exception:
V[i] = Hflat[0]
V.sort()
m = np.diff(V) == 0
if np.any(m):
logging.warning("Too few points to create valid contours")
while np.any(m):
V[np.where(m)[0][0]] *= 1.0 - 1e-4
m = np.diff(V) == 0
V.sort()
# Compute the bin centers.
X1, Y1 = 0.5 * (X[1:] + X[:-1]), 0.5 * (Y[1:] + Y[:-1])
# Extend the array for the sake of the contours at the plot edges.
H2 = H.min() + np.zeros((H.shape[0] + 4, H.shape[1] + 4))
H2[2:-2, 2:-2] = H
H2[2:-2, 1] = H[:, 0]
H2[2:-2, -2] = H[:, -1]
H2[1, 2:-2] = H[0]
H2[-2, 2:-2] = H[-1]
H2[1, 1] = H[0, 0]
H2[1, -2] = H[0, -1]
H2[-2, 1] = H[-1, 0]
H2[-2, -2] = H[-1, -1]
X2 = np.concatenate(
[
X1[0] + np.array([-2, -1]) * np.diff(X1[:2]),
X1,
X1[-1] + np.array([1, 2]) * np.diff(X1[-2:]),
]
)
Y2 = np.concatenate(
[
Y1[0] + np.array([-2, -1]) * np.diff(Y1[:2]),
Y1,
Y1[-1] + np.array([1, 2]) * np.diff(Y1[-2:]),
]
)
if plot_datapoints:
if data_kwargs is None:
data_kwargs = {}
data_kwargs["color"] = data_kwargs.get("color", color)
data_kwargs["ms"] = data_kwargs.get("ms", 2.0)
data_kwargs["mec"] = data_kwargs.get("mec", "none")
data_kwargs["alpha"] = data_kwargs.get("alpha", 0.1)
ax.plot(x, y, "o", zorder=-1, rasterized=True, **data_kwargs)
# Plot the base fill to hide the densest data points.
if not (np.all(x == x[0]) or np.all(y == y[0])):
if (plot_contours or plot_density) and not no_fill_contours:
ax.contourf(
X2, Y2, H2.T, [V.min(), H.max()], cmap=white_cmap, antialiased=False
)
if plot_contours and fill_contours:
if contourf_kwargs is None:
contourf_kwargs = {}
contourf_kwargs["colors"] = contourf_kwargs.get("colors", contour_cmap)
contourf_kwargs["antialiased"] = contourf_kwargs.get("antialiased", False)
ax.contourf(
X2,
Y2,
H2.T,
np.concatenate([[0], V, [H.max() * (1 + 1e-4)]]),
**contourf_kwargs,
)
# Plot the density map. This can't be plotted at the same time as the
# contour fills.
elif plot_density:
ax.pcolor(X, Y, H.max() - H.T, cmap=density_cmap)
# Plot the contour edge colors.
if plot_contours:
if contour_kwargs is None:
contour_kwargs = {}
contour_kwargs["colors"] = contour_kwargs.get("colors", color)
ax.contour(X2, Y2, H2.T, V, **contour_kwargs)
# Set axis limits if plotting dimension
if not np.all(x == x[0]):
ax.set_xlim(prange[0])
if not np.all(y == y[0]):
ax.set_ylim(prange[1])
return ax
def lighten_color(color, amount=0.5):
"""
Lightens the given color by multiplying (1-luminosity) by the given amount.
Input can be matplotlib color string, hex string, or RGB tuple.
Examples:
>> lighten_color('g', 0.3)
>> lighten_color('#F034A3', 0.6)
>> lighten_color((.3,.55,.1), 0.5)
"""
try:
c = mc.cnames[color]
except Exception:
c = color
c = colorsys.rgb_to_hls(*mc.to_rgb(c))
return colorsys.hls_to_rgb(c[0], max(0, min(1, amount * c[1])), c[2])