Source code for utils_general
"""
General mix of utility functions
"""
import sys
from io import IOBase
import numpy as np
[docs]
def h5py_to_array(xs):
"""
Copy vector/dataset from an HDF5 file to a NumPy array
Parameters
----------
xs : h5py_dataset
The input dataset read by h5py from an HDF5 object
Returns
-------
res : array_like
Copy of the dataset as NumPy array
"""
res = np.empty(shape=xs.shape, dtype=xs.dtype)
res[:] = xs[:]
return res
[docs]
def prt_center(text, llen):
"""
Prints a centered line
Parameters
----------
text : str
The text string to print
llen : int
Length of the line
Returns
-------
None
"""
print("{0}{1}{0}".format(int((llen - len(text)) / 2) * " ", text))
[docs]
class Logger(object):
"""
Class used to redefine stdout to terminal and an output file.
Parameters
----------
outfilename : str
Absolute path to an output file
"""
# Credit: http://stackoverflow.com/a/14906787
def __init__(self, outfilename):
self.terminal = sys.stdout
self.log = open(outfilename + ".log", "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
[docs]
def list_metallicities(Grid, defaultpath, inputparams, limits):
"""
Get a list of metallicities in the grid that we loop over
Parameters
----------
Grid : hdf5 object
The already loaded grid, containing the tracks/isochrones.
defaultpath : str
Path in Grid
inputparams : dict
Dictionary of all controls and input.
limits : dict
Dict of flat priors used in run.
Returns
-------
metal : list
List of possible metalliticies that should be looped over in
`bastamain`.
"""
if "grid" in defaultpath:
metal = range(1)
else:
metal = [x for x in Grid[defaultpath].items() if "=" in x[0]]
for i in range(len(metal)):
metal[i] = float(metal[i][0][4:])
metal = np.asarray(metal)
metal_name = "MeH" if "MeH" in limits else "FeH"
if metal_name in limits:
metal = metal[
(metal >= limits[metal_name][0]) & (metal <= limits[metal_name][1])
]
return metal
[docs]
def unique_unsort(params):
"""
As we want to check for unique elements to not copy elements, but retain the
order they were given in, we have to do this, until numpy implements an 'unsort'
key to numpy.unique...
Parameters
----------
params : list
List of parameters
Returns
-------
params : list
List of unique params, retaining order
"""
indexes = np.unique(params, return_index=True)[1]
return [params[index] for index in sorted(indexes)]
[docs]
def compare_output_to_input(
starid, inputparams, hout, out, hout_dist, out_dist, uncert="qunatiles", sigmacut=1
):
"""
This function compares the outputted value of all fitting parameters
to the input that was fitted.
If one or more fitting parameters deviates more than 'sigmacut' number
of the effective symmetric uncertainty away from their input parameter,
a warning is printed and 'starid' is appended to the .warn-file.
Parameters
----------
starid : str
Unique identifier of current target.
inputparms : dict
Dict containing input from xml-file.
hout : list
List of column headers for output
out : list
List of output values for the columns given in `hout`.
uncert : str
Type of reported uncertainty to use for comparison.
sigmacut : float, optional
Number of standard deviation used for determining when to issue
a warning.
Returns
-------
comparewarn : bool
Flag to determine whether or not a warning was raised.
"""
if inputparams["warnoutput"] is False:
return False
fitparams = inputparams["fitparams"]
warnfile = inputparams["warnoutput"]
comparewarn = False
ps = []
sigmas = []
for p in fitparams:
if p in hout:
idx = np.nonzero([p == xout for xout in hout])[0][0]
xin, xinerr = fitparams[p]
if uncert == "quantiles":
outerr = (out[idx + 1] + out[idx + 2]) / 2
else:
outerr = out[idx + 1]
serr = np.sqrt(outerr**2 + xinerr**2)
sigma = np.abs(out[idx] - xin) / serr
bigdiff = sigma >= sigmacut
if bigdiff:
comparewarn = True
ps.append(p)
sigmas.append(sigma)
if len(inputparams["magnitudes"]) > 0:
for m in list(inputparams["distanceparams"]["filters"]):
mdist = "M_" + m
if mdist in hout_dist:
idx = np.nonzero([x == mdist for x in hout_dist])[0][0]
priorM = inputparams["magnitudes"][m]["median"]
priorerrp = inputparams["magnitudes"][m]["errp"]
priorerrm = inputparams["magnitudes"][m]["errm"]
if uncert == "quantiles":
outerr = (out_dist[idx + 1] + out_dist[idx + 2]) / 2
else:
outerr = out_dist[idx + 1]
serr = np.sqrt(((priorerrp + priorerrm) / 2) ** 2 + outerr**2)
sigma = np.abs(out_dist[idx] - priorM) / serr
bigdiff = sigma >= sigmacut
if bigdiff:
comparewarn = True
ps.append(mdist)
sigmas.append(sigma)
if "distance" in hout_dist:
idx = np.nonzero([x == "distance_joint" for x in hout_dist])[0][0]
priordistqs = inputparams["distanceparams"]["priordistance"]
priorerrm = priordistqs[0] - priordistqs[1]
priorerrp = priordistqs[2] - priordistqs[0]
if uncert == "quantiles":
outerr = (out_dist[idx + 1] + out_dist[idx + 2]) / 2
else:
outerr = out_dist[idx + 1]
serr = np.sqrt(((priorerrp + priorerrm) / 2) ** 2 + outerr**2)
sigma = np.abs(out_dist[idx] - priordistqs[1]) / serr
bigdiff = sigma >= sigmacut
if bigdiff:
comparewarn = True
ps.append("distance")
sigmas.append(sigma)
if comparewarn:
print("A >%s sigma difference was found between input and output of" % sigmacut)
print(ps)
print("with sigma differences of")
print(sigmas)
if isinstance(warnfile, IOBase):
warnfile.write("{}\t{}\t{}\n".format(starid, ps, sigmas))
else:
with open(warnfile, "a") as wf:
wf.write("{}\t{}\t{}\n".format(starid, ps, sigmas))
return comparewarn
[docs]
def inflog(x):
"np.log(x), but where x=0 returns -inf without a warning"
with np.errstate(divide="ignore"):
return np.log(x)
[docs]
def add_out(hout, out, par, x, xm, xp, uncert):
"""
Add entries in out list, according to the wanted uncertainty.
Parameters
----------
hout : list
Names in header
out : list
Parameter values
par : str
Parameter name
x : float
Centroid value
xm : float
Lower bound uncertainty, or symmetric uncertainty
xp : float, None
Upper bound uncertainty if not symmetric uncertainty (None for symmetric)
uncert : str
Type of reported uncertainty, "quantiles" or "std"
Returns
-------
hout : list
Header list with added names
out : list
Parameter list with added entries
"""
if uncert == "quantiles":
hout += [par, par + "_errp", par + "_errm"]
out += [x, xp - x, x - xm]
else:
hout += [par, par + "_err"]
out += [x, xm]
return hout, out
def normfactor(alphas, ms):
# Algorithm from App. A in Pflamm-Altenburg & Kroupa (2006)
# https://ui.adsabs.harvard.edu/abs/2006MNRAS.373..295P/abstract
ks = np.zeros(len(alphas))
ks[0] = (1 / ms[1]) ** alphas[0]
ks[1] = (1 / ms[1]) ** alphas[1]
if len(ks) == 2:
return ks
ks[2] = (ms[2] / ms[1]) ** alphas[1] * (1 / ms[2]) ** alphas[2]
if len(ks) == 3:
return ks
if len(ks) == 4:
ks[3] = (
(ms[2] / ms[1]) ** alphas[1]
* (ms[3] / ms[2]) ** alphas[2]
* (1 / ms[3]) ** alphas[3]
)
return ks
else:
print("Mistake in normfactor")
[docs]
def get_parameter_values(parameter, Grid, selectedmodels, noofind):
"""
Get parameter values from grid
Parameters
----------
parameter : str
Grid, hdf5 object
selectedmodels :
models to return
noofind :
number of parameter values
Returns
-------
x_all : array
parameter values
"""
x_all = np.zeros(noofind)
i = 0
for modelpath in selectedmodels:
N = len(selectedmodels[modelpath].logPDF)
try:
x_all[i : i + N] = selectedmodels[modelpath].paramvalues[parameter]
except Exception:
x_all[i : i + N] = Grid[modelpath + "/" + parameter][
selectedmodels[modelpath].index
]
i += N
return x_all
[docs]
def printparam(param, xmed, xstdm, xstdp, uncert="quantiles", centroid="median"):
"""
Pretty-print of output parameter to log and console.
Parameters
----------
param : str
Name of parameter
xmed : float
Centroid value (median or mean)
xstdm : float
Lower bound uncertainty, or symmetric unceartainty
xstdp : float
Upper bound uncertainty, if not symmetric. Unused if uncert is std.
uncert : str, optional
Type of reported uncertainty, "quantiles" or "std"
centroid : str, optional
Type of reported uncertainty, "median" or "mean"
Returns
-------
None
"""
# Formats made to accomodate longest possible parameter name ("E(B-V)(joint)")
print("{0:9} {1:13} : {2:12.6f}".format(centroid, param, xmed))
if uncert == "quantiles":
print("{0:9} {1:13} : {2:12.6f}".format("err_plus", param, xstdp - xmed))
print("{0:9} {1:13} : {2:12.6f}".format("err_minus", param, xmed - xstdm))
else:
print("{0:9} {1:13} : {2:12.6f}".format("stdev", param, xstdm))
print("-----------------------------------------------------")