Source code for bastamain

"""
Main module for running BASTA analysis
"""

import gc
import os
import sys
import time

import h5py

# Import matplotlib after other plotting modules for proper setup
# --> Here in main it is only used for clean-up
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from basta import fileio as fio
from basta import plot_driver, priors, process_output, stats
from basta import utils_general as util
from basta import utils_seismic as su
from basta.constants import freqtypes


# Custom exception
[docs] class LibraryError(Exception): pass
[docs] def BASTA( starid: str, gridfile: str, inputparams: dict, gridid: bool | tuple = False, usebayw: bool = True, usepriors: tuple = (None,), optionaloutputs: bool = False, seed: int | None = None, debug: bool = False, verbose: bool = False, developermode: bool = False, validationmode: bool = False, ) -> None: """ The BAyesian STellar Algorithm (BASTA). (c) 2025, The BASTA Team For a description of how to use BASTA, please explore the documentation (https://github.com/BASTAcode/BASTA). This function is typically called by :func:'xmltools.run_xml()' Parameters ---------- starid : str Unique identifier for this target. gridfile : str Path and name of the hdf5 file containing the isochrones or tracks used in the fitting inputparams : dict Dictionary containing most information needed, e.g. controls, fitparameters, output options. gridid : bool or tuple For isochrones, a tuple containing (overshooting [f], diffusion [0 or 1], mass loss [eta], alpha enhancement [0.0 ... 0.4]) used for selecting a science case / path in the library. usebayw : bool or tuple If True, bayesian weights are applied in the computation of the likelihood. See :func:`interpolation_helpers.bay_weights()` for details. usepriors : tuple Tuple of strings containing name of priors (e.g., an IMF). See :func:`priors` for details. optionaloutputs : bool, optional If True, saves a 'json' file for each star with the global results and the PDF. seed : int, optional The seed of randomness debug : bool, optional Activate additional output for debugging (for developers) verbose : bool, optional Activate a lot (!) of additional output (for developers) developermode : bool, optional Activate experimental features (for developers) validationmode : bool, optional Activate validation mode features (for validation purposes only) """ # Enable legacy printing of NumPy data types # --> E.g., print 104.14836386995329 instead of np.float64(104.14836386995329) # and 'Teff' instead of np.str_('Teff') to the .log file np.set_printoptions(legacy="1.25") # Set output directory and filenames t0 = time.localtime() outputdir = inputparams.get("output") outfilename = os.path.join(outputdir, starid) # Start the log stdout = sys.stdout sys.stdout = util.Logger(outfilename) # Pretty printing a header util.print_bastaheader(t0=t0, seed=seed, developermode=developermode) # Load the desired grid and obtain information from the header Grid = h5py.File(gridfile, "r") gridtype, gridver, gridtime, grid_is_intpol = util.read_grid_header(Grid) # Verbose information on the grid file print(f"\nFitting star id: {starid} .") print(f"* Using the grid '{gridfile}' of type '{gridtype}'.") print(f" - Grid built with BASTA version {gridver}, timestamp: {gridtime}.") entryname, defaultpath, difsolarmodel = util.check_gridtype(gridtype, gridid=gridid) # Read available weights if not provided by the user bayweights, dweight = ( util.read_grid_bayweights(Grid, gridtype) if usebayw else (None, None) ) # Get list of parameters cornerplots = inputparams["cornerplots"] outparams = inputparams["asciiparams"] allparams = list(np.unique(cornerplots + outparams)) inputparams, allparams = util.prepare_distancefitting( inputparams=inputparams, debug=debug, debug_dirpath=outfilename, allparams=allparams, ) # Create list of all available input parameters fitparams = inputparams.get("fitparams") fitfreqs = inputparams["fitfreqs"] distparams = inputparams.get("distanceparams", False) limits = inputparams.get("limits") # Scale dnu and numax using a solar model or default solar values inputparams = su.solar_scaling(Grid, inputparams, diffusion=difsolarmodel) # Prepare asteroseismic quantities if required if fitfreqs["active"]: if not all(x in freqtypes.alltypes for x in fitfreqs["fittypes"]): print(fitfreqs["fittypes"]) raise ValueError("Unrecognized frequency fitting parameters!") # Obtain/calculate all frequency related quantities ( obskey, obs, obsfreqdata, obsfreqmeta, obsintervals, ) = su.prepare_obs(inputparams, verbose=verbose, debug=debug) # Apply prior on dnufit to mimick the range defined by dnufrac if fitfreqs["dnuprior"] and ("dnufit" not in limits): dnufit_frac = fitfreqs["dnufrac"] * fitfreqs["dnufit"] dnuerr = max(3 * fitfreqs["dnufit_err"], dnufit_frac) limits["dnufit"] = [ fitfreqs["dnufit"] - dnuerr, fitfreqs["dnufit"] + dnuerr, ] # Check if any specified limit in prior is in header, and can be used to # skip computation of models, in order to speed up computation tracks_headerpath = "header/" if "tracks" in gridtype.lower(): headerpath: str | bool = tracks_headerpath elif "isochrones" in gridtype.lower(): headerpath = tracks_headerpath + defaultpath if "FeHini" in limits: del limits["FeHini"] print("Warning: Dropping prior in FeHini, redundant for isochrones!") else: headerpath = False # Gridcut dictionary containing cutting parameters gridcut = {} if headerpath: keys = Grid[headerpath].keys() # Compare keys in header and limits for key in keys: if key in limits: gridcut[key] = limits[key] # Remove key from limits, to avoid redundant second check del limits[key] # Apply the cut on header parameters with a special treatment of diffusion if headerpath and gridcut: print("\nCutting in grid based on sampling parameters ('gridcut'):") noofskips = [0, 0] for cpar in gridcut: if cpar != "dif": print(f"* {cpar}: {gridcut[cpar]}") # Diffusion switch printed in a more readable format if "dif" in gridcut: # As gridcut['dif'] is always either [-inf, 0.5] or [0.5, inf] # The location of 0.5 can be used as the switch switch = np.where(np.array(gridcut["dif"]) == 0.5)[0][0] print( "* Only considering tracks with diffusion turned", "{:s}!".format(["on", "off"][switch]), ) util.print_fitparams(fitparams=fitparams) if fitfreqs["active"]: util.print_seismic(fitfreqs=fitfreqs, obskey=obskey, obs=obs) util.print_distances(distparams, inputparams["asciiparams"]) util.print_additional(inputparams) util.print_weights(bayweights, gridtype) util.print_priors(limits, usepriors) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Start likelihood computation # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Two loop cases for the outer "metal" loop: # - For Garstec and MESA grids, the top level contains only one element ("tracks"). # Here the outer loop will run only once. # - For BaSTI, the top level is a list of metallicities and the outer loop will run # multiple times. metal = util.list_metallicities(Grid, defaultpath, inputparams, limits) # We assume Garstec grid structure. The path will be updated in the loop for BaSTI group_name = defaultpath + "tracks/" # Before running the actual loop, all tracks/isochrones are counted to better # estimate the progress. trackcounter = 0 for FeH in metal: if "grid" not in defaultpath: group_name = f"{defaultpath}FeH={FeH:.4f}/" assert group_name == defaultpath + "FeH=" + format(FeH, ".4f") + "/" group = Grid[group_name] trackcounter += len(group.items()) # Prepare the main loop shapewarn = 0 warn = True selectedmodels = {} noofind = 0 noofposind = 0 # In some cases we need to store quantities computed at runtime if fitfreqs["active"] and fitfreqs["dnufit_in_ratios"]: dnusurfmodels = {} if fitfreqs["active"] and fitfreqs["glitchfit"]: glitchmodels = {} print( f"\n\nComputing likelihood of models in the grid ({trackcounter} {entryname}) ..." ) # Use a progress bar (with the package tqdm; will write to stderr) pbar = tqdm(total=trackcounter, desc="--> Progress", ascii=True) for FeH in metal: if "grid" not in defaultpath: group_name = f"{defaultpath}FeH={FeH:.4f}/" group = Grid[group_name] for noingrid, (name, libitem) in enumerate(group.items()): # Update progress bar in the start of the loop to count skipped tracks pbar.update(1) # For grid with interpolated tracks, skip tracks flagged as empty if grid_is_intpol: if libitem["IntStatus"][()] < 0: continue # Check for diffusion if "dif" in inputparams: if round(libitem["dif"][0]) != round(float(inputparams["dif"])): continue # Check if mass or age is in limits to efficiently skip if "grid" not in defaultpath: param, val = name.split("=") if param == "mass": param += "ini" if param in limits: # if age or massini is outside limits, skip this iteration if float(val) < limits[param][0] or float(val) > limits[param][1]: continue # Check if track should be skipped from cut in initial parameters if gridcut: noofskips[1] += 1 docut = False for param in gridcut: if "tracks" in gridtype.lower(): value = Grid[tracks_headerpath][param][noingrid] elif "isochrones" in gridtype.lower(): # For isochrones, metallicity is already cut from the # metal list and lookup of age is simplest and fastest if param == "age": value = float(name[4:]) # If value is outside cut limits, skip looking at the rest if not (value >= gridcut[param][0] and value <= gridcut[param][1]): docut = True continue # Actually skip this iteration if docut: noofskips[0] += 1 continue # Check which models have parameters within limits index = np.ones(len(libitem["age"][:]), dtype=bool) for param in limits: index &= libitem[param][:] >= limits[param][0] index &= libitem[param][:] <= limits[param][1] # Check which models have phases as specified if "phase" in inputparams: # Mapping of verbose input phases to internal numbers pmap = { "pre-ms": 1, "solar": 2, "rgb": 3, "flash": 4, "clump": 5, "agb": 6, } # Fitting multiple phases or just one if isinstance(inputparams["phase"], tuple): iphases = [pmap[ip] for ip in inputparams["phase"]] phaseindex = libitem["phase"][:] == iphases[0] for j in range(1, len(iphases)): phaseindex |= libitem["phase"][:] == iphases[j] index &= phaseindex else: iphase = pmap[inputparams["phase"]] index &= libitem["phase"][:] == iphase # Check which models have l=0, lowest n within tolerance if fitfreqs["active"]: indexf = np.zeros(len(index), dtype=bool) for ind in np.where(index)[0]: rawmod = libitem["osc"][ind] rawmodkey = libitem["osckey"][ind] mod = su.transform_obj_array(rawmod) modkey = su.transform_obj_array(rawmodkey) modkeyl0, modl0 = su.get_givenl(l=0, osc=mod, osckey=modkey) # As mod is ordered (stacked in increasing n and l), # then [0, 0] is the lowest l=0 mode same_n = modkeyl0[1, :] == obskey[1, 0] cl0 = modl0[0, same_n] if len(cl0) > 1: cl0 = cl0[0] # Note to self: This code is pretty hard to read... if ( cl0 >= ( obs[0, 0] - min( (fitfreqs["dnufrac"] / 2 * fitfreqs["dnufit"]), (3 * obs[1, 0]), ) ) ) and (cl0 - obs[0, 0]) <= ( fitfreqs["dnufrac"] * fitfreqs["dnufit"] ): indexf[ind] = True index &= indexf # If any models are within tolerances, calculate statistics if np.any(index): chi2 = np.zeros(index.sum()) paramvalues = {} for param in fitparams: paramvals = libitem[param][index] chi2 += ( (paramvals - fitparams[param][0]) / fitparams[param][1] ) ** 2.0 if param in allparams: paramvalues[param] = paramvals # Add parameters not in fitparams for param in allparams: if param not in fitparams: paramvalues[param] = libitem[param][index] # Frequency (and/or ratio and/or glitch) fitting if fitfreqs["active"]: if fitfreqs["dnufit_in_ratios"]: dnusurf = np.zeros(index.sum()) if fitfreqs["glitchfit"]: glitchpar = np.zeros((index.sum(), 3)) for indd, ind in enumerate(np.where(index)[0]): chi2_freq, warn, shapewarn, addpars = stats.chi2_astero( obskey, obs, obsfreqmeta, obsfreqdata, obsintervals, libitem, ind, fitfreqs, warnings=warn, shapewarn=shapewarn, debug=debug, verbose=verbose, ) chi2[indd] += chi2_freq if fitfreqs["dnufit_in_ratios"]: dnusurf[indd] = addpars["dnusurf"] if fitfreqs["glitchfit"]: glitchpar[indd] = addpars["glitchparams"] # Bayesian weights (across tracks/isochrones) logPDF = 0.0 if debug: bayw = 0.0 magw = 0.0 IMFw = 0.0 if bayweights is not None: for weight in bayweights: logPDF += util.inflog(libitem[weight][()]) if debug: bayw += util.inflog(libitem[weight][()]) # Within a given track/isochrone; these are called dweights assert dweight is not None logPDF += util.inflog(libitem[dweight][index]) if debug: bayw += util.inflog(libitem[dweight][index]) # Multiply by absolute magnitudes, if present for f in inputparams["magnitudes"]: mags = inputparams["magnitudes"][f]["prior"] absmags = libitem[f][index] interp_mags = mags(absmags) logPDF += util.inflog(interp_mags) if debug: magw += util.inflog(interp_mags) # Multiply priors into the weight for prior in usepriors: logPDF += util.inflog(getattr(priors, prior)(libitem, index)) if debug: IMFw += util.inflog(getattr(priors, prior)(libitem, index)) # Calculate likelihood from weights, priors and chi2 # PDF = weights * np.exp(-0.5 * chi2) logPDF -= 0.5 * chi2 if debug and verbose: print( "DEBUG: Mass with nonzero likelihood:", libitem["massini"][index][~np.isinf(logPDF)], ) # Sum the number indexes and nonzero indexes noofind += len(logPDF) noofposind += np.count_nonzero(~np.isinf(logPDF)) if debug and verbose: print( f"DEBUG: Index found: {group_name + name}, {~np.isinf(logPDF)}" ) # Store statistical info if debug: selectedmodels[group_name + name] = stats.priorlogPDF( index, logPDF, chi2, bayw, magw, IMFw ) else: selectedmodels[group_name + name] = stats.Trackstats( index, logPDF, chi2 ) if fitfreqs["active"] and fitfreqs["dnufit_in_ratios"]: dnusurfmodels[group_name + name] = stats.Trackdnusurf(dnusurf) if fitfreqs["active"] and fitfreqs["glitchfit"]: glitchmodels[group_name + name] = stats.Trackglitchpar( glitchpar[:, 0], glitchpar[:, 1], glitchpar[:, 2], ) elif debug and verbose: print( f"DEBUG: Index not found: {group_name + name}, {~np.isinf(logPDF)}" ) # End loop over isochrones/tracks ####################################################################### # End loop over metals ########################################################################### pbar.close() print( f"Done! Computed the likelihood of {noofind!s} models,", f"found {noofposind!s} models with non-zero likelihood!\n", ) if gridcut: print( f"(Note: The use of 'gridcut' skipped {noofskips[0]} out of {noofskips[1]} {gridtype})\n" ) # Raise possible warnings if shapewarn == 1: print( "Warning: Found models with fewer frequencies than observed!", "These were set to zero likelihood!", ) if "intpol" in gridfile: print( "This is probably due to the interpolation scheme. Lookup", "`interpolate_frequencies` for more details.", ) if shapewarn == 2: print( "Warning: Models without frequencies overlapping with observed", "ignored due to interpolation of ratios being impossible.", ) if shapewarn == 3: print( "Warning: Models ignored due to phase shift differences being", "unapplicable to models with mixed modes.", ) if noofposind == 0: fio.no_models(starid, inputparams, "No models found") return # Print a header to signal the start of the output section in the log print("\n*****************************************") print("** **") print("** Output and results from the fit **") print("** **") print("*****************************************\n") # Find and print highest likelihood model info maxPDF_path, maxPDF_ind = stats.get_highest_likelihood( Grid, selectedmodels, inputparams ) stats.get_lowest_chi2(Grid, selectedmodels, inputparams) # Generate posteriors of ascii- and plotparams # --> Print posteriors to console and log # --> Generate corner plots # --> Generate Kiel diagrams print("\n\nComputing posterior distributions for the requested output parameters!") print("==> Summary statistics printed below ...\n") process_output.compute_posterior( starid=starid, selectedmodels=selectedmodels, Grid=Grid, inputparams=inputparams, outfilename=outfilename, gridtype=gridtype, debug=debug, developermode=developermode, validationmode=validationmode, ) # Collect additional output for plotting and saving addstats = {} if fitfreqs["active"] and fitfreqs["dnufit_in_ratios"]: addstats["dnusurf"] = dnusurfmodels if fitfreqs["active"] and fitfreqs["glitchfit"]: addstats["glitchparams"] = glitchmodels # Make frequency-related plots freqplots = inputparams.get("freqplots") if fitfreqs["active"] and len(freqplots): plot_driver.plot_all_seismic( freqplots, Grid=Grid, fitfreqs=fitfreqs, obsfreqmeta=obsfreqmeta, obsfreqdata=obsfreqdata, obskey=obskey, obs=obs, obsintervals=obsintervals, selectedmodels=selectedmodels, path=maxPDF_path, ind=maxPDF_ind, plotfname=outfilename + "_{0}." + inputparams["plotfmt"], nameinplot=inputparams["nameinplot"], **addstats, debug=debug, ) else: print( "Did not get any frequency file input, skipping ratios and echelle plots." ) # Save dictionary with full statistics if optionaloutputs: pfname = outfilename + ".json" fio.save_selectedmodels(pfname, selectedmodels) print(f"Saved dictionary to {pfname}") # Print time of completion t1 = time.localtime() print( f"\nFinished on {time.strftime('%Y-%m-%d %H:%M:%S', t1)}", f"(runtime {time.mktime(t1) - time.mktime(t0)} s).\n", ) # Save log and recover standard output sys.stdout = stdout print(f"Saved log to {outfilename}.log") # Close grid, close open plots, and try to free memory between multiple runs Grid.close() plt.close("all") gc.collect()