diff --git a/bin/posydon-setup-grid b/bin/posydon-setup-grid index f6978f52d7..31c049baaf 100755 --- a/bin/posydon-setup-grid +++ b/bin/posydon-setup-grid @@ -83,6 +83,8 @@ def find_inlist_from_scenario(source, gitcommit, system_type): else: Pwarn("git repository is already there, using that", "OverwriteWarning") + Pwarn("git repository is already there, using that", + "OverwriteWarning") inlists_dir = '{0}/.posydon_mesa_inlists'.format(os.environ['HOME']) branch = gitcommit.split('-')[0] @@ -112,6 +114,8 @@ def find_inlist_from_scenario(source, gitcommit, system_type): else: Pwarn("git repository is already there, using that", "OverwriteWarning") + Pwarn("git repository is already there, using that", + "OverwriteWarning") inlists_dir = '{0}/.user_mesa_inlists'.format(os.environ['HOME']) branch = gitcommit.split('-')[0] @@ -146,6 +150,7 @@ def find_inlist_from_scenario(source, gitcommit, system_type): stderr = subprocess.PIPE ) proc.wait() + proc.wait() # if this is looking at posydon defaults, all posydon defaults build from default common inlists if source == 'posydon': diff --git a/posydon/active_learning/psy_cris/run_params/run_psycris_sequence.py b/posydon/active_learning/psy_cris/run_params/run_psycris_sequence.py index 43e82a201f..80cb05b5e2 100644 --- a/posydon/active_learning/psy_cris/run_params/run_psycris_sequence.py +++ b/posydon/active_learning/psy_cris/run_params/run_psycris_sequence.py @@ -144,3 +144,6 @@ def main(): if __name__ == "__main__": main() + +if __name__ == "__main__": + main() diff --git a/posydon/binary_evol/MESA/step_mesa.py b/posydon/binary_evol/MESA/step_mesa.py index 1b324c3b84..5a5d924d07 100644 --- a/posydon/binary_evol/MESA/step_mesa.py +++ b/posydon/binary_evol/MESA/step_mesa.py @@ -280,10 +280,10 @@ def load_Interp(self, filename): # Check if interpolation files exist if not os.path.exists(filename): - data_download() + data_download() #TODO: specify dataset # Load interpolator - self._Interp = IFInterpolator() + self._Interp = IFInterpolator(load = True) self._Interp.load(filename=filename) def close(self): diff --git a/posydon/binary_evol/SN/profile_collapse.py b/posydon/binary_evol/SN/profile_collapse.py index 23abc1bdce..c71934a1e9 100644 --- a/posydon/binary_evol/SN/profile_collapse.py +++ b/posydon/binary_evol/SN/profile_collapse.py @@ -393,7 +393,6 @@ def do_core_collapse_BH(star, The mass of the disk radiated away in M_sun. 'BZ_jet_power_total' : float The total Blandford-Znajek jet power in erg/s. - # Additional keys that are not used in the current implementation: # 'BZ_jet_power_array' : np.array(BZ_jet_power_array), # Blandford-Znajek jet power at each shell collapse in erg/s @@ -645,6 +644,26 @@ def do_core_collapse_BH(star, BZ_jet_power_array.append(BZ_power) BZ_jet_power_total += BZ_power + # calculate the potential BZ jet power at this moment of the collapse + # We assume full efficiency for the magnetic flux and a BH spin + # dependence of a^2 for the BH spin efficiency. + # just an energy total per collapse step. + BZ_power = BZ_jet_power(M_dot=dm_disk, + eta_phi=1, + eta_a=a_BH**2) + BZ_jet_power_array.append(BZ_power) + BZ_jet_power_total += BZ_power + + # calculate the potential BZ jet power at this moment of the collapse + # We assume full efficiency for the magnetic flux and a BH spin + # dependence of a^2 for the BH spin efficiency. + # just an energy total per collapse step. + BZ_power = BZ_jet_power(M_dot=dm_disk, + eta_phi=1, + eta_a=a_BH**2) + BZ_jet_power_array.append(BZ_power) + BZ_jet_power_total += BZ_power + # Append all quantities to the arrays J_accreted_array.append(J_BH) M_BH_array.append(M_BH) diff --git a/posydon/interpolation/IF_interpolation.py b/posydon/interpolation/IF_interpolation.py index c46e13c171..c4976d8145 100644 --- a/posydon/interpolation/IF_interpolation.py +++ b/posydon/interpolation/IF_interpolation.py @@ -196,9 +196,6 @@ class relies on the BaseIFInterpolator class to perform the interpolation find_constraints_to_apply, sanitize_interpolated_quantities, ) -from posydon.interpolation.data_scaling import DataScaler -from posydon.utils.posydonwarning import Pwarn -from posydon.visualization.plot_defaults import DEFAULT_LABELS # INITIAL-FINAL INTERPOLATOR @@ -273,13 +270,16 @@ def evaluate(self, binary, sanitization_verbose=False): """ ynums = {} ycats = {} - + # s = time.time() for interpolator in self.interpolators: ynum, ycat = interpolator.evaluate(binary, sanitization_verbose) ynums = {**ynums, **ynum} ycats = {**ycats, **ycat} + # e = time.time() + # print(f"Iterated over {len(self.interpolators)} interpolators in {e - s}") + return ynums, ycats diff --git a/posydon/interpolation/constraints.py b/posydon/interpolation/constraints.py index 45fbf75de0..0ee3a662b0 100644 --- a/posydon/interpolation/constraints.py +++ b/posydon/interpolation/constraints.py @@ -49,10 +49,14 @@ import numpy as np -from posydon.utils.common_functions import ( - orbital_separation_from_period, - stefan_boltzmann_law, -) +CLASSIFICATION_KEYS = [ + "S<*>_state", + "mt_hist", + "S<*>_MOD_SN_type", + "S<*>_MOD_CO_type" +] + +N_MODELS = 11 # how many super nova models are there? # toggle this flag to enable/disable constraints (used for debugging) INTERPOLATION_CONSTRAINTS_ON = True @@ -523,3 +527,63 @@ def sanitize_interpolated_quantities(fvalues, constraints, verbose=False): constraint["constraint"]) return sanitized + + +def mt_constraint(classes): + + interpolation_class = classes["interpolation_class"] + + if interpolation_class == "initial_MT": + classes["mt_hist"] == "ini_RLO" + elif interpolation_class == "no_MT": + classes["mt_hist"] = "no_RLO" + elif interpolation_class == "stable_MT": + pass + elif interpolation_class == "unstable_MT": + pass + elif interpolation_class == "stable_reverse_MT": + pass + + +CLASS_CONSTRAINTS = { + "S<*>_state": None, + "mt_hist": mt_constraint, + "S<*>_MOD_SN_type": None, + "S<*>_MOD_CO_type": None +} + +def apply_class_constraint(key_name, classes): + + if key_name not in classes.keys(): + return + else: + CLASS_CONSTRAINTS[key_name](classes) + +def sanitize_classes(classes, ): + + assert(type(classes) == dict) + + if "interpolation_class" not in classes.keys(): + raise ValueError( + "Interpolation class must be present as a classified quantity to enforce classification constraints!" + ) + + for key in CLASSIFICATION_KEYS: + if "<*>" in key: + + for star in range(2): + key_name = key.replace("<*>", f"{star}") + + if "MOD" in key_name: + + for model in range(N_MODELS): + key_name = key_name.replace("", f"{model}") + + apply_class_constraint(key_name, classes) + + else: + apply_class_constraint(key_name, classes) + else: + + apply_class_constraint(key, classes) + diff --git a/posydon/interpolation/data_scaling.py b/posydon/interpolation/data_scaling.py index 901e66f61c..0f5d9a6fc6 100644 --- a/posydon/interpolation/data_scaling.py +++ b/posydon/interpolation/data_scaling.py @@ -6,8 +6,27 @@ ] +import sys +import warnings + import numpy as np +# Convert UserWarning to an error +warnings.simplefilter("error", RuntimeWarning) + +eps = 1.0e-16 + +SCALING_OPTIONS = [ + "none", + "min_max", + "max_abs", + # "standardize", + "log_min_max", # has + # "neg_log_min_max", # has + "log_max_abs", # has + # "log_standardize", # has + # "neg_log_standardize" # has +] class DataScaler: """Data Normalization class. @@ -68,27 +87,28 @@ def fit(self, x, method='none', lower=-1.0, upper=1.0): if method == 'min_max': assert upper > lower, "upper must be greater than lower" self.lower, self.upper = lower, upper - self.params = [x.min(axis=0), x.max(axis=0)] + self.params = [np.nanmin(x, axis=0), np.nanmax(x, axis=0)] elif method == 'log_min_max': assert upper > lower, "upper must be greater than lower" self.lower, self.upper = lower, upper - self.params = [np.log10(x.min(axis=0)), np.log10(x.max(axis=0))] + self.params = [self.log(np.nanmin(x, axis=0)), self.log(np.nanmax(x, axis=0))] + elif method == 'neg_log_min_max': assert upper > lower, "upper must be greater than lower" self.lower, self.upper = lower, upper - self.params = [np.log10((-x).min(axis=0)), - np.log10((-x).max(axis=0))] + self.params = [self.log(np.nanmin(-x, axis=0)), + self.log(np.nanmax(-x, axis=0))] elif method == 'max_abs': - self.params = [np.abs(x).max(axis=0)] + self.params = [np.nanmax(np.abs(x), axis=0)] elif method == 'log_max_abs': - self.params = [np.abs(np.log10(x)).max(axis=0)] - elif method == 'standarize': - self.params = [x.mean(axis=0), x.std(axis=0)] - elif method == 'log_standarize': + self.params = [np.nanmax(np.abs(self.log(x)), axis=0)] + elif method == 'standardize': + self.params = [np.nanmean(x, axis=0), np.nanstd(x, axis=0)] + elif method == 'log_standardize': # log will be computed in transform again - self.params = [np.log10(x).mean(axis=0), np.log10(x).std(axis=0)] - elif method == 'neg_log_standarize': # log(-x) - self.params = [np.log10(-x).mean(axis=0), np.log10(-x).std(axis=0)] + self.params = [np.nanmean(self.log(x), axis=0), np.nanstd(self.log(x), axis=0)] + elif method == 'neg_log_standardize': # log(-x) + self.params = [np.nanmean(self.log(-x), axis=0), np.nanstd(self.log(-x), axis=0)] elif method == 'log': self.params = [] elif method == 'none': # no transformation @@ -124,26 +144,26 @@ def transform(self, x): x_t = ((x - self.params[0]) / (self.params[1] - self.params[0]) * (self.upper - self.lower) + self.lower) elif self.method == 'log_min_max': - x_t = ((np.log10(x) - self.params[0]) + x_t = ((self.log(x) - self.params[0]) / (self.params[1] - self.params[0]) * (self.upper - self.lower) + self.lower) elif self.method == 'neg_log_min_max': - x_t = ((np.log10(-x) - self.params[0]) + x_t = ((self.log(-x) - self.params[0]) / (self.params[1] - self.params[0]) * (self.upper - self.lower) + self.lower) elif self.method == 'max_abs': x_t = x / self.params[0] elif self.method == 'log_max_abs': - x_t = np.log10(x) / self.params[0] - elif self.method == 'standarize': + x_t = self.log(x) / self.params[0] + elif self.method == 'standardize': x_t = (x - self.params[0]) / self.params[1] - elif self.method == 'log_standarize': + elif self.method == 'log_standardize': # log will be computed in transform again - x_t = (np.log10(x) - self.params[0]) / self.params[1] - elif self.method == 'neg_log_standarize': - x_t = (np.log10(-x) - self.params[0]) / self.params[1] + x_t = (self.log(x) - self.params[0]) / self.params[1] + elif self.method == 'neg_log_standardize': + x_t = (self.log(-x) - self.params[0]) / self.params[1] elif self.method == 'log': - x_t = np.log10(x) + x_t = self.log(x) else: # no transformation x_t = x @@ -201,24 +221,38 @@ def inv_transform(self, x_t): / (self.upper - self.lower) * (self.params[1] - self.params[0]) + self.params[0]) elif self.method == 'log_min_max': - x = 10 ** ((x_t - self.lower) / (self.upper - self.lower) + x = self.unlog((x_t - self.lower) / (self.upper - self.lower) * (self.params[1] - self.params[0]) + self.params[0]) elif self.method == 'neg_log_min_max': - x = -10 ** ((x_t - self.lower) / (self.upper - self.lower) + x = -self.unlog((x_t - self.lower) / (self.upper - self.lower) * (self.params[1] - self.params[0]) + self.params[0]) elif self.method == 'max_abs': x = x_t * self.params[0] elif self.method == 'log_max_abs': - x = 10 ** (x_t * self.params[0]) + x = self.unlog(x_t * self.params[0]) elif self.method == 'standarize': x = x_t * self.params[1] + self.params[0] elif self.method == 'log_standarize': - x = 10 ** (x_t * self.params[1] + self.params[0]) + x = self.unlog(x_t * self.params[1] + self.params[0]) elif self.method == 'neg_log_standarize': - x = -10 ** (x_t * self.params[1] + self.params[0]) + x = -self.unlog(x_t * self.params[1] + self.params[0]) elif self.method == 'log': - x = 10 ** x_t + x = self.unlog(x_t) else: # no transformation x = x_t return x + + def log(self, x): + logged = None + try: + logged = np.log10(x + eps) + except RuntimeWarning: + print(self.method) + print(x, np.isinf(x).any(), np.isnan(x).any(), (x < 0).any(), np.nanmin(x)) + # sys.exit() + + return logged + + def unlog(self, x): + return (10 ** x) - eps diff --git a/posydon/interpolation/new_interpolator.py b/posydon/interpolation/new_interpolator.py new file mode 100644 index 0000000000..47064db36d --- /dev/null +++ b/posydon/interpolation/new_interpolator.py @@ -0,0 +1,708 @@ +""" +Module implementing initial-final (IF) interpolation. + +""" + +__authors__ = [ + "Philipp Moura Srivastava ", +] + +import os +import pickle +import sys +import time +from datetime import date + +import numpy as np +from scipy.spatial import Delaunay +from sklearn.metrics import balanced_accuracy_score +from sklearn.model_selection import train_test_split + +# ML Imports +from sklearn.neighbors import KNeighborsClassifier + +# POSYDON +from posydon.grids.psygrid import PSyGrid +from posydon.interpolation.constraints import ( + find_constraints_to_apply, + sanitize_interpolated_quantities, +) +from posydon.interpolation.data_scaling import SCALING_OPTIONS, DataScaler +from posydon.interpolation.preprocessing import ( + IN_SCALING_OPTIONS, + OUT_SCALING_OPTIONS, + Transformer, + find_normalization_evaluation_matrix, +) +from posydon.utils.posydonwarning import Pwarn + +eps = 1.0e-16 + + +class IFInterpolator: + """ Class used to train interpolator and carry out interpolation. Familiarity with the over all system, which can + be gained by referencing section 3 of 2411.02376, is required to understand the documentation + """ + + def __init__(self, grids = None, in_keys = None, out_keys = None, max_k = None, load = False): + """ Class constructor + + Parameters + ---------- + grids : list of PSyGrid + Contains both the training grid and validation grid in the first and second postions, respectively + in_keys: list of strings + Contains the names of the parameters that define the input space + out_keys: dict + Keys correspond to classifiers to be trained, and values are a list of strings specifying which parameters + are to be interpolated using the respective classifier + max_k: int + The maximum number of k that is considered when optimizing k for each classifier + """ + + if type(grids) != list and not load: + sys.exit("Please provide a list of PSyGrids containing both a training and validation grid to train the interpolator") + elif load: + print("Constructed in Loading Mode") + else: + + self.in_keys = in_keys + + self.out_key_dict = out_keys + self.continuous_out_keys = sum(list(out_keys.values()), []) # keys to be interpolated which correspond to numerical quantities + self.discrete_out_keys = list(out_keys.keys()) # keys to be interpolated which correspond to discrete quantities + self.constraints = find_constraints_to_apply(self.continuous_out_keys) + + # ============= checks ============= + if "interpolation_class" not in self.discrete_out_keys: + sys.exit("The key \"interpolation_class\" needs to be provided as one of the interpolation keys") + + self.max_k = max_k + + self.training_grid = self.preprocess_grid(grids[0], training_grid = True) + self.validation_grid = self.preprocess_grid(grids[1]) + + self.triangulate(self.training_grid) + # =============== usage statistics variables ============ + self.outside_convex_hull = dict(zip(self.discrete_out_keys, [0] * len(self.discrete_out_keys))) + self.inside_convex_hull = dict(zip(self.discrete_out_keys, [0] * len(self.discrete_out_keys))) + + # variable to control whether or not we are debugging + self.debug_mode = False + + def stats(self, _print = False): + """ Returns statistics regarding the number of samples that were interpolated with their initial condition outside + of the convex hull for its respective predicted class + + Parameters + ---------- + _print: boolean + Controls whether statistics are printed as well as returned as a dictionary + + Return Values + ------------- + dict: a dictionary that has the percentage of points outside the convex hull for + each classification scheme + """ + percentages = [] + + for key in self.discrete_out_keys: + percentages.append( + self.outside_convex_hull[key] / (self.outside_convex_hull[key] + self.inside_convex_hull[key]) + ) + if _print: + print(f"Total of {sum(percentages) / len(percentages):.2f} outside of hull") + + return dict(zip(self.discrete_out_keys, percentages)) + + def train(self): + """ method used to find optimal hyperparamters for classification (k) and which normalization + schemes work best for interpolation + """ + self.is_training = True + self.classifiers = dict(# Finding classification hyperparameters + zip( + self.discrete_out_keys, + [self.find_hyperparameters(key) for key in self.discrete_out_keys] + ) + ) + self.out_scalers = dict(# Fincing interpolation normalization schemes + zip( + self.discrete_out_keys, + [self.optimize_normalization(key) for key in self.discrete_out_keys] + ) + ) + self.is_training = False + + + def interpolate(self, iv, klass, sn_model): + """ a method which performs interpolation for a respective initial value and its + predicted class (convex hull) + + Parameters + ---------- + iv: np.ndarray (3,) + an initial value containing the primary and secondary mass as well as the orbital period + klass: string + a predicted class label + sn_model: string + the specified super nova model being used + """ + + interpolated = [] + ics = {} + weights = {} + + interpolation_class_ind = self.discrete_out_keys.index("interpolation_class") + sn_class_ind = self.discrete_out_keys.index(sn_model) + klass = [klass[interpolation_class_ind], klass[sn_class_ind]] + classification_schemes = ["interpolation_class", sn_model] + + for key, c in zip(classification_schemes, klass): # interpolating based in mass transfer type and supernova outcome separately + + triangulation = self.training_grid["triangulations"][key][c] + + simplex = -1 if triangulation == "1NN" else triangulation.find_simplex(iv) + + if simplex == -1: + interpolated.extend( + self.get_nearest_neighbor(iv, key) + ) + self.outside_convex_hull[key] += 1 + continue + else: + self.inside_convex_hull[key] += 1 + + vertices = triangulation.simplices[simplex] + ics[key] = triangulation.points[vertices] + + class_inds = self.training_grid["class_inds"][key][c] + + final_values = np.array(self.training_grid["final_values"][key][class_inds][vertices].tolist()) + + # ============= handling cases where nans exist in final values ========== + nans = np.isnan(final_values) + num_nans = nans.sum(axis = 0) + not_nans = np.where((num_nans < 2) & (num_nans > 0))[0] # indices of dimensions that should not be nan despite containing nans in neighbors + nans_tie = np.where(num_nans == 2)[0] # if there is a tie in the number of nans, we flip a coin to decide if output will be nan or not + + if self.debug_mode: + print(final_values) + + for tie in nans_tie: + coin = np.random.rand() # draws from uniform distribution 0-1 + if coin > 0.5: + np.append(not_nans, tie) + + # =============== performing general interpolation ================== + if not self.is_training: + label_dict = self.out_scalers[key]["label_dict"] + + # final_values = self._transform.normalize(final_values) if self.is_training else self.out_scalers[key]["transform"][label_dict[c]].normalize(final_values) + final_values = final_values + barycentric_weights = self.compute_barycentric_coordinates(iv, triangulation.points[vertices])[..., np.newaxis] + + weights[key] = barycentric_weights + + i_values = np.sum(final_values * barycentric_weights, axis = 0) + + # =========== fixing those values that were interpolated to nan because they contained an acceptable amount of nans in neighbors ========= + for dim in not_nans: + where_not_nan = np.where(nans[:, dim] == 0) + nan_weights = barycentric_weights[where_not_nan] + nan_weights /= nan_weights.sum() + + i_values[dim] = (final_values[:, dim][where_not_nan] * nan_weights[:, 0]).sum() + + + # i_values = self._transform.unnormalize(i_values[np.newaxis, ...]) if self.is_training else self.out_scalers[key]["transform"][label_dict[c]].unnormalize(i_values[np.newaxis, ...]) + i_values = [i_values] + + interpolated.extend(i_values[0]) + + + meta_data = { + "weights": weights, + "ics": ics, + "ic": iv, + "interpolated": interpolated + } + + return interpolated, meta_data + + def evaluate(self, initial_values, sn_model = "S1_SN_MODEL_v2_01_SN_type"): + """ The main method of the class used to classify and interpolate + + Parameters + ---------- + initial_values: np.ndaarray + starting stellar masses and orbital period in numpy array + sn_model: string + specifies supernova model to be used during interpolation + + Return Values + ------------- + interpolated_values: np.ndarray + contains all interpolated values, all classification schemes are concatenated into one array + classes: np.ndarray + a list of lists, each list has two classes. The first is the mass transfer type and the second + is the compact object type which is used for the supernova + n: list of dicts + each dict containing meta data information about the interpolation such as + neighbors used and distances found + + """ + + if self.classifiers is None: + sys.exit("Please find classifier hyperparameters before using interpolator") + + interpolation_class_ind = self.discrete_out_keys.index("interpolation_class") + + classes = np.array([ + cl["classifier"].predict(cl["transform"].normalize(initial_values)) + for cl in self.classifiers.values()]).T + + interpolated_values = [] + n = [] + + for iv, klass in zip(initial_values, classes): + if klass[interpolation_class_ind] == "initial_MT": + continue + + interpolated, meta_data = self.interpolate(iv, klass, sn_model) + + interpolated = self.apply_continuous_constraints(interpolated, sn_model) + interpolated_values.append(interpolated) + n.append(meta_data) + + interpolated_values = np.array(interpolated_values) + + classes = np.array(classes) + + return interpolated_values, classes, n + + def find_hyperparameters(self, klass): + """ finds optimal k for a specified classifier + + Parameters + ---------- + klass: string + classifier to consider + + Return Values + ------------- + dict: dict + contains classifier information and more + """ + + input_matrix = [] + """ + matrix that considers different number of neighbors with different + in_scaling options + """ + + for k in range(1, self.max_k): + row = [] + for opt in IN_SCALING_OPTIONS: + row.append( + [k, opt] + ) + input_matrix.append(row) + + kwargs = { + "input_matrix": input_matrix, + "self": self, + "klass": klass + } + + def kwargs_fnc(**kwargs): + """ helper function handling passing of arguments for functions given + to the preprocessing modules + """ + + kwargs = { + "self": kwargs["kwargs"]["self"], + "k": kwargs["item"][0], + "scaling": kwargs["item"][1] + } + + return kwargs + + def eval_fnc(self, k, scaling): + """ the preprocessing module evaluates every point in input_matrix (specified above) + which considers different input_scalings and numbers of neighbors. This function gives + a score for each value of k paired with a normalization + + Parameters + ---------- + k: int + number of neighbors used + scaling: string + specifies normalization + + Return Values + ------------- + bacc: float + an accuracy score + stats: statistics used + """ + + validation_classifier = KNeighborsClassifier(n_neighbors = k, weights = "distance") + + training_initial_values = self.training_grid["initial_values"] + + transform = Transformer(training_initial_values, scaling) + training_initial_values = transform.normalize(training_initial_values) + + validation_classifier.fit( + training_initial_values, + self.training_grid["final_classes"][klass] + ) + + validation_initial_values = self.validation_grid["initial_values"] + validation_initial_values = transform.normalize(validation_initial_values) + predicted_classes = validation_classifier.predict(validation_initial_values) + + bacc = balanced_accuracy_score( + self.validation_grid["final_classes"][klass], + predicted_classes + ) + + return bacc, transform + + eval_matrix, stat_matrix = find_normalization_evaluation_matrix(eval_fnc, kwargs_fnc, kwargs) # getting matrix + + k_star = list(np.unravel_index(eval_matrix.argmax(), eval_matrix.shape)) # optimal number of neighbors + + classifier = KNeighborsClassifier(n_neighbors = k_star[0] + 1, weights = "distance") # defining classifier + + training_initial_values = self.training_grid["initial_values"] + + scaling = IN_SCALING_OPTIONS[k_star[1]] + + transform = Transformer(training_initial_values, scaling) + training_initial_values = transform.normalize(training_initial_values) # taking care of normalization + + classifier.fit( + training_initial_values, + self.training_grid["final_classes"][klass] + ) # training classifier + + return { + "classifier": classifier, + "transform": stat_matrix[*k_star], + "log": "log" in IN_SCALING_OPTIONS[k_star[1]], + "k_star": k_star, + "eval_matrix": eval_matrix + } + + def optimize_normalization(self, key): + """ method to find optimal normalization per class + + Parameters + ---------- + key: string + specifies parameter + + Return Values + ------------- + dict: dict + interpolator information and more + """ + + input_matrix = [] + + labels = np.unique(self.training_grid["final_classes"][key]) + labels = np.delete(labels, np.where(labels == "initial_MT")[0]) + + for label in labels: + row = [] + for opt in OUT_SCALING_OPTIONS: + row.append( + [label, opt] + ) + input_matrix.append(row) + + kwargs = { + "input_matrix": input_matrix, + "self": self, + "key": key + } + + def kwargs_fnc(**kwargs): + """ helper function handling passing of arguments for functions given + to the preprocessing modules + """ + + kwargs = { + "self": kwargs["kwargs"]["self"], + "key": kwargs["kwargs"]["key"], + "klass": kwargs["item"][0], + "scaling": kwargs["item"][1] + } + + return kwargs + + def eval_fnc(self, key, klass, scaling): + """ the preprocessing module evaluates every point in input_matrix (specified above) + which considers different classes and output scalings. This function gives + a score for each value of class label paired with a normalization + + Parameters + ---------- + key: string + parameter considered + klass: string + class label + scaling: string + specifies normalization + + Return Values + ------------- + errors: float + an accuracy score + stats: statistics used + """ + self.training = True + self.scaling = scaling + + klass_inds = np.where(self.validation_grid["final_classes"][key] == klass)[0] + + training_final_values = self.training_grid["final_values"][key][self.training_grid["class_inds"][key][klass]] + + self._transform = Transformer(training_final_values, scaling) + + interpolated, classes, _ = self.evaluate(self.validation_grid["initial_values"][klass_inds]) + + classes = classes[np.where(classes[:, 0] != "initial_MT")[0]] + predicted_klass_inds = np.where((classes[:, 0] == klass) | (classes[:, 1] == klass))[0] + + # needs to be fixed to include any arbitrary SN model but this will do for now + ground_truth = np.concatenate( + [self.validation_grid["final_values"]["interpolation_class"], self.validation_grid["final_values"]["S1_SN_MODEL_v2_01_SN_type"]], axis = 1 + ) + + # some interpolated and ground truth values will be NaN, e.g., for keys describing CE phenomnea are NaN if no CE is present, need to filter these NaNs out + nans_mask = ~(np.isnan(interpolated[predicted_klass_inds]) + np.isnan(ground_truth[klass_inds][predicted_klass_inds])) + + errors = np.abs( + (interpolated[predicted_klass_inds][nans_mask] - ground_truth[klass_inds][predicted_klass_inds][nans_mask]) / + (ground_truth[klass_inds][predicted_klass_inds][nans_mask] + eps) + ) + + self.training = False + + error_mean = errors.mean() if len(errors) > 0 else np.inf + return error_mean, Transformer(training_final_values, scaling) + + eval_matrix, stat_matrix = find_normalization_evaluation_matrix(eval_fnc, kwargs_fnc, kwargs) # finding normalization + + # opt = tuple(np.unravel_index(eval_matrix.argmin(axis = 0), eval_matrix.shape)) + opt = [eval_matrix.argmin(axis = 0), np.arange(eval_matrix.shape[1])] + + return { + "transform": stat_matrix[opt[0], opt[1]], + "eval_matrix": eval_matrix, + "label_dict": dict(zip(labels, np.arange(len(labels)))) + } + + # =================== helper methods below =========================== + + def preprocess_grid(self, grid, training_grid = False): + """ method that takes PSyGrid object and processes it into nice + numpy arrays and dictionaries + + Parameters + --------- + grid: PSyGrid + grid + training_grid: bool + specifies whether this is a training grid (regularly sampled) + + Return Values + ------------- + dict + PSyGrid processed into dict such that it contains only information + important to IF interpolation + """ + + final_values = np.array(grid.final_values[self.continuous_out_keys].tolist()) + + valid_inds = np.where( + (grid.final_values["interpolation_class"] != "not_converged") & + (grid.final_values["interpolation_class"] != "ignored_no_RLO") & + (grid.final_values["interpolation_class"] != "ignored_no_binary_history") + )[0] + + initial_values = np.array(grid.initial_values[self.in_keys][valid_inds].tolist()) + # determining if should interp in q + if training_grid: + self.interp_in_q = False + + initial_values = np.log10(initial_values + eps) + + if self.interp_in_q: + initial_values[:, 1] = (10**initial_values[:, 1] - eps) / (10**initial_values[:, 0] - eps) + + if training_grid: + self.iv_min = initial_values.min(axis = 0, keepdims = True) + self.iv_max = initial_values.max(axis = 0, keepdims = True) + + class_inds = {} + + for key in self.discrete_out_keys: + class_labels = np.unique(grid.final_values[valid_inds][key]) + class_inds[key] = dict(zip( + class_labels, + [np.where(grid.final_values[valid_inds][key] == label)[0] for label in class_labels] + )) + + return { + "initial_values": 10**initial_values, + "final_values": dict(zip(self.out_key_dict.keys(), [np.array(grid.final_values[valid_inds][keys].tolist()) for keys in self.out_key_dict.values()])), # np.array(grid.final_values[self.continuous_out_keys][valid_inds].tolist()), + "final_classes": dict(zip(self.discrete_out_keys, np.array(grid.final_values[self.discrete_out_keys][valid_inds].tolist()).T)), + "class_inds": class_inds, + } + + def triangulate(self, grid_dict): + """ method that constructs Delaunay triangulations stored in class memory + when given a grid + + Parameters + ---------- + grid_dict: dict + a dictionary containing grid information created with preprocess grid + method + + """ + + triangulations = {} + + for label_name in self.discrete_out_keys: + classes = np.unique(grid_dict["final_classes"][label_name]).tolist() + if "initial_MT" in classes: + classes.remove("initial_MT") + + class_triangulations = {} + + for klass in classes: + + class_inds = grid_dict["class_inds"][label_name][klass] + + if class_inds.shape[0] < 5: + print(f"too few training samples for {klass}") + class_triangulations[klass] = "1NN" + else: + + try: + class_triangulations[klass] = Delaunay(grid_dict["initial_values"][class_inds]) + except: + print(f"Geometry wrong for {klass}, using 1NN") + class_triangulations[klass] = "1NN" + + triangulations[label_name] = class_triangulations + + grid_dict["triangulations"] = triangulations + + def compute_barycentric_coordinates(self, point, coords): + """ helper method that computes barycentric coordinates which are the weights to use + for the nearest neighbors + + Parameters + ---------- + point: np.ndarray + initial values inside tetrahedral (we are trying to predict) + coords: np.ndarray + coordinates of neighbors which are vertices of tetrahedra + + Return Values + ------------- + np.ndarray: weight for each coord + """ + + T = np.array([ + coords[0] - coords[3], + coords[1] - coords[3], + coords[2] - coords[3] + ]) # our matrix + T = T.T + T_I = np.linalg.inv(T) + + r_a = point - coords[3] + + weights = (T_I @ r_a).tolist() + + weights.append(1 - weights[0] - weights[1] - weights[2]) + + weights = np.array(weights) / sum(weights) + + return weights + + def get_nearest_neighbor(self, iv, key): + """ finds the nearest neighbor in the training grid + + Parameters + ---------- + iv: np.ndarray + contains initial stellar masses and orbital period + key: string + specifies parameter + + Return Values + ------------- + np.ndarray: output parameters of nearest neighbors + """ + + dists = np.sqrt(np.square(self.training_grid["initial_values"] - iv).sum(axis = 1)) + sorted_inds = dists.argsort() + + return np.array(self.training_grid["final_values"][key][sorted_inds[0]].tolist()) + + def apply_continuous_constraints(self, interpolated, sn_model): + """ method that applies constraints to our outputs + + Parameters + ---------- + interpolated: np.ndarray + interpolated values + sn_model: string + specifies super nova model used + + Return Values + ------------- + np.ndarray: contains interpolated values with constraints applied + + """ + keys = self.out_key_dict["interpolation_class"] + self.out_key_dict[sn_model] + + sanitized = sanitize_interpolated_quantities( + dict(zip(keys, interpolated)), + self.constraints, verbose=False + ) + return np.array([sanitized[key] for key in keys]) + + def save(self, filename): + """ + Saves the IFInterpolator instance to a pickle file. + + Parameters + ---------- + filename : str + Path or filename where the object should be saved (e.g., 'interpolator.pkl'). + """ + try: + with open(filename, 'wb') as f: + pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL) + print(f"Successfully saved interpolator to {filename}") + except Exception as e: + print(f"Error saving interpolator: {e}") + + def load(self, filename): + """ + Loads an IFInterpolator instance from a pickle file. + """ + with open(filename, 'rb') as f: + return pickle.load(f) + + + diff --git a/posydon/interpolation/preprocessing.py b/posydon/interpolation/preprocessing.py new file mode 100644 index 0000000000..ceb2a43df4 --- /dev/null +++ b/posydon/interpolation/preprocessing.py @@ -0,0 +1,121 @@ +""" +Module implementing preprocessing for IF Interpolation +""" + +__authors__ = [ + "Philipp Moura Srivastava ", +] + +import numpy as np + +eps = 1.0e-32 + +IN_SCALING_OPTIONS = [ + "none", + "min-max", + "standard", + "log_min-max", + "log_standard" +] + +OUT_SCALING_OPTIONS = [ + "none", + "log_min-max", + "log_standard", + "min-max", + "standard" +] + +class Transformer: + + def __init__(self, data, scaling): + """ + If a dimension contains negative values we assume that it is in log space and unlog it. + This is an assumption that we know doesn't hold since things like rates can be negative, but it + simplifies the preprocessing code for now. + """ + data = data.copy() + self.logged = (data < 0.0).any(axis = 0) + data[:, self.logged] = 10**data[:, self.logged] + + computations = [ + lambda data: [0, 1], + lambda data: [data.min(axis = 0), data.max(axis = 0) - data.min(axis = 0)], + lambda data: [data.mean(axis = 0), data.std(axis = 0)], + lambda data: [np.log10(data + eps).min(axis = 0), np.log10(data + eps).max(axis = 0) - np.log10(data + eps).min(axis = 0)], + lambda data: [np.log10(data + eps).mean(axis = 0), np.log10(data + eps).std(axis = 0)], + ] + compute = dict(zip(IN_SCALING_OPTIONS, computations)) # this line assumes that all other options are a subset of IN_SCALING_OPTION + + self.log = "log" in scaling + self.shift, self.scale = compute[scaling](data) + + def normalize(self, data): + + data = data.copy() + + if self.logged.any() and data.shape[1] == self.logged.shape[0]: + data[:, self.logged] = 10**data[:, self.logged] + + if self.log: + data = np.log10(data + eps) + + return (data - self.shift) / (self.scale + eps) + + # def unnormalize(self, data): + + # data = data.copy() + + # data = (data * (self.scale + eps)) + self.shift + + # if self.log: + # data = 10**data - eps + # if self.logged.any() and data.shape[1] == self.logged.shape[0]: + # data[:, self.logged] = np.log10(data[:, self.logged]) + # else: + # if self.logged.any() and data.shape[1] == self.logged.shape[0]: + # data[:, self.logged] = np.log10(data[:, self.logged]) + + # return data + def unnormalize(self, data): + data = data.copy() + data = (data * (self.scale + eps)) + self.shift + if self.log: + non_logged = ~self.logged + if non_logged.any() and data.shape[1] == self.logged.shape[0]: + data[:, non_logged] = 10**data[:, non_logged] - eps + elif data.shape[1] != self.logged.shape[0]: + data = 10**data - eps # fallback: no column info available + if self.logged.any() and data.shape[1] == self.logged.shape[0]: + data[:, self.logged] = np.log10(np.maximum(data[:, self.logged], eps)) + return data + + + +def find_normalization_evaluation_matrix(eval_fnc, kwarg_fnc, kwargs): + # eval_fnc - test_classifier?, what's changing + # kwarg_fnc - input to eval_fnc + # kwargs - inputs we iterate over + + normalization_eval_matrix = [] + normalization_stat_matrix = [] + + for row in kwargs["input_matrix"]: + eval_row = [] + stat_row = [] + + for col in row: + acc, stat = eval_fnc(**kwarg_fnc(**{"item": col, "kwargs": kwargs})) + + eval_row.append( + acc + ) + stat_row.append(stat) + + normalization_eval_matrix.append(eval_row) + normalization_stat_matrix.append(stat_row) + + normalization_eval_matrix = np.array(normalization_eval_matrix) + normalization_stat_matrix = np.array(normalization_stat_matrix) + + return normalization_eval_matrix, normalization_stat_matrix diff --git a/posydon/unit_tests/interpolation/__init__.py b/posydon/unit_tests/interpolation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/posydon/unit_tests/interpolation/test_IF_interpolation.py b/posydon/unit_tests/interpolation/test_IF_interpolation.py new file mode 100644 index 0000000000..72ed259132 --- /dev/null +++ b/posydon/unit_tests/interpolation/test_IF_interpolation.py @@ -0,0 +1,22 @@ +"""Unit tests of posydon/interpolation/IF_interpolation.py +""" + +__authors__ = [ + "Philipp Rajah de Moura Srivastava " +] + +# import the module which will be tested +from posydon.interpolation.new_interpolator import IFInterpolator + +# import other needed code for the tests, which is not already imported in the +# module you like to test + + +# define single test functions +def test_name(): + pass + +# define test classes collecting several test functions +class TestClass: + def test_name(self): + assert True diff --git a/posydon/visualization/new_interpolation.py b/posydon/visualization/new_interpolation.py new file mode 100644 index 0000000000..798ce06fe9 --- /dev/null +++ b/posydon/visualization/new_interpolation.py @@ -0,0 +1,391 @@ +""" Module to evaluate IFInterpolator class """ + +__authors__ = [ + "Philipp Moura Srivastava " + "Simone Bavera " +] + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns + +eps = 1.0e-16 + +class EvaluateIFInterpolator: + """ Class that is helpful for evaluating interpolation performance + """ + def __init__(self, interpolator, test_grid, sn_method = "S1_SN_MODEL_v2_01_SN_type"): + """ Initialize the EvaluateIFInterpolator class + + Parameters + ---------- + interpolator : IFInterpolator + Interpolator that user wants to test + test_grid : PSyGrid + Grid object containing testing tracks + + """ + + self.interpolator = interpolator + + self.test_grid = test_grid + + # assuming that in_keys are the same for all interpolators + self.in_keys = interpolator.in_keys + self.out_keys = interpolator.continuous_out_keys + self.sn_method = sn_method + + self.__compute_errs() + + def __compute_errs(self): + """ Method that computes both interpolation and classification errors """ + + iv = np.array(self.test_grid.initial_values[self.in_keys].tolist(), + dtype=float) # initial values + fv = np.array(self.test_grid.final_values[self.out_keys].tolist(), + dtype=float) # final values + ic = self.test_grid.final_values["interpolation_class"] # final values + + + + i, c, _ = self.interpolator.evaluate(iv) # interpolated + ivalid_inds = np.where(c[:, 0] != "initial_MT")[0] + fv = fv[ivalid_inds] + + self.errs = {} + nans_mask = np.isnan(fv) + np.isnan(i) + + with np.errstate(divide='ignore', invalid='ignore'): + self.errs["relative"] = np.abs((fv - i) / (fv + eps)) + + self.errs["absolute"] = np.abs(fv - i) + self.errs["valid_inds"] = ivalid_inds + + + cvalid_inds = np.where( + self.test_grid.final_values["interpolation_class"] != "not_converged" + )[0] + + self.cfv = self.test_grid.final_values[cvalid_inds] + + _, c, _ = self.interpolator.evaluate(iv[cvalid_inds]) # classifying + + # computing confusion matrices + self.matrices = {} + + for i, value in enumerate(c.T): + + key = "interpolation_class" if i == 0 else self.sn_method + + labels = self.cfv[key] + + classes = self.__find_labels(key) + + matrix = {} + + # catch cases where nothing is cl,assified, e.g. S2_MODELXX_SN_type + # when S2 is a compact object + if len(classes) == 1 and classes[0] == 'None': + matrix['None'] = 1. + else: + for _class in classes: + class_inds = np.where(labels == _class)[0] + pred_classes, counts = np.unique(value[class_inds], return_counts = True) + + row = {c: 0 for c in classes} + + for pred_class, count in zip(pred_classes, counts): + row[pred_class] = count / len(class_inds) + + matrix[_class] = row + + + self.matrices[key] = matrix + + # saving classes + self.c = c + self.ic = ic + + + + def __format(self, s, title = False): + """ Method that formats keys for plots + + Parameters + ---------- + s : str + string to be formatted + + """ + + return s.replace("_", " ").title() + + def __find_labels(self, key): + """ Method that finds labels in classifier + + Parameters + ---------- + key : str + name of the classifier + + Returns + ------- + list of class labels + + """ + + labels = np.unique(self.interpolator.training_grid["final_classes"][key]) + + return labels + + def __clean_errs(self, errs): + nans = np.isnan(errs).any(axis = 1) + infs = np.isinf(errs).any(axis = 1) + + # errs += 1.0e-16 + # errs = np.array(np.log10(errs[(~np.isnan(errs)) & (~np.isinf(errs))]), dtype = object) # dropping nans and infs + errs = errs[(~np.isnan(errs).any(axis = 1)) & (~np.isinf(errs).any(axis = 1))] + + return errs + + + def violin_plots(self, err_type = "relative", keys = None, + save_path = None, close_fig = False): + """ Method that plots distribution of specified error for given keys and + optionally saves it. + + Parameters + ---------- + + err_type : str + Either relative or absolute, default is relative + keys : list + A list of keys for which the errors will be shown, by default is all of them + save_path: str + The path where the figure should be saved to + close_fig: bool + Flag whether figure should be closed + + """ + + if keys is None: + keys = self.out_keys + + k_inds = [self.out_keys.index(key) for key in keys] + dirty_errs = self.errs[err_type][:, k_inds] + errs = np.log10(self.__clean_errs(self.errs[err_type][:, k_inds]) + 1.0e-16) + + n_tracks = self.test_grid.final_values["star_1_mass"].shape[0] + + + stable_inds = np.where(self.ic[self.errs["valid_inds"]] == "stable_MT") + no_inds = np.where(self.ic[self.errs["valid_inds"]] == "no_MT") + unstable_inds = np.where(self.ic[self.errs["valid_inds"]] == "unstable_MT") + #TODO: add interpolation class "stable_reverse_MT" + + stable_errs = np.log10(self.__clean_errs(self.errs[err_type].T[k_inds].T[stable_inds]) + 1.0e-16) + no_errs = np.log10(self.__clean_errs(self.errs[err_type].T[k_inds].T[no_inds]) + 1.0e-16) + unstable_errs = np.log10(self.__clean_errs(self.errs[err_type].T[k_inds].T[unstable_inds]) + 1.0e-16) + + + plt.rcParams.update({"font.size": 32, "font.family": "stixgeneral"}) + + fig, axs = plt.subplots(1, 1, + figsize = (24, 10), + tight_layout = True) + + rel_plot = axs.violinplot(errs, showmedians = True, points = 1000) + stable_plot = axs.violinplot(stable_errs, showmedians = True, points = 1000) + no_plot = axs.violinplot(no_errs, showmedians = True, points = 1000) + unstable_plot = axs.violinplot(unstable_errs, showmedians = True, points = 1000) + + + axs.set_title(f"Distribution of {err_type.capitalize()} Errors") + axs.set_xticks(np.arange(1, len(keys) + 1), + labels = [ + f"{self.__format(ec)} ({(med * 100):.2f}%)" for ec, med in zip(keys, np.nanmedian(dirty_errs, axis = 0)) + ], rotation = 20) + axs.set_ylim(-4, 2) + axs.set_ylabel("Errors in Log 10 Scale") + axs.grid(axis = "y") + + def halve_paths(field, color, right = True): + + for i, path in enumerate(field.get_paths()): + + # getting mean + m = np.mean(path.vertices[:, 0]) + + first = m if right == True else -np.inf + second = np.inf if right == True else m + + # modify the paths to not go further left than the center + field.get_paths()[i].vertices[:, 0] = np.clip(path.vertices[:, 0], first, second) + field.set_edgecolor(color) + + def customize_violinplot(plot, color, outlined = False, right = True): + + halve_paths(plot["cmins"], color, right = right) + halve_paths(plot["cmaxes"], color, right = right) + halve_paths(plot["cmedians"], color, right = right) + + for pc in plot["bodies"]: + + halve_paths(pc, color, right = right) + + + pc.set_facecolor(color if not outlined else "None") + pc.set_edgecolor(color) + pc.set_linewidth(4) + pc.set_alpha(0.75) + + customize_violinplot(rel_plot, "coral") + customize_violinplot(stable_plot, "#1e90ff", True, False) + customize_violinplot(no_plot, "crimson", True, False) + customize_violinplot(unstable_plot, "olive", True, False) + + axs.legend( + [rel_plot["bodies"][0], stable_plot["bodies"][0], no_plot["bodies"][0], unstable_plot["bodies"][0]], + ["Relative Error", "Stable MT Error", "No MT Error", "Unstable MT Error"], + bbox_to_anchor = (0, 1.02, 1, 0.2), + loc = "lower left", + mode = "expand", + ncol = 4 + ) + + plt.show() + + if save_path is not None: + fig.save(save_path) + + # close figure + if close_fig: + plt.close(fig) + + def confusion_matrix(self, key, params = {}, save_path = None, + close_fig = False): + """ Method that plots confusion matrices to evaluate classification + + Parameters + ---------- + key : str + The key for the classifier of interest + params : dict + Extra params to pass to matplolib, x_labels (list), y_labels (list), title (str) + save_path : str + The path where the figure should be saved to + close_fig: bool + Flag whether figure should be closed + + """ + + if key not in self.matrices.keys(): + raise Exception("Key not in List of Matrices") + + arr_mat = [] + + for k, value in self.matrices[key].items(): + arr_mat.append(list(value.values())) + + figsize = params["figsize"] if "figsize" in params.keys() else (4, 8) + + fig, ax = plt.subplots(1, 1, figsize = figsize, constrained_layout = True) + + im = ax.imshow(arr_mat) + + x_axis = [self.__format(x) for x in self.matrices[key].keys()] if "x_axis" not in params.keys() else params["x_axis"] + y_axis = [self.__format(y) for y in self.matrices[key][list(self.matrices[key].keys())[0]].keys()] if "y_axis" not in params.keys() else params["y_axis"] + title = f"Confusion Matrix for {self.__format(key)}" if "title" not in params.keys() else params["title"] + + # Show all ticks and label them with the respective list entries + ax.set_xticks(np.arange(len(x_axis)), labels = x_axis) + ax.set_yticks(np.arange(len(y_axis)), labels = y_axis) + + # Rotate the tick labels and set their alignment. + plt.setp(ax.get_xticklabels(), rotation = 45, ha = "right", + rotation_mode = "anchor") + + # Loop over data dimensions and create text annotations. + for i in range(len(arr_mat)): + for j in range(len(arr_mat[i])): + text = ax.text(j, i, f"{100 * arr_mat[i][j]:.2f}", + ha = "center", va = "center", color = "w" if arr_mat[i][j] < 0.9 else "black") + + ax.set_xlabel("Predicted") + ax.set_ylabel("Actual") + ax.set_title(title) + + cax = ax.inset_axes([1.1, 0.1, 0.05, 0.8]) + + fig.colorbar(im, ax = ax, cax = cax, pad = 1) + + fig.tight_layout() + + if save_path is not None: # saving + fig.save(save_path) + + # close figure + if close_fig: + plt.close(fig) + + def classifiers(self): + """ Method that lists classifiers available """ + _, c, _ = self.interpolator.evaluate( + np.array([ + self.test_grid.initial_values[self.in_keys].tolist() + ])[0] + ) + + return list(classes.keys()) + + def keys(self): + """ Method that lists out keys available """ + + return self.interpolator.continuous_out_keys + + def decision_boundaries(self): + + pass + + def plot2D(self, key, slice_3D_var_str, slice_3D_var_range, PLOT_PROPERTIES): + + k_ind = self.out_keys.index(key) + + if slice_3D_var_str == 'mass_ratio': + var = self.test_grid.initial_values["star_2_mass"] / self.test_grid.initial_values["star_1_mass"] + elif slice_3D_var_str == 'star_2_mass': + var = self.test_grid.initial_values["star_2_mass"] + else: + raise ValueError("slice_3D_var_str must be either 'mass_ratio' or 'star_2_mass'") + + slice = (var >= slice_3D_var_range[0]) & (var <= slice_3D_var_range[1]) + + slice_errs = self.errs["relative"].T[k_ind] + slice_errs = np.array([slice_errs[i] if in_slice and i in self.errs["valid_inds"][0] else np.nan for i, in_slice in enumerate(slice)]) + + # find inf and assign large value else they are not plotted + slice_errs[np.isinf(slice_errs)] = 1e99 + + fig = self.test_grid.plot2D('star_1_mass', 'period_days', slice_errs, + termination_flag='interpolation_class_errors', + grid_3D=True, slice_3D_var_str=slice_3D_var_str, + slice_3D_var_range=slice_3D_var_range, + verbose=False, **PLOT_PROPERTIES) + + + def violinplot_with_nans(data, axs, **kwargs): + """ + Wrapper around plt.violinplot that handles NaNs by dropping them per column. + + data: 2D array-like of shape (n, d), or 1D array + """ + data = np.asarray(data, dtype=float) + + if data.ndim == 1: + clean = [data[~np.isnan(data)]] + else: + # Extract each column, dropping NaNs independently + clean = [data[:, i][~np.isnan(data[:, i])] for i in range(data.shape[1])] + + return axs.violinplot(clean, **kwargs)