"""
Calculation and generation of output, and driver for producing plots
"""
import os
import copy
from io import IOBase
from copy import deepcopy
import numpy as np
import matplotlib
import basta.fileio as fio
from basta.constants import sydsun as sydc
from basta.constants import parameters, statdata
from basta.utils_distances import compute_distance_from_mag
from basta.distances import get_absorption, get_EBV_along_LOS
from basta import utils_general as util
from basta import stats, plot_corner, plot_kiel
from basta.downloader import get_basta_dir
# Change matplotlib backend before loading pyplot
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# Set the style of all plots
plt.style.use(os.path.join(get_basta_dir(), "plots.mplstyle"))
# Define a color dictionary for easier change of color
colors = {"l0": "#D55E00", "l1": "#009E73", "l2": "#0072B2"}
[docs]
def compute_posterior(
starid,
selectedmodels,
Grid,
inputparams,
outfilename,
gridtype,
debug=False,
developermode=False,
validationmode=False,
compareinputoutput=False,
):
"""
This function computes the posterior distributions and produce plots.
Parameters
----------
starid : str
Unique identifier of current target.
selectedmodels : dict
Contains information on all models with a non-zero likelihood.
Grid : hdf5 object
The already loaded grid, containing the tracks/isochrones.
inputparams : dict
Dict containing input from xml-file.
outfilename : str
Name of directory of where to put plots outputted if debug is True.
gridtype : str
Type of the grid (as read from the grid in bastamain) containing either 'tracks'
or 'isochrones'.
debug : bool, optional
Debug flag for developers.
developermode : bool, optional
If True, experimental features will be used in run.
validationmode : bool, optional
If True, assume a validation run with changed behaviour
"""
# Load setings
asciifile = inputparams.get("asciioutput")
asciifile_dist = inputparams.get("asciioutput_dist")
centroid = inputparams["centroid"]
uncert = inputparams["uncert"]
plottype = inputparams["plotfmt"]
# Lists of params (copy to avoid problems when running multiple stars)
outparams = deepcopy(inputparams["asciiparams"])
cornerplots = deepcopy(inputparams["cornerplots"])
params = util.unique_unsort(outparams + cornerplots)
# List of params for plotting
kielplots = inputparams["kielplots"]
fitparams = inputparams["fitparams"]
fitpar_kiel = copy.deepcopy(fitparams)
# Initialise strings for printing
hout = []
out = []
hout.append("starid")
out.append(starid)
hout_dist = []
out_dist = []
hout_dist.append("starid")
out_dist.append(starid)
# Generate PDF values
logy = np.concatenate([ts.logPDF for ts in selectedmodels.values()])
noofind = len(logy)
nonzeroprop = np.isfinite(logy)
logy = logy[nonzeroprop]
nsamples = min(statdata.nsamples, noofind)
# Likelihood is only defined up to a multiplicative constant of
# proportionality, therefore we subtract max(logy) from logy to make sure
# the greatest argument to np.exp is 1 and thus the sum is greater than 1
# and we avoid dividing by zero when normalizing.
lk = logy - np.amax(logy)
p = np.exp(lk - np.log(np.sum(np.exp(lk))))
sampled_indices = np.random.choice(np.arange(len(p)), p=p, size=nsamples)
if debug:
cs = np.concatenate([ts.chi2 for ts in selectedmodels.values()])
ws = np.exp(logy + 0.5 * cs[nonzeroprop])
ws /= np.sum(ws)
expcs = np.exp(-0.5 * cs[nonzeroprop])
expcs /= np.sum(expcs)
lsampled_indices = np.random.choice(np.arange(len(p)), p=expcs, size=nsamples)
wsampled_indices = np.random.choice(np.arange(len(ws)), p=ws, size=nsamples)
# Corner plot kwargs
ckwargs = {
"show_titles": True,
"quantiles": statdata.quantiles,
"smooth": 1,
"smooth1d": "kde",
"title_kwargs": {"fontsize": 10},
"plot_datapoints": False,
"uncert": uncert,
}
# Compute distance posterior
if "distance" in params:
distanceparams = inputparams["distanceparams"]
ms = list(distanceparams["filters"])
d_samples = np.zeros((nsamples, 2 * (len(ms) + 1)))
LOS_EBV = get_EBV_along_LOS(distanceparams)
if "distance" in cornerplots:
plotout = np.zeros(3 * (2 * (len(ms) + 1)))
j = 0
dinterp, EBVinterp = [], []
for idm, m in enumerate(ms):
m_all = np.random.normal(
distanceparams["m"][m], distanceparams["m_err"][m], noofind
)
M_all = util.get_parameter_values(m, Grid, selectedmodels, noofind)
A_all = np.zeros(noofind)
# Compute distances and extinction iteratively
d_all = compute_distance_from_mag(m_all, M_all, A_all)
for i in range(3):
EBV_all = LOS_EBV(d_all)
A_all = get_absorption(EBV_all, fitparams, m)
d_all = compute_distance_from_mag(m_all, M_all, A_all)
# Create posteriors from weighted histograms
dinterp.append(
stats.posterior(
d_all, nonzeroprop, sampled_indices, nsigma=statdata.nsigma
)
)
EBVinterp.append(
stats.posterior(
EBV_all, nonzeroprop, sampled_indices, nsigma=statdata.nsigma
)
)
# Compute centroid and uncertainties and print them
xcen, xstdm, xstdp = stats.calc_key_stats(
d_all[nonzeroprop][sampled_indices], centroid, uncert
)
Acen, Astdm, Astdp = stats.calc_key_stats(
A_all[nonzeroprop][sampled_indices], centroid, uncert
)
Mcen, Mstdm, Mstdp = stats.calc_key_stats(
M_all[nonzeroprop][sampled_indices], centroid, uncert
)
if idm == 0:
print("-----------------------------------------------------")
util.printparam(
"d(" + m + ")", xcen, xstdm, xstdp, uncert=uncert, centroid=centroid
)
util.printparam(
"A(" + m + ")", Acen, Astdm, Astdp, uncert=uncert, centroid=centroid
)
if "distance" in cornerplots and uncert == "quantiles":
plotout[6 * idm : 6 * idm + 3] = [xcen, xstdp - xcen, xcen - xstdm]
plotout[6 * idm + 3 : 6 * idm + 6] = [Acen, Astdp - Acen, Acen - Astdm]
elif "distance" in cornerplots:
plotout[6 * idm : 6 * idm + 3] = [xcen, xstdm, xstdm]
plotout[6 * idm + 3 : 6 * idm + 6] = [Acen, Astdm, Astdm]
hout_dist, out_dist = util.add_out(
hout_dist, out_dist, "distance_" + m, xcen, xstdm, xstdp, uncert
)
hout_dist, out_dist = util.add_out(
hout_dist, out_dist, "A_" + m, Acen, Astdm, Astdp, uncert
)
hout_dist, out_dist = util.add_out(
hout_dist, out_dist, "M_" + m, Mcen, Mstdm, Mstdp, uncert
)
d_samples[:, j] = d_all[nonzeroprop][sampled_indices]
d_samples[:, j + 1] = A_all[nonzeroprop][sampled_indices]
j += 2
# Compute joint distance and extinction posteriors
d_array = np.unique(
[stats.quantile_1D(f.x, f.y, np.linspace(0, 1, 200)) for f in dinterp]
)
dposterior = np.prod([f(d_array) for f in dinterp], axis=0)
EBV_array = np.unique(
[stats.quantile_1D(f.x, f.y, np.linspace(0, 1, 200)) for f in EBVinterp]
)
EBVposterior = np.prod([f(EBV_array) for f in EBVinterp], axis=0)
if np.nansum(dposterior) == 0 or np.nansum(EBVposterior) == 0:
derrmessage = (
"Joint distance posterior could not be computed as the "
+ "distances derived for each magnitude are too different."
)
print(derrmessage)
fio.write_star_to_errfile(starid, inputparams, derrmessage)
if "distance" in outparams:
hout_dist, out_dist = util.add_out(
hout_dist,
out_dist,
"distance_joint",
np.nan,
np.nan,
np.nan,
uncert,
)
hout_dist, out_dist = util.add_out(
hout_dist, out_dist, "EBV", np.nan, np.nan, np.nan, uncert
)
hout, out = util.add_out(
hout, out, "distance", np.nan, np.nan, np.nan, uncert
)
if "distance" in cornerplots:
cornerplots.remove("distance")
else:
xcen, xstdm, xstdp = stats.calc_key_stats(
d_array, centroid, uncert, weights=dposterior
)
EBVcen, EBVstdm, EBVstdp = stats.calc_key_stats(
EBV_array, centroid, uncert, weights=EBVposterior
)
util.printparam(
"d(joint)", xcen, xstdm, xstdp, centroid=centroid, uncert=uncert
)
util.printparam(
"E(B-V)(joint)",
EBVcen,
EBVstdm,
EBVstdp,
centroid=centroid,
uncert=uncert,
)
if "distance" in cornerplots and uncert == "quantiles":
plotout[-6:-3] = [xcen, xstdp - xcen, xcen - xstdm]
plotout[-3:] = [EBVcen, EBVstdp - EBVcen, EBVcen - EBVstdm]
elif "distance" in cornerplots:
plotout[-6:-3] = [xcen, xstdm, xstdm]
plotout[-3:] = [EBVcen, EBVstdm, EBVstdm]
d_samples[:, -2] = d_array[
np.random.choice(
np.arange(len(dposterior)),
p=dposterior / np.sum(dposterior),
size=nsamples,
)
]
d_samples[:, -1] = EBV_array[
np.random.choice(
np.arange(len(EBVposterior)),
p=EBVposterior / np.sum(EBVposterior),
size=nsamples,
)
]
# Create plots
if "distance" in cornerplots:
clabels = []
for m in ms:
clabels.append("d(" + m + ")")
clabels.append("A(" + m + ")")
clabels = clabels + ["d(joint)", "E(B-V)(joint)"]
try:
plot_corner.corner(
d_samples, labels=clabels, plotout=plotout, **ckwargs
)
cornerfile = outfilename + "_distance_corner." + plottype
plt.savefig(cornerfile)
plt.close()
print("\nSaved distance corner plot to {0}.\n".format(cornerfile))
except Exception as error:
print(
"\nDistance corner plot failed with the error:{0}\n".format(
error
)
)
# Plotting done: Remove keyword
cornerplots.remove("distance")
# Add to output array
if "distance" in outparams:
hout_dist, out_dist = util.add_out(
hout_dist, out_dist, "distance_joint", xcen, xstdm, xstdp, uncert
)
hout_dist, out_dist = util.add_out(
hout_dist, out_dist, "EBV", EBVcen, EBVstdm, EBVstdp, uncert
)
hout, out = util.add_out(
hout, out, "distance", xcen, xstdm, xstdp, uncert
)
# We have finished using distances: Remove keyword
params.remove("distance")
if "distance" in outparams:
outparams.remove("distance")
# Make sure that something is written to the ascii distance files! It will be
# deleted later...
else:
hout_dist, out_dist = util.add_out(
hout_dist, out_dist, "distance_joint", np.nan, np.nan, np.nan, uncert
)
hout_dist, out_dist = util.add_out(
hout_dist, out_dist, "EBV", np.nan, np.nan, np.nan, uncert
)
# Allocate arrays
samples = np.zeros((nsamples, len(cornerplots)))
if debug:
lsamples = np.zeros((nsamples, len(cornerplots)))
wsamples = np.zeros((nsamples, len(cornerplots)))
plotin = np.ones(2 * len(cornerplots)) * -9999
plotout = np.zeros(3 * len(cornerplots))
dnu_scales = inputparams.get("dnu_scales", {})
for numpar, param in enumerate(params):
# Generate list of x values
x = util.get_parameter_values(param, Grid, selectedmodels, noofind)
# Scale back to muHz before output/plot
if param.startswith("dnu") and param not in ["dnufit", "dnufitMos12"]:
dnu_rescal = dnu_scales.get(param, 1.00)
x *= inputparams.get("dnusun", sydc.SUNdnu) / dnu_rescal
if param in fitparams:
fitparams[param] = (
np.asarray(fitparams[param])
* inputparams.get("dnusun", sydc.SUNdnu)
/ dnu_rescal
)
elif param.startswith("numax"):
x *= inputparams.get("numsun", sydc.SUNnumax)
if param in fitparams:
fitparams[param] = np.asarray(fitparams[param]) * inputparams.get(
"numsun", sydc.SUNnumax
)
elif param in ["dnufit", "dnufitMos12"]:
dnu_rescal = dnu_scales.get(param, 1.00)
x /= dnu_rescal
if param in fitparams:
fitparams[param] = np.asarray(fitparams[param]) / dnu_rescal
# Compute quantiles (using np.quantile is ~50 times faster than quantile_1D)
xcen, xstdm, xstdp = stats.calc_key_stats(
x[nonzeroprop][sampled_indices], centroid, uncert
)
# Print info to log and console
if numpar == 0:
print("-----------------------------------------------------")
util.printparam(param, xcen, xstdm, xstdp, uncert=uncert, centroid=centroid)
if param in cornerplots:
idx = cornerplots.index(param)
if param in fitparams:
xin, stdin = fitparams[param]
plotin[2 * idx : 2 * idx + 2] = [xin, stdin]
samples[:, idx] = x[nonzeroprop][sampled_indices]
if debug:
lsamples[:, idx] = x[nonzeroprop][lsampled_indices]
wsamples[:, idx] = x[nonzeroprop][wsampled_indices]
if uncert == "quantiles":
plotout[3 * idx : 3 * idx + 3] = [xcen, xstdp - xcen, xcen - xstdm]
else:
plotout[3 * idx : 3 * idx + 3] = [xcen, xstdm, xstdm]
if param in outparams:
hout, out = util.add_out(hout, out, param, xcen, xstdm, xstdp, uncert)
# Create header for ascii file and save it
if asciifile is not False:
hline = b"# "
for i in range(len(hout)):
hline += hout[i].encode() + " ".encode()
if isinstance(asciifile, IOBase):
asciifile.seek(0)
if b"#" not in asciifile.readline():
asciifile.write(hline + b"\n")
np.savetxt(
asciifile, np.asarray(out).reshape(1, len(out)), fmt="%s", delimiter=" "
)
print("\nSaved results to " + asciifile.name + ".")
elif asciifile is False:
pass
else:
np.savetxt(
asciifile,
np.asarray(out).reshape(1, len(out)),
fmt="%s",
header=hline,
delimiter=" ",
)
print("Saved results to " + asciifile + ".")
if asciifile_dist:
if len(hout_dist) > 0:
hline = b"# "
for i in range(len(hout_dist)):
hline += hout_dist[i].encode() + " ".encode()
if isinstance(asciifile_dist, IOBase):
asciifile_dist.seek(0)
if b"#" not in asciifile_dist.readline():
asciifile_dist.write(hline + b"\n")
np.savetxt(
asciifile_dist,
np.asarray(out_dist).reshape(1, len(out_dist)),
fmt="%s",
delimiter=" ",
)
if "distance" in outparams:
print(
"Saved distance results for different filters to %s."
% asciifile_dist.name
)
elif asciifile_dist is False:
pass
else:
np.savetxt(
asciifile_dist,
np.asarray(out_dist).reshape(1, len(out_dist)),
fmt="%s",
header=hline,
delimiter=" ",
)
if "distance" in outparams:
print(
"Saved distance results for different filters to %s."
% asciifile_dist
)
# Compare input to output and produce a comparison plot
if compareinputoutput | developermode:
comparewarn = util.compare_output_to_input(
starid, inputparams, hout, out, hout_dist, out_dist, uncert=uncert
)
if comparewarn:
print(
"DEBUG: The input values of the fitting parameters "
+ "disagree with the outputted values."
)
if not len(kielplots):
print("DEBUG: make Kiel diagram due to warning")
library_param = "massini" if "tracks" in gridtype.lower() else "age"
x = util.get_parameter_values(
library_param, Grid, selectedmodels, noofind
)
lp_interval = np.quantile(
x[nonzeroprop][sampled_indices], statdata.quantiles[1:]
)
x = util.get_parameter_values("FeH", Grid, selectedmodels, noofind)
feh_interval = np.quantile(
x[nonzeroprop][sampled_indices], statdata.quantiles[1:]
)
x = util.get_parameter_values("Teff", Grid, selectedmodels, noofind)
Teffout = np.quantile(
x[nonzeroprop][sampled_indices], statdata.quantiles
)
x = util.get_parameter_values("logg", Grid, selectedmodels, noofind)
loggout = np.quantile(
x[nonzeroprop][sampled_indices], statdata.quantiles
)
try:
fig = plot_kiel.kiel(
Grid=Grid,
selectedmodels=selectedmodels,
fitparams=fitpar_kiel,
inputparams=inputparams,
lp_interval=lp_interval,
feh_interval=feh_interval,
Teffout=Teffout,
loggout=loggout,
gridtype=gridtype,
nameinplot=starid if inputparams["nameinplot"] else False,
debug=debug,
developermode=developermode,
validationmode=validationmode,
)
kielfile = outfilename + "_warn_kiel." + plottype
fig.savefig(kielfile)
plt.close()
print("Saved warning Kiel diagram to " + kielfile + ".")
except Exception as error:
print("Warning Kiel diagram failed with the error:", error)
# Create corner plot
if len(cornerplots):
try:
plot_corner.corner(
samples,
labels=parameters.get_keys(cornerplots)[1],
truth_color=parameters.get_keys(cornerplots)[3],
plotin=plotin,
plotout=plotout,
nameinplot=starid if inputparams["nameinplot"] else False,
**ckwargs,
)
cornerfile = outfilename + "_corner." + plottype
plt.savefig(cornerfile)
plt.close()
print("Saved corner plot to " + cornerfile + ".")
except Exception as error:
print("Corner plot failed with the error:", error)
if debug:
try:
plot_corner.corner(
lsamples,
labels=parameters.get_keys(cornerplots)[1],
truth_color=parameters.get_keys(cornerplots)[3],
plotin=plotin,
plotout=plotout,
nameinplot=starid if inputparams["nameinplot"] else False,
**ckwargs,
)
cornerfile = outfilename + "_DEBUG_likelihood_corner." + plottype
plt.savefig(cornerfile)
plt.close()
print("Saved likelihood corner plot to " + cornerfile + ".")
except Exception as error:
print("Likelihood corner plot failed with the error:", error)
try:
plot_corner.corner(
wsamples,
labels=parameters.get_keys(cornerplots)[1],
truth_color=parameters.get_keys(cornerplots)[3],
plotin=plotin,
plotout=plotout,
nameinplot=starid if inputparams["nameinplot"] else False,
**ckwargs,
)
cornerfile = outfilename + "_DEBUG_prior_corner." + plottype
plt.savefig(cornerfile)
plt.close()
print("Saved prior corner plot to " + cornerfile + ".")
except Exception as error:
print("Prior corner plot failed with the error:", error)
# Create Kiel diagram
if len(kielplots):
# Find quantiles of massini/age and FeH to determine what tracks to plot
library_param = "massini" if "tracks" in gridtype.lower() else "age"
x = util.get_parameter_values(library_param, Grid, selectedmodels, noofind)
lp_interval = np.quantile(
x[nonzeroprop][sampled_indices], statdata.quantiles[1:]
)
# Use correct metallicity (only important for alpha enhancement)
metalname = "MeH" if "MeH" in fitparams else "FeH"
x = util.get_parameter_values(metalname, Grid, selectedmodels, noofind)
feh_interval = np.quantile(
x[nonzeroprop][sampled_indices], statdata.quantiles[1:]
)
x = util.get_parameter_values("Teff", Grid, selectedmodels, noofind)
Teffout = np.quantile(x[nonzeroprop][sampled_indices], statdata.quantiles)
x = util.get_parameter_values("logg", Grid, selectedmodels, noofind)
loggout = np.quantile(x[nonzeroprop][sampled_indices], statdata.quantiles)
try:
fig = plot_kiel.kiel(
Grid=Grid,
selectedmodels=selectedmodels,
fitparams=fitpar_kiel,
inputparams=inputparams,
lp_interval=lp_interval,
feh_interval=feh_interval,
Teffout=Teffout,
loggout=loggout,
gridtype=gridtype,
nameinplot=starid if inputparams["nameinplot"] else False,
debug=debug,
developermode=developermode,
validationmode=validationmode,
color_by_likelihood=False,
)
kielfile = outfilename + "_kiel." + plottype
fig.savefig(kielfile)
plt.close()
print("Saved Kiel diagram to " + kielfile + ".")
except Exception as error:
print("Kiel diagram failed with the error:", error)
raise
if debug and len(inputparams["magnitudes"]) > 0:
print("Make normalised distribution plot of terms in PDF computation")
mins = []
bayw = np.concatenate([ts.bayw for ts in selectedmodels.values()])
bayw = bayw[nonzeroprop]
bayw -= np.amax(bayw)
mins.append(np.amin(bayw))
magw = np.concatenate([ts.magw for ts in selectedmodels.values()])
magw = magw[nonzeroprop]
magw -= np.amax(magw)
mins.append(np.amin(magw))
IMFw = np.concatenate([ts.IMFw for ts in selectedmodels.values()])
IMFw = IMFw[nonzeroprop]
IMFw -= np.amax(IMFw)
mins.append(np.amin(IMFw))
csw = -0.5 * cs
csw -= np.amax(csw)
for param in ["massini"]: # params:
fig, axs = plt.subplots(5, sharex=True)
x = util.get_parameter_values(param, Grid, selectedmodels, noofind)
axs[0].plot(x[nonzeroprop], logy, "k.", label="Posterior", ms=3, alpha=0.1)
axs[1].plot(
x[nonzeroprop][sampled_indices],
bayw[sampled_indices],
"b.",
label="Bayesian weights",
ms=3,
alpha=0.1,
)
axs[2].plot(
x[nonzeroprop][sampled_indices],
magw[sampled_indices],
"g.",
label="Absolute magnitude",
ms=3,
alpha=0.1,
)
axs[3].plot(
x[nonzeroprop][sampled_indices],
IMFw[sampled_indices],
"c.",
label="IMF",
ms=3,
alpha=0.1,
)
axs[4].plot(
x[nonzeroprop][sampled_indices],
csw[sampled_indices],
"m.",
label=r"$\chi^2$ part",
ms=3,
alpha=0.1,
)
plt.xlabel(param)
shadowaxes = fig.add_subplot(111, frame_on=False)
plt.tick_params(
labelcolor="none", top=False, bottom=False, left=False, right=False
)
shadowaxes.set_ylabel("Scaled probability")
for i in [1, 2, 3, 4]:
axs[i].set_ylim([min(mins) - 0.5, 0.5])
handles, labels = [
(a + b + c + d + e)
for a, b, c, d, e in zip(
axs[0].get_legend_handles_labels(),
axs[1].get_legend_handles_labels(),
axs[2].get_legend_handles_labels(),
axs[3].get_legend_handles_labels(),
axs[4].get_legend_handles_labels(),
)
]
fig.legend(handles, labels, loc="upper center", ncol=5)
distfile = outfilename + "_DEBUG_dist" + param + "." + plottype
fig.savefig(distfile)
plt.close()
print("Saved distribution plot to " + distfile + ".")