diff --git a/src/quantem/__init__.py b/src/quantem/__init__.py index b9aa54b4..3c1bc876 100644 --- a/src/quantem/__init__.py +++ b/src/quantem/__init__.py @@ -8,6 +8,7 @@ from quantem.core import visualization as visualization from quantem import imaging as imaging +from quantem import spectroscopy as spectroscopy from quantem import diffractive_imaging as diffractive_imaging __version__ = version("quantem") diff --git a/src/quantem/core/datastructures/__init__.py b/src/quantem/core/datastructures/__init__.py index dfb5b47a..60cddd99 100644 --- a/src/quantem/core/datastructures/__init__.py +++ b/src/quantem/core/datastructures/__init__.py @@ -5,3 +5,4 @@ from quantem.core.datastructures.dataset4d import Dataset4d as Dataset4d from quantem.core.datastructures.dataset3d import Dataset3d as Dataset3d from quantem.core.datastructures.dataset2d import Dataset2d as Dataset2d +from quantem.core.datastructures.dataset1d import Dataset1d as Dataset1d diff --git a/src/quantem/core/datastructures/dataset1d.py b/src/quantem/core/datastructures/dataset1d.py new file mode 100644 index 00000000..79d4d957 --- /dev/null +++ b/src/quantem/core/datastructures/dataset1d.py @@ -0,0 +1,216 @@ +from typing import Self + +import matplotlib.pyplot as plt +import numpy as np +from numpy.typing import NDArray + +from quantem.core.datastructures.dataset import Dataset +from quantem.core.utils.validators import ensure_valid_array + + +@Dataset.register_dimension(1) +class Dataset1d(Dataset): + """1D dataset class that inherits from Dataset. + + This class represents a 1D dataset, such as spectra from EDS or EELS. + + Attributes + ---------- + None beyond base Dataset. + """ + + def __init__( + self, + array: NDArray, + name: str, + origin: NDArray | tuple | list | float | int, + sampling: NDArray | tuple | list | float | int, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + metadata: dict = {}, + _token: object | None = None, + ): + """Initialize a 1D dataset. + + Parameters + ---------- + array : NDArray + The underlying 1D array data + name : str + A descriptive name for the dataset + origin : NDArray | tuple | list | float | int + The origin coordinates for each dimension + sampling : NDArray | tuple | list | float | int + The sampling rate/spacing for each dimension + units : list[str] | tuple | list + Units for each dimension + signal_units : str, optional + Units for the array values, by default "arb. units" + _token : object | None, optional + Token to prevent direct instantiation, by default None + """ + super().__init__( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + metadata=metadata, + _token=_token, + ) + + @classmethod + def from_array( + cls, + array: NDArray, + name: str | None = None, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + ) -> Self: + """Create a Dataset1d from a 1D array. + + Parameters + ---------- + array : NDArray + 1D array with shape (length) + name : str | None + Dataset name. Default: "1D dataset" + origin : NDArray | tuple | list | float | int | None + Origin for each dimension. Default: [0] + sampling : NDArray | tuple | list | float | int | None + Sampling for each dimension. Default: [1] + units : list[str] | tuple | list | None + Units for each dimension. Default: ["pixels"] + signal_units : str + Units for array values. Default: "arb. units" + + Returns + ------- + Dataset1d + + Examples + -------- + >>> import numpy as np + >>> from quantem.core.datastructures import Dataset1d + >>> arr = np.random.rand(10) + >>> data = Dataset3d.from_array(arr) + >>> data.shape + (10,) + + With calibration: + + >>> data = Dataset1d.from_array( + ... arr, + ... sampling=[0.15], + ... units=["eV"], + ... ) + + Visualize: + + >>> data.show() # all frames in grid + >>> data.show(index=0) # single frame + >>> data.show(ncols=2) # 2 columns + """ + array = ensure_valid_array(array, ndim=1) + return cls( + array=array, + name=name if name is not None else "1D dataset", + origin=origin if origin is not None else np.zeros(1), + sampling=sampling if sampling is not None else np.ones(1), + units=units if units is not None else ["pixels"], + signal_units=signal_units, + _token=cls._token, + ) + + @classmethod + def from_shape( + cls, + shape: tuple[int], + name: str = "constant 1D dataset", + fill_value: float = 0.0, + origin: NDArray | tuple | list | float | int | None = None, + sampling: NDArray | tuple | list | float | int | None = None, + units: list[str] | tuple | list | None = None, + signal_units: str = "arb. units", + ) -> Self: + """Create a Dataset1d filled with a constant value. + + Parameters + ---------- + shape : tuple[int, int, int] + Shape (n_frames, height, width) + name : str + Dataset name. Default: "constant 1D dataset" + fill_value : float + Value to fill array with. Default: 0.0 + origin : NDArray | tuple | list | float | int | None + Origin for each dimension + sampling : NDArray | tuple | list | float | int | None + Sampling for each dimension + units : list[str] | tuple | list | None + Units for each dimension + signal_units : str + Units for array values + + Returns + ------- + Dataset1d + + Examples + -------- + >>> data = Dataset1d.from_shape((10000)) + >>> data.shape + (10000,) + >>> data.array.max() + 0.0 + """ + array = np.full(shape, fill_value, dtype=np.float32) + return cls.from_array( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + ) + + def show( + self, + title: str | None = None, + returnfig: bool = False, + **kwargs, + ): + """ + Plots 1D dataset + + Parameters + ---------- + scalebar: ScalebarConfig or bool + If True, displays scalebar + title: str + Title of Dataset + **kwargs: dict + Keyword arguments for show_2d + """ + + if title is None: + title = self.name + + fig, (ax) = plt.subplots(1, 1, figsize=(4, 4)) + + ax.plot( + float(self.origin[0]) + float(self.sampling[0]) * np.arange(self.shape[0]), + self.array, + linewidth=1.5, + ) + ax.set_xlabel(self.units[0]) + ax.set_ylabel(self.signal_units) + ax.set_title(title) + + fig.tight_layout() + plt.show() + + return (fig, ax) if returnfig else None diff --git a/src/quantem/core/datastructures/dataset3d.py b/src/quantem/core/datastructures/dataset3d.py index 1af7b4e6..91cdeb2c 100644 --- a/src/quantem/core/datastructures/dataset3d.py +++ b/src/quantem/core/datastructures/dataset3d.py @@ -30,6 +30,7 @@ def __init__( sampling: NDArray | tuple | list | float | int, units: list[str] | tuple | list, signal_units: str = "arb. units", + metadata: dict = {}, _token: object | None = None, ): """Initialize a 3D dataset. @@ -58,6 +59,7 @@ def __init__( sampling=sampling, units=units, signal_units=signal_units, + metadata=metadata, _token=_token, ) diff --git a/src/quantem/core/io/__init__.py b/src/quantem/core/io/__init__.py index 2780eae4..de0df5f8 100644 --- a/src/quantem/core/io/__init__.py +++ b/src/quantem/core/io/__init__.py @@ -1,7 +1,7 @@ from quantem.core.io.file_readers import read_2d as read_2d from quantem.core.io.file_readers import read_4dstem as read_4dstem from quantem.core.io.file_readers import ( - read_emdfile_to_4dstem as read_emdfile_to_4dstem, + read_3d_spectroscopy as read_3d_spectroscopy, ) from quantem.core.io.serialize import AutoSerialize as AutoSerialize from quantem.core.io.serialize import load as load diff --git a/src/quantem/core/io/file_readers.py b/src/quantem/core/io/file_readers.py index 4fe72645..421fbf33 100644 --- a/src/quantem/core/io/file_readers.py +++ b/src/quantem/core/io/file_readers.py @@ -9,6 +9,13 @@ from quantem.core.datastructures import Dataset2d as Dataset2d from quantem.core.datastructures import Dataset3d as Dataset3d from quantem.core.datastructures import Dataset4dstem as Dataset4dstem +from quantem.spectroscopy import ( + Dataset3deds as Dataset3deds, +) +from quantem.spectroscopy import ( + Dataset3deels as Dataset3deels, +) +from quantem.spectroscopy import Dataset3dspectroscopy as Dataset3dspectroscopy def read_4dstem( @@ -147,6 +154,95 @@ def read_4dstem( return dataset +def read_3d_spectroscopy( + file_path: str, file_type: str, data_type: str, dataset_index: int | None = None +) -> Dataset3dspectroscopy: + """ + File reader for 3D spectroscopy data + + Parameters + ---------- + file_path: str + Path to data + file_type: str + The type of file reader needed. See rosettasciio for supported formats + https://hyperspy.org/rosettasciio/supported_formats/index.html + data_type: str + type of spectroscopy data 'EELS' or 'EDS' + Returns + -------- + Dataset3dspectroscopy + """ + file_reader = importlib.import_module(f"rsciio.{file_type}").file_reader # type: ignore + data_list = file_reader(file_path) + + # If specific index provided, use it + if dataset_index is not None: + imported_data = data_list[dataset_index] + if imported_data["data"].ndim != 3: + raise ValueError( + f"Dataset at index {dataset_index} has {imported_data['data'].ndim} dimensions, " + f"expected 4D. Shape: {imported_data['data'].shape}" + ) + else: + # Automatically find first 3D dataset + three_d_datasets = [(i, d) for i, d in enumerate(data_list) if d["data"].ndim == 3] + + if len(three_d_datasets) == 0: + print(f"No 3D datasets found in {file_path}. Available datasets:") + for i, d in enumerate(data_list): + print(f" Dataset {i}: shape {d['data'].shape}, ndim={d['data'].ndim}") + raise ValueError("No 3D dataset found in file") + + dataset_index, imported_data = three_d_datasets[0] + + dataset_indices = [] + for entry in three_d_datasets: + dataset_indices.append(entry[0]) + + if len(data_list) > 1: + print( + f"File contains {len(data_list)} dataset(s) and {len(three_d_datasets)} 3D dataset(s) at indices {', '.join(map(str, dataset_indices))}. Using dataset {dataset_index} with shape {imported_data['data'].shape}" + ) + + imported_axes = imported_data["axes"] + axis_order = (0, 1, 2) if file_type == "digitalmicrograph" else (2, 0, 1) + array = ( + imported_data["data"] + if file_type == "digitalmicrograph" + else imported_data["data"].transpose(axis_order) + ) + ordered_axes = [imported_axes[idx] for idx in axis_order] + sampling = [ax.get("scale", 1) for ax in ordered_axes] + origin = [ax.get("offset", 0) for ax in ordered_axes] + units = [ + "pixels" if ax.get("units", "1") == "1" else ax.get("units", "pixels") + for ax in ordered_axes + ] + + for i, unit in enumerate(units): + if unit == "eV" and data_type == "EDS": + sampling[i] = sampling[i] / 1000 + origin[i] = origin[i] / 1000 + units[i] = "keV" + + if data_type == "EELS": + dataset_cls = Dataset3deels + elif data_type == "EDS": + dataset_cls = Dataset3deds + else: + raise ValueError(f"`data_type` must be `EDS` or `EELS` not `{data_type}`") + + dataset = dataset_cls.from_array( + array=array, + sampling=sampling, + origin=origin, + units=units, + ) + + return dataset + + def read_2d( file_path: str | PathLike, file_type: str | None = None, diff --git a/src/quantem/spectroscopy/__init__.py b/src/quantem/spectroscopy/__init__.py index e69de29b..9b87e017 100644 --- a/src/quantem/spectroscopy/__init__.py +++ b/src/quantem/spectroscopy/__init__.py @@ -0,0 +1,10 @@ +from quantem.spectroscopy.dataset3dspectroscopy import ( + Dataset3dspectroscopy as Dataset3dspectroscopy, +) +from quantem.spectroscopy.dataset3deels import ( + Dataset3deels as Dataset3deels, +) + +from quantem.spectroscopy.dataset3deds import ( + Dataset3deds as Dataset3deds, +) diff --git a/src/quantem/spectroscopy/atomic_weights.csv b/src/quantem/spectroscopy/atomic_weights.csv new file mode 100644 index 00000000..3a952649 --- /dev/null +++ b/src/quantem/spectroscopy/atomic_weights.csv @@ -0,0 +1,118 @@ +H,1.01 +He,4.00 +Li,6.94 +Be,9.01 +B,10.81 +C,12.01 +N,14.01 +O,16.00 +F,19.00 +Ne,20.18 +Na,22.99 +Mg,24.31 +Al,26.98 +Si,28.09 +P,30.97 +S,32.06 +Cl,35.45 +Ar,39.95 +K,39.10 +Ca,40.08 +Sc,44.96 +Ti,47.87 +V,50.94 +Cr,52.00 +Mn,54.94 +Fe,55.85 +Co,58.93 +Ni,58.69 +Cu,63.55 +Zn,65.38 +Ga,69.72 +Ge,72.63 +As,74.92 +Se,78.97 +Br,79.90 +Kr,83.80 +Rb,85.47 +Sr,87.62 +Y,88.91 +Zr,91.22 +Nb,92.91 +Mo,95.95 +Tc,97.00 +Ru,101.07 +Rh,102.91 +Pd,106.42 +Ag,107.87 +Cd,112.41 +In,114.82 +Sn,118.71 +Sb,121.76 +Te,127.60 +I,126.90 +Xe,131.29 +Cs,132.91 +Ba,137.33 +La,138.91 +Ce,140.12 +Pr,140.91 +Nd,144.24 +Pm,145.00 +Sm,150.36 +Eu,151.96 +Gd,157.25 +Tb,158.93 +Dy,162.50 +Ho,164.93 +Er,167.26 +Tm,168.93 +Yb,173.05 +Lu,174.97 +Hf,178.49 +Ta,180.95 +W,183.84 +Re,186.21 +Os,190.23 +Ir,192.22 +Pt,195.08 +Au,196.97 +Hg,200.59 +Tl,204.38 +Pb,207.20 +Bi,208.98 +Po,209.00 +At,210.00 +Rn,222.00 +Fr,223.00 +Ra,226.00 +Ac,227.00 +Th,232.04 +Pa,231.04 +U,238.03 +Np,237.00 +Pu,244.00 +Am,243.00 +Cm,247.00 +Bk,247.00 +Cf,251.00 +Es,252.00 +Fm,257.00 +Md,258.00 +No,259.00 +Lr,262.00 +Rf,267.00 +Db,270.00 +Sg,269.00 +Bh,270.00 +Hs,270.00 +Mt,278.00 +Ds,281.00 +Rg,281.00 +Cn,285.00 +Nh,286.00 +Fl,289.00 +Mc,289.00 +Lv,293.00 +Ts,293.00 +Og,294.00 \ No newline at end of file diff --git a/src/quantem/spectroscopy/dataset3deds.py b/src/quantem/spectroscopy/dataset3deds.py new file mode 100644 index 00000000..18b5c774 --- /dev/null +++ b/src/quantem/spectroscopy/dataset3deds.py @@ -0,0 +1,3650 @@ +import re +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +from matplotlib.lines import Line2D +from numpy.typing import NDArray +from scipy.optimize import curve_fit +from scipy.signal import find_peaks, peak_prominences, peak_widths + +from quantem.core.visualization import show_2d +from quantem.spectroscopy import Dataset3dspectroscopy +from quantem.spectroscopy.spectroscopy_models import ( + EDSModel, + GaussianPeaks, + PolynomialBackground, + abundance_smoothness_l2, + build_element_basis, + eds_data_loss, + inverse_softplus, + polynomial_energy_basis, +) + + +class Dataset3deds(Dataset3dspectroscopy): + """An EDS dataset class that inherits from Dataset3dspectroscopy. + + This class represents a scanning transmission electron microscopy (STEM) dataset, + where the data consists of a 3D array with dimensions (energy, scan_y, scan_x). + The first dimension represents the energy, while the latter + two dimensions represent real space sampling. + + """ + + element_info = None + element_info_path = "x_ray_lines.csv" + + def __init__( + self, + array: NDArray | Any, + name: str, + origin: NDArray | tuple | list | float | int, + sampling: NDArray | tuple | list | float | int, + units: list[str] | tuple | list, + signal_units: str = "arb. units", + _token: object | None = None, + ): + """Initialize a 3D EDS dataset.""" + super().__init__( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + _token=_token, + ) + self.dataset_type = "eds" + + @staticmethod + def _normalize_specs(specs, param_name="spec", allow_none=False): + """Parse specs into a flat list of stripped strings.""" + if specs is None: + if allow_none: + return None + raise TypeError(f"{param_name} must be a string or sequence of strings") + if isinstance(specs, str): + return [s.strip() for s in specs.split(",") if s.strip()] + if isinstance(specs, (list, tuple, set)): + return [s.strip() for item in specs for s in str(item).split(",") if s.strip()] + raise TypeError(f"{param_name} must be a string or sequence of strings") + + @staticmethod + def _normalize_token(text): + """Return a lowercase alphanumeric-only token for fuzzy matching.""" + return re.sub(r"[^a-z0-9]", "", str(text).lower()) + + @staticmethod + def _ordered_element_keys(all_info): + """Return element keys sorted longest-first for greedy prefix matching.""" + return sorted(map(str, all_info), key=lambda k: (-len(k), k)) + + @classmethod + def _resolve_element_from_label(cls, label, ordered_elements): + """Extract the element name from a line label like 'FeKa1'.""" + label = str(label) + for element in ordered_elements: + if label.startswith(element): + return element + m = re.match(r"^[A-Z][a-z]?", label) + return m.group(0) if m else None + + @classmethod + def _ensure_element_info(cls): + """Load element X-ray line data if not already cached.""" + if cls.element_info is None: + cls.load_element_info() + return cls.element_info or {} + + @classmethod + def _normalize_element_info(cls, combine_close_peaks=True, energy_threshold_ev=15): + """Normalize EDS X-ray lines and optionally merge unresolved line families.""" + if not isinstance(cls.element_info, dict): + return cls.element_info + + threshold_kev = float(energy_threshold_ev) / 1000.0 + + def line_family(line_name): + canonical = cls._canonical_line_name(line_name).strip() + match = re.match(r"^([A-Za-z]+)", canonical) + return match.group(1) if match else canonical + + def normalized_line_name(line_name): + canonical = cls._canonical_line_name(line_name).strip() + match = re.match(r"^([A-Za-z]+)\d+(?:,\d+)+$", canonical) + return match.group(1) if match else canonical + + def unique_name(lines, name): + if name not in lines: + return name + idx = 2 + while f"{name}__{idx}" in lines: + idx += 1 + return f"{name}__{idx}" + + def merged_info(entries): + weights = np.asarray([entry["weight"] for entry in entries], dtype=float) + energies = np.asarray([entry["energy"] for entry in entries], dtype=float) + weight_sum = float(np.sum(weights)) + if weight_sum > 0.0: + energy = float(np.sum(energies * weights) / weight_sum) + else: + energy = float(np.mean(energies)) + return {"energy (keV)": energy, "weight": weight_sum} + + normalized_info = {} + for element, lines in cls.element_info.items(): + if not isinstance(lines, dict): + normalized_info[element] = lines + continue + + entries_by_family = {} + normalized_lines = {} + for line_name, line_info in lines.items(): + if not isinstance(line_info, dict): + continue + try: + energy = float(line_info.get("energy (keV)", line_info.get("energy"))) + except (TypeError, ValueError): + continue + try: + weight = float(line_info.get("weight", 0.0)) + except (TypeError, ValueError): + weight = 0.0 + + entry = { + "line": normalized_line_name(line_name), + "family": line_family(line_name), + "energy": energy, + "weight": weight, + } + entries_by_family.setdefault(entry["family"], []).append(entry) + + for family, entries in entries_by_family.items(): + entries = sorted(entries, key=lambda entry: entry["energy"]) + if not combine_close_peaks: + for entry in entries: + name = unique_name(normalized_lines, entry["line"]) + normalized_lines[name] = { + "energy (keV)": entry["energy"], + "weight": entry["weight"], + } + continue + + clusters = [] + current = [] + for entry in entries: + if not current or entry["energy"] - current[0]["energy"] <= threshold_kev: + current.append(entry) + else: + clusters.append(current) + current = [entry] + if current: + clusters.append(current) + + for cluster in clusters: + name = family if len(cluster) > 1 else cluster[0]["line"] + name = unique_name(normalized_lines, name) + normalized_lines[name] = merged_info(cluster) + + normalized_info[element] = dict( + sorted( + normalized_lines.items(), + key=lambda item: (item[1]["energy (keV)"], item[0]), + ) + ) + + cls.element_info = normalized_info + return cls.element_info + + @classmethod + def _parse_element_selectors(cls, specs, *, allow_none=False, param_name="spec"): + """Parse element/line specifiers into a dict of {element: set_of_suffixes | None}.""" + tokens = cls._normalize_specs(specs, param_name=param_name, allow_none=allow_none) + if tokens is None: + return None + + ordered = cls._ordered_element_keys(cls._ensure_element_info()) + out: dict[str, set[str] | None] = {} + for raw in tokens: + compact = re.sub(r"[\s_-]+", "", str(raw).strip()) + if not compact: + continue + element = next((k for k in ordered if compact.lower().startswith(k.lower())), None) + if element is None: + raise ValueError(f"Could not resolve element from specifier '{raw}'") + suffix = compact[len(element) :] + out.setdefault(element, None if not suffix else set()) + if suffix and out[element] is not None: + out[element].add(suffix) + return out or None + + @staticmethod + def _canonical_line_name(line_name: str) -> str: + """Strip any suffix after '__' from a line name.""" + return str(line_name).split("__", 1)[0] + + @classmethod + def _iter_selected_lines(cls, element: str, suffix: str, *, raw_spec: str): + """Yield (line_name, line_info) pairs matching an element and optional suffix.""" + lines = cls._ensure_element_info().get(element) or {} + if not lines: + raise ValueError(f"No X-ray lines found for element '{element}'") + if not suffix: + yield from lines.items() + return + + suffix = cls._normalize_token(suffix) + exact, prefix = [], [] + for line_name, line_info in lines.items(): + token = cls._normalize_token(cls._canonical_line_name(line_name)) + if token == suffix: + exact.append((line_name, line_info)) + if token.startswith(suffix): + prefix.append((line_name, line_info)) + matches = exact or prefix + if not matches: + raise ValueError( + f"No X-ray lines matched specifier '{raw_spec}' for element '{element}'" + ) + yield from matches + + @classmethod + def _group_labels_by_element(cls, labels: list[str]): + """Group line labels by their parent element.""" + ordered = cls._ordered_element_keys(cls._ensure_element_info()) + grouped: dict[str, list[str]] = {} + for lbl in sorted(map(str, labels)): + element = cls._resolve_element_from_label(lbl, ordered) + if element: + grouped.setdefault(element, []).append(lbl) + return grouped + + @classmethod + def _select_labels( + cls, selector: str, *, labels: list[str], labels_by_element: dict[str, list[str]] + ): + """Return labels matching a selector string (exact, element, or prefix).""" + selector = str(selector).strip() + if not selector: + return [] + + lower_map = {lbl.lower(): lbl for lbl in labels} + if selector.lower() in lower_map: + return [lower_map[selector.lower()]] + + elem_map = {elem.lower(): elem for elem in labels_by_element} + if selector.lower() in elem_map: + return list(labels_by_element[elem_map[selector.lower()]]) + + token = cls._normalize_token(selector) + return [lbl for lbl in labels if cls._normalize_token(lbl).startswith(token)] + + @staticmethod + def _line_shell(line_name: str) -> str: + """Return the shell letter ('K', 'L', 'M', or '?') for a line name.""" + line_name = str(line_name).upper() + return ( + "K" + if line_name.startswith("K") + else "L" + if line_name.startswith("L") + else "M" + if line_name.startswith("M") + else "?" + ) + + @staticmethod + def _peak_confidence( + snr_value: float, line_weight: float, distance_value: float, tolerance: float + ) -> float: + """Compute a confidence score for a peak-to-line match.""" + sigma = max(float(tolerance) / 3.0, 1e-9) + return ( + np.log1p(max(float(snr_value), 0.0)) + * max(float(line_weight), 0.0) + * np.exp(-0.5 * (float(distance_value) / sigma) ** 2) + ) + + @staticmethod + def _line_matches_selector(line_name: str, selector: str) -> bool: + """Check whether a line name matches a shell or substring selector.""" + line = str(line_name).strip().lower() + selector = str(selector).strip().lower() + return line.startswith(selector) if selector in {"k", "l", "m"} else selector in line + + @classmethod + def _line_allowed_for_element( + cls, element_name: str, line_name: str, edge_filters=None + ) -> bool: + """Return True if the line passes the edge filter for its element.""" + selectors = None if edge_filters is None else edge_filters.get(str(element_name)) + return selectors is None or any( + cls._line_matches_selector(line_name, token) for token in selectors + ) + + def _get_spectrum_images(self, method="integration"): + """Retrieve cached spectrum images for the given method.""" + return { + "integration": getattr(self, "_spectrum_images", None), + "fit": getattr(self, "_spectrum_images_pytorch", None), + }.get(method) + + @staticmethod + def _shell_preference_factor(shell_name: str) -> float: + """Return a down-weighting factor for M-shell lines.""" + return 0.72 if shell_name == "M" else 1.0 + + @staticmethod + def _merge_edge_filters(requested, saved): + """Merge requested and saved edge filters, unioning selectors per element.""" + if requested and saved: + merged = dict(saved) + for element, selectors in requested.items(): + current = merged.get(element) + merged[element] = ( + None if current is None or selectors is None else set(current).union(selectors) + ) + return merged + return requested or saved + + @staticmethod + def _estimate_snr_thresholds(snr_values, floor=None, snr_threshold=None): + """Auto-estimate SNR floors/thresholds from the peak SNR distribution.""" + snr_values = np.asarray(snr_values, dtype=float) + snr_values = snr_values[np.isfinite(snr_values)] + + if floor is None: + if snr_values.size: + # Robust quantile floor: center near the middle/high-middle SNR + # band so the floor tracks "visible" peaks without being pulled + # down by noise tails or up by a few extreme peaks. + q30, q40, q50, q60 = np.percentile(snr_values, [30, 40, 50, 60]) + floor = 0.5 * float(q40 + q50) + floor = float(np.clip(floor, q30, q60)) + floor = max(0.0, floor) + else: + floor = 8.0 + else: + floor = float(floor) + + if snr_threshold is None: + if snr_values.size: + high = snr_values[snr_values >= floor] + high = high if high.size else snr_values + # Keep auto-threshold independent from the requested display + # count (peaks). `peaks` should control only how many detected + # peaks are shown, not which peaks are detected. + anchor = np.sort(high)[::-1][: min(high.size, 40)] + med, q75, q90 = np.percentile(anchor, [50, 75, 90]) + snr_threshold = float( + np.clip(max(med, 0.7 * q75, 2.5 * floor), max(2.5 * floor, floor), q90) + ) + else: + snr_threshold = max(4.0 * floor, 30.0) + else: + snr_threshold = float(snr_threshold) + + return floor, snr_threshold + + def x_ray_lookup( + self, spec: str | list[str] | tuple[str, ...] | set[str] + ) -> tuple[np.ndarray, np.ndarray, list[str]]: + """Look up X-ray line energies, weights, and labels. + + Parameters + ---------- + spec : str | sequence[str] + One or more element/line specifiers. Accepted formats include + element names (``'Fe'``), element + shell (``'Fe K'``), and + element + line (``'Fe Ka1'``). Comma-separated strings are + split automatically. + + Returns + ------- + energies : ndarray + 1-D array of line energies in keV, sorted by energy. + weights : ndarray + Corresponding tabulated line weights (0--1). + labels : list[str] + Human-readable labels such as ``'FeKa1'``. + + Raises + ------ + ValueError + If no lines match the specifier(s). + """ + info = type(self)._ensure_element_info() + ordered = type(self)._ordered_element_keys(info) + specs = type(self)._normalize_specs(spec, param_name="spec") + + rows: list[tuple[str, float, float]] = [] + for raw in specs: + compact = re.sub(r"[\s_-]+", "", str(raw).strip()) + if not compact: + continue + element = next((k for k in ordered if compact.lower().startswith(k.lower())), None) + if element is None: + raise ValueError(f"Could not resolve element from specifier '{raw}'") + suffix = compact[len(element) :] + for line_name, line_info in type(self)._iter_selected_lines( + element, suffix, raw_spec=str(raw) + ): + if not isinstance(line_info, dict): + continue + try: + energy = float(line_info.get("energy (keV)", line_info.get("energy"))) + except (TypeError, ValueError): + continue + try: + weight = float(line_info.get("weight", 0.0)) + except (TypeError, ValueError): + weight = 0.0 + rows.append( + (f"{element}{type(self)._canonical_line_name(line_name)}", energy, weight) + ) + + if not rows: + raise ValueError(f"No X-ray lines matched specifier(s): {specs}") + + unique = sorted( + {(lbl, round(float(e), 12), round(float(w), 12)) for lbl, e, w in rows}, + key=lambda t: (t[1], -t[2], t[0]), + ) + return ( + np.asarray([e for _, e, _ in unique], dtype=float), + np.asarray([w for _, _, w in unique], dtype=float), + [lbl for lbl, _, _ in unique], + ) + + def generage_spectrum_images(self, elements=None, width=0.15, return_maps=False): + """Generate spectrum images by integrating around X-ray line energies. + + For each matched X-ray line, sums the spectral intensity within an + energy window of ``line_energy +/- width`` at every spatial pixel. + Results are cached in ``self._spectrum_images`` for later use by + :meth:`show_spectrum_images` and :meth:`quantify_composition_cliff_lorimer`. + + Parameters + ---------- + elements : str | sequence[str] | None, optional + Element/line specifiers (see :meth:`x_ray_lookup`). If ``None``, + uses ``self.model_elements``. + width : float, optional + Half-width of the integration window in keV. + return_maps : bool, optional + If ``True``, return ``(maps, labels)``. + + Returns + ------- + tuple[ndarray, list[str]] | None + Only returned when *return_maps* is ``True``. + """ + if elements is None: + if self.model_elements is None: + raise ValueError("elements must be specified") + elements = list(self.model_elements) + + energies, _, labels = self.x_ray_lookup(elements) + keep = (energies > self.energy_axis.min()) & (energies < self.energy_axis.max()) + energies = energies[keep] + labels = [label for label, ok in zip(labels, keep) if ok] + + mask = (self.energy_axis[:, None] > energies[None, :] - width) & ( + self.energy_axis[:, None] < energies[None, :] + width + ) + n, h, w = self.array.shape + maps = (mask.astype(self.array.dtype).T @ self.array.reshape(n, -1)).reshape( + mask.shape[1], h, w + ) + + self._spectrum_images = { + **getattr(self, "_spectrum_images", {}), + **dict(zip(labels, maps)), + } + + images, titles = self.show_spectrum_images(x_ray_lines=elements, return_maps=True) + + if return_maps: + return images, titles + + def Integrate(self, spec, width=0.15, return_maps=False, show=True, **kwargs): + """Integrate the spectrum around specified X-ray lines. + + Sums spectral intensity within ``line_energy +/- width`` for each + selector. By default, displays the resulting map(s). + + Parameters + ---------- + spec : str | sequence[str] + Element/line specifiers (see :meth:`x_ray_lookup`), e.g. + ``'Fe Ka'`` or ``['Cu', 'Zn']``. + width : float, optional + Half-width of the integration window in keV. + return_maps : bool, optional + If ``True``, return the integrated maps. + show : bool, optional + If ``True``, display the maps. + **kwargs + Forwarded to the plotting function (e.g. ``cmap``, ``roi``). + + Returns + ------- + ndarray | dict[str, ndarray] + Single map when one selector is given, otherwise a dict keyed by + selector string. + """ + width = float(width) + specs = type(self)._normalize_specs(spec, param_name="spec") + arr = np.asarray(self.array, dtype=float) + energy_axis = np.asarray(self.energy_axis, dtype=float) + energy_min, energy_max = float(energy_axis.min()), float(energy_axis.max()) + + selector_masks, integrated_maps = {}, {} + for selector in map(str, specs): + line_energies, _, _ = self.x_ray_lookup(selector.strip()) + line_energies = line_energies[ + (line_energies >= energy_min) & (line_energies <= energy_max) + ] + if not len(line_energies): + raise ValueError( + f"No X-ray lines for selector '{selector}' are within the dataset energy range" + ) + + mask = np.any( + (energy_axis[:, None] >= line_energies[None, :] - width) + & (energy_axis[:, None] <= line_energies[None, :] + width), + axis=1, + ) + selector_masks[selector] = mask + integrated_maps[selector] = arr[mask].sum(axis=0) + + if show: + cmap = kwargs.pop("cmap", "magma") + if len(integrated_maps) == 1: + selector = next(iter(integrated_maps)) + self.show_energy_window_map( + energy_window=[energy_min, energy_max], + roi=kwargs.pop("roi", None), + roi_cal=kwargs.pop("roi_cal", None), + mask=selector_masks[selector], + data_type=kwargs.pop("data_type", "eds"), + cmap=cmap, + show=True, + ) + else: + show_2d( + list(integrated_maps.values()), + title=list(integrated_maps), + cmap=cmap, + scalebar={"sampling": self.sampling[1], "units": self.units[1]}, + **kwargs, + ) + + return ( + integrated_maps + if return_maps or len(integrated_maps) != 1 + else next(iter(integrated_maps.values())) + ) + + def integrate(self, spec, width=0.15, return_maps=False, show=True, **kwargs): + """Convenience wrapper for Integrate.""" + return self.Integrate(spec=spec, width=width, return_maps=return_maps, show=show, **kwargs) + + def show_spectrum_images( + self, x_ray_lines=None, return_fig=False, return_maps=False, method="integration", **kwargs + ): + """Display cached spectrum images. + + Parameters + ---------- + x_ray_lines : str | sequence[str] | None, optional + Selectors to filter which images are shown. If ``None``, one + panel per element is displayed. + return_fig : bool, optional + If ``True``, return ``(fig, ax)``. + method : {"integration", "fit"}, optional + Which cache to read from: integration-based maps or PyTorch + fit-based maps. + **kwargs + Forwarded to :func:`show_2d` (e.g. ``cmap``). + + Returns + ------- + tuple[Figure, Axes] | None + Only returned when *return_fig* is ``True``. + + Raises + ------ + ValueError + If no cached spectrum images exist for the chosen *method*. + """ + spectrum_images = self._get_spectrum_images(method) + if not spectrum_images: + raise ValueError("No spectrum images found. Run generage_spectrum_images(...) first.") + + line_map = {str(k): np.asarray(v) for k, v in spectrum_images.items()} + labels = list(line_map) + labels_by_element = type(self)._group_labels_by_element(labels) + + def sum_maps(lbls): + return np.sum([line_map[lbl] for lbl in lbls], axis=0) + + specs = type(self)._normalize_specs(x_ray_lines, param_name="x_ray_lines", allow_none=True) + if not specs: + titles = sorted(labels_by_element) + images = [sum_maps(labels_by_element[t]) for t in titles] + else: + selected = [ + type(self)._select_labels( + str(raw), labels=labels, labels_by_element=labels_by_element + ) + for raw in specs + ] + if any(not s for s in selected): + bad = next(raw for raw, s in zip(specs, selected) if not s) + raise ValueError(f"No spectrum images matched selector '{bad}'") + images = [line_map[s[0]] if len(s) == 1 else sum_maps(s) for s in selected] + titles = [s[0] if len(s) == 1 else str(raw).strip() for raw, s in zip(specs, selected)] + + fig, ax = show_2d( + images, + title=titles, + cmap=kwargs.pop("cmap", "magma"), + scalebar={"sampling": self.sampling[1], "units": self.units[1]}, + returnfig=True, + **kwargs, + ) + + if return_fig and return_maps: + return (fig, ax), (images, titles) + elif return_fig: + return fig, ax + elif return_maps: + return images, titles + + def _build_pytorch_spectrum_images( + self, abundance_maps: np.ndarray, element_names: list[str] | tuple[str, ...] + ) -> dict[str, np.ndarray]: + """Convert per-element abundance maps into per-line spectrum images using weights.""" + maps = np.asarray(abundance_maps) + if maps.ndim != 3: + return {} + + line_maps = {} + for i, element_name in enumerate(element_names): + if i >= maps.shape[0]: + break + try: + _, line_weights, line_labels = self.x_ray_lookup(str(element_name)) + except ValueError: + continue + element_map = np.asarray(maps[i], dtype=float) + for weight, label in zip(line_weights, line_labels): + line_maps[str(label)] = element_map * float(weight) + return line_maps + + def quantify_composition_cliff_lorimer( + self, k_factors, method="integration", return_maps=False, verbose=True + ): + """Quantify elemental composition using the Cliff-Lorimer thin-film method. + + Parameters + ---------- + k_factors : dict[str, float] + Mapping of element/line selectors to their k-factors, e.g. + ``{'Fe K': 1.0, 'Cu K': 1.45}``. At least two elements are + required. + method : {"integration", "fit"}, optional + Which cached spectrum images to use for intensity extraction. + return_maps : bool, optional + If ``True``, include per-pixel atomic-percent and weight-percent + maps in the returned dict. + verbose : bool, optional + If ``True``, print the quantification summary table. + + Returns + ------- + dict + Keys include ``atomic_percent``, ``weight_percent``, + ``intensities``, ``weighted_intensities``, and + ``summary_table``. When *return_maps* is ``True``, also + includes ``atomic_percent_maps`` and ``weight_percent_maps``. + + Raises + ------ + ValueError + If *k_factors* is empty, fewer than two elements are matched, or + spectrum images are missing. + """ + if not k_factors: + raise ValueError("k_factors must be a non-empty dict") + spectrum_images = self._get_spectrum_images(method) + if not spectrum_images: + raise ValueError("No spectrum images available for quantification") + + ordered_elements = type(self)._ordered_element_keys(type(self)._ensure_element_info()) + line_map = {str(k): np.asarray(v, dtype=float) for k, v in spectrum_images.items()} + labels = list(line_map) + labels_by_element = type(self)._group_labels_by_element(labels) + + def match(selector: str) -> list[str]: + return type(self)._select_labels( + selector, labels=labels, labels_by_element=labels_by_element + ) + + intensities, weighted_intensities = {}, {} + selector_maps = {} if return_maps else None + intensity_maps = {} if return_maps else None + weighted_intensity_maps = {} if return_maps else None + + for selector, k_raw in k_factors.items(): + k_val = float(k_raw) + sel_labels = match(str(selector).strip()) + if not sel_labels: + raise ValueError(f"No spectrum images matched selector {selector!r}") + + matched_elements = { + type(self)._resolve_element_from_label(lbl, ordered_elements) for lbl in sel_labels + } - {None} + if len(matched_elements) != 1: + raise ValueError( + f"Selector {selector!r} matched multiple elements: {sorted(matched_elements)}" + ) + element = next(iter(matched_elements)) + + grouped_map = np.sum([line_map[lbl] for lbl in sel_labels], axis=0) + intensity = float(grouped_map.sum()) + weighted = float(k_val * intensity) + intensities[element] = intensities.get(element, 0.0) + intensity + weighted_intensities[element] = weighted_intensities.get(element, 0.0) + weighted + + if return_maps: + weighted_map = grouped_map * k_val + selector_maps[str(selector)] = grouped_map + intensity_maps[element] = intensity_maps.get(element, 0) + grouped_map + weighted_intensity_maps[element] = ( + weighted_intensity_maps.get(element, 0) + weighted_map + ) + + if len(weighted_intensities) < 2: + raise ValueError("At least two elements are required for Cliff-Lorimer quantification") + + weighted_sum = sum(weighted_intensities.values()) + atomic_percent = { + el: 100.0 * val / weighted_sum if weighted_sum > 0 else 0.0 + for el, val in weighted_intensities.items() + } + + if type(self).atomic_weights is None: + type(self).load_atomic_weights() + atomic_weights = type(self).atomic_weights or {} + missing = [el for el in atomic_percent if el not in atomic_weights] + if missing: + raise ValueError(f"Atomic weights not found for elements: {missing}") + + weight_sum = sum( + (atomic_percent[el] / 100.0) * float(atomic_weights[el]) for el in atomic_percent + ) + weight_percent = { + el: (atomic_percent[el] / 100.0) * float(atomic_weights[el]) / weight_sum * 100.0 + if weight_sum > 0 + else 0.0 + for el in atomic_percent + } + + ordered = sorted(weighted_intensities, key=weighted_intensities.get, reverse=True) + table_text = "\n".join( + [ + "Element Intensity Weighted Intensity Atomic % Weight %", + "------- ------------- -------------------- ---------- ----------", + *[ + f"{el:<7} {intensities[el]:>13.3f} {weighted_intensities[el]:>20.3f} {atomic_percent[el]:>10.3f} {weight_percent[el]:>10.3f}" + for el in ordered + ], + ] + ) + result = { + "intensities": intensities, + "weighted_intensities": weighted_intensities, + "atomic_percent": atomic_percent, + "weight_percent": weight_percent, + "summary_table": table_text, + } + if verbose: + print(table_text) + + if return_maps: + weighted_stack = np.stack(list(weighted_intensity_maps.values()), axis=0) + weighted_sum_map = weighted_stack.sum(axis=0) + atomic_percent_maps = { + el: np.divide( + wmap * 100.0, + weighted_sum_map, + out=np.zeros_like(weighted_sum_map, dtype=float), + where=weighted_sum_map > 0, + ) + for el, wmap in weighted_intensity_maps.items() + } + mass_maps = { + el: atomic_percent_maps[el] / 100.0 * float(atomic_weights[el]) + for el in atomic_percent_maps + } + mass_sum_map = np.sum(np.stack(list(mass_maps.values()), axis=0), axis=0) + weight_percent_maps = { + el: np.divide( + mmap * 100.0, + mass_sum_map, + out=np.zeros_like(mass_sum_map, dtype=float), + where=mass_sum_map > 0, + ) + for el, mmap in mass_maps.items() + } + result.update( + { + "selector_maps": selector_maps, + "intensity_maps": intensity_maps, + "weighted_intensity_maps": weighted_intensity_maps, + "atomic_percent_maps": atomic_percent_maps, + "weight_percent_maps": weight_percent_maps, + } + ) + return result + + def clear_spectrum_images(self): + """Clear cached integration-based spectrum images.""" + self._spectrum_images = {} + + def clear_spectrum_images_pytorch(self): + """Clear cached PyTorch fit-based spectrum images.""" + self._spectrum_images_pytorch = {} + + def peak_autoid( + self, + roi=None, + roi_cal=None, + energy_range=None, + elements=None, + ignore_elements=None, + ignore_range=None, + tolerance=0.15, + min_line_weight=0.0, + mask=None, + show_text=True, + floor=None, + snr_quantile_floor=None, + snr_min=None, + snr_threshold=None, + distance_threshold_for_sample=0.05, + grid_peaks=None, + peaks=15, + mode=None, + line=None, + return_details=False, + ): + """Automatically identify elements from EDS peaks in the mean spectrum. + + Finds peaks in the spatially-averaged spectrum, matches them against a + database of known X-ray line energies, and classifies elements as + *detected* (high confidence) or *possible* (lower confidence). Results + are printed and overlaid on an interactive spectrum plot. + + Parameters + ---------- + roi : sequence[int] | None, optional + Pixel-coordinate ROI ``[y0, y1, x0, x1]`` used when computing the + mean spectrum. If ``None``, the full spatial extent is used. + roi_cal : sequence[float] | None, optional + Calibrated-coordinate ROI (same layout as *roi* but in physical + units). + energy_range : sequence[float] | None, optional + Two-element energy window ``[emin, emax]`` in keV. Peaks outside + this range are ignored. + elements : str | sequence[str] | None, optional + Element or element-line specifiers to search for, e.g. + ``'Fe'``, ``'Fe Ka'``, or ``['Cu', 'Zn K']``. When provided, + behaviour depends on *mode*. + ignore_elements : str | sequence[str] | None, optional + Elements to exclude from autodetection. + ignore_range : sequence[float] | None, optional + Energy range ``[emin, emax]`` whose peaks are ignored. Defaults to + ``[0, 0.25]`` keV to skip the noise floor. + tolerance : float, optional + Maximum energy difference in keV between a detected peak and a + tabulated X-ray line for them to be considered a match. + M-shell minor lines use ``tolerance * 0.5``. + min_line_weight : float, optional + Minimum tabulated line weight (0--1) for a line to be considered. + mask : ndarray | None, optional + Boolean spatial mask; only pixels where ``mask`` is ``True`` + contribute to the mean spectrum. + show_text : bool, optional + If ``True``, annotate matched peaks on the plot. + floor : float | None, optional + Minimum signal-to-noise ratio for a peak to be displayed. If + ``None``, estimated from robust middle quantiles (roughly between + the 30th and 60th percentile of peak SNRs). + snr_quantile_floor : float | None, optional + Deprecated alias for *floor*. + snr_min : float | None, optional + Deprecated alias for *floor*. + snr_threshold : float | None, optional + SNR above which a peak match counts as "strong" evidence for an + element. If ``None``, estimated automatically. + distance_threshold_for_sample : float, optional + Maximum energy distance (keV) for a match to qualify as a strong + match (used together with *snr_threshold*). + grid_peaks : dict | None, optional + Mapping of ``{label: energy}`` for known grid/artifact peaks that + should be flagged in the output. + peaks : int, optional + Maximum number of peaks to display. + mode : {"elements_only", "elements_preferred", "autofill"} | None, optional + Search strategy. ``"elements_only"`` restricts matching to + *elements*; ``"elements_preferred"`` boosts them but allows others; + ``"autofill"`` (default when *elements* is ``None``) searches all + elements. + line : float | sequence[float] | None, optional + Energy value(s) in keV for reference lines to draw on the spectrum + plot, e.g. ``3.692`` or ``[3.692, 4.510]``. Lines are drawn as + dashed black vertical lines. + return_details : bool, optional + If ``True``, return a dict with detection details instead of the + figure. + + Returns + ------- + tuple[Figure, tuple[Axes, Axes]] | dict + By default returns ``(fig, (ax_img, ax_spec))``. When + *return_details* is ``True``, returns a dict containing + ``detected_elements``, ``element_confidence``, ``display_peaks``, + ``peak_matches``, ``floor``, ``snr_threshold``, and the figure. + """ + type(self)._ensure_element_info() + all_info = type(self).element_info or {} + grid_peaks = grid_peaks or {} + ignore_range = [0, 0.25] if ignore_range is None else ignore_range + ignored_elements = set( + map(str, type(self)._normalize_specs(ignore_elements, allow_none=True) or []) + ) + min_line_weight = max(float(min_line_weight), 0.0) + + requested = type(self)._parse_element_selectors( + elements, allow_none=True, param_name="elements" + ) + saved = { + str(k): (set(map(str, v.keys())) if isinstance(v, dict) and v else None) + for k, v in (getattr(self, "model_elements", {}) or {}).items() + } or None + edge_filters = requested if requested is not None else saved + requested_elements = set(edge_filters) if edge_filters else None + + mode = (str(mode).strip().lower() if mode is not None else None) or ( + "elements_only" if requested_elements else "autofill" + ) + search_elements = requested_elements if mode == "elements_only" else None + preferred_elements = ( + set(map(str, requested_elements or [])) if mode == "elements_preferred" else set() + ) + reference_elements = requested_elements + + fig, (ax_img, ax_spec) = self.show_mean_spectrum( + roi=roi, + roi_cal=roi_cal, + energy_range=energy_range, + mask=mask, + data_type="eds", + show=False, + ) + spec = self.calculate_mean_spectrum( + roi=roi, + roi_cal=roi_cal, + energy_range=energy_range, + mask=mask, + ) + E = float(self.origin[0]) + float(self.sampling[0]) * np.arange(self.shape[0]) + + # Keep the energy axis aligned with calculate_mean_spectrum filtering. + if mask is not None: + mask_arr = np.asarray(mask, dtype=bool) + if mask_arr.shape != E.shape: + raise ValueError( + f"Mask shape {mask_arr.shape} does not match energy axis shape {E.shape}." + ) + E = E[mask_arr] + + if energy_range is not None: + keep = (energy_range[0] <= E) & (E <= energy_range[1]) + E = E[keep] + + if len(spec) != len(E): + raise ValueError( + "Energy axis length does not match mean spectrum length after filtering. " + f"Got len(E)={len(E)} and len(spec)={len(spec)}." + ) + + def in_ignore(energy): + return len(ignore_range) == 2 and ignore_range[0] <= float(energy) <= ignore_range[1] + + peak_indices, props = find_peaks(spec, height=0, distance=5) + peak_heights = props["peak_heights"] + peak_proms = ( + peak_prominences(spec, peak_indices)[0] + if len(peak_indices) + else np.asarray([], dtype=float) + ) + peak_width_samples = ( + peak_widths(spec, peak_indices, rel_height=0.5)[0] + if len(peak_indices) + else np.asarray([], dtype=float) + ) + background_std = np.nanstd(spec[spec <= np.nanpercentile(spec, 50)]) + if not np.isfinite(background_std) or background_std <= 0: + background_std = np.nanstd(spec) + if not np.isfinite(background_std) or background_std <= 0: + background_std = 1.0 + + if floor is None and snr_quantile_floor is not None: + floor = snr_quantile_floor + if floor is None and snr_min is not None: + floor = snr_min + + # Collapse shoulder peaks before SNR filtering. + # Two adjacent peaks are treated as one if they are very close in energy + # and the valley between them is shallow relative to the smaller peak. + # This removes split-peak artifacts that tend to over-label broad peaks. + def collapse_shoulder_peaks(indices, heights, prominences, widths): + if len(indices) <= 1: + return ( + np.asarray(indices, dtype=int), + np.asarray(heights, dtype=float), + np.asarray(prominences, dtype=float), + np.asarray(widths, dtype=float), + ) + + energy_gap_limit = max(6.0 * float(self.sampling[0]), 0.14) + min_valley_relief = 0.35 + min_height_ratio = 0.45 + + keep = [] + i = 0 + while i < len(indices): + best_idx = int(indices[i]) + best_h = float(heights[i]) + best_p = float(prominences[i]) + best_w = float(widths[i]) + j = i + 1 + + while j < len(indices): + cand_idx = int(indices[j]) + cand_h = float(heights[j]) + cand_p = float(prominences[j]) + cand_w = float(widths[j]) + if float(E[cand_idx] - E[best_idx]) > energy_gap_limit: + break + + lo, hi = sorted((best_idx, cand_idx)) + if hi - lo <= 1: + valley = float(min(spec[lo], spec[hi])) + else: + valley = float(np.min(spec[lo : hi + 1])) + + smaller = max(min(best_h, cand_h), 1e-12) + valley_relief = (smaller - valley) / smaller + height_ratio = min(best_h, cand_h) / max(best_h, cand_h) + + # Not a clearly separated doublet -> merge shoulders. + if valley_relief < min_valley_relief or height_ratio < min_height_ratio: + if (cand_p > best_p) or (cand_p == best_p and cand_h > best_h): + best_idx, best_h, best_p, best_w = cand_idx, cand_h, cand_p, cand_w + j += 1 + continue + + break + + keep.append((best_idx, best_h, best_p, best_w)) + i = j + + out_idx = np.asarray([pk for pk, _, _, _ in keep], dtype=int) + out_h = np.asarray([h for _, h, _, _ in keep], dtype=float) + out_p = np.asarray([p for _, _, p, _ in keep], dtype=float) + out_w = np.asarray([w for _, _, _, w in keep], dtype=float) + order = np.argsort(out_idx) + return out_idx[order], out_h[order], out_p[order], out_w[order] + + peak_indices, peak_heights, peak_proms, peak_width_samples = collapse_shoulder_peaks( + peak_indices, + peak_heights, + peak_proms, + peak_width_samples, + ) + + snr_values = np.asarray([height / background_std for height in peak_heights], dtype=float) + floor, snr_threshold = type(self)._estimate_snr_thresholds( + snr_values, + floor, + snr_threshold, + ) + + # Prominence filter in SNR units: suppress shoulder/noise artifacts that + # may have acceptable height but do not form a distinct peak. + prominence_snr = np.asarray( + [float(p) / max(float(background_std), 1e-12) for p in peak_proms], dtype=float + ) + + def _local_noise_std(pk_idx): + # Use local baseline variability so narrow doublets are not lost + # when a wide energy range inflates global noise estimates. + local_window = max(0.24, 12.0 * float(self.sampling[0])) + mask_local = np.abs(E - float(E[int(pk_idx)])) <= local_window + if int(np.count_nonzero(mask_local)) < 9: + return float(background_std) + + y_local = np.asarray(spec[mask_local], dtype=float) + if y_local.size < 9 or not np.all(np.isfinite(y_local)): + return float(background_std) + + local_cut = float(np.nanpercentile(y_local, 70)) + base_local = y_local[y_local <= local_cut] + if base_local.size < 5: + base_local = y_local + + local_std = float(np.nanstd(base_local)) + if not np.isfinite(local_std) or local_std <= 0: + local_std = float(background_std) + return max(local_std, 1e-12) + + local_noise = np.asarray([_local_noise_std(int(i)) for i in peak_indices], dtype=float) + local_snr_values = np.asarray( + [float(h) / max(float(n), 1e-12) for h, n in zip(peak_heights, local_noise)], + dtype=float, + ) + local_prominence_snr = np.asarray( + [float(p) / max(float(n), 1e-12) for p, n in zip(peak_proms, local_noise)], dtype=float + ) + + prominence_floor = max(2.2, 0.85 * float(floor)) + salience_snr = prominence_snr * np.sqrt(np.maximum(peak_width_samples, 1e-12)) + salience_floor = max(4.2, 2.0 * float(floor)) + local_salience_snr = local_prominence_snr * np.sqrt(np.maximum(peak_width_samples, 1e-12)) + + adaptive_floor = max(2.0, 0.62 * float(floor)) + adaptive_prominence_floor = max(1.6, 0.62 * float(prominence_floor)) + adaptive_salience_floor = max(2.6, 0.62 * float(salience_floor)) + + display_peaks_with_prom = [ + ( + int(i), + float(h), + float(E[i]), + float(max(float(h / background_std), float(local_snr))), + float(max(float(p_snr), float(local_p_snr))), + float(max(float(sal), float(local_sal))), + ) + for i, h, p_snr, sal, local_snr, local_p_snr, local_sal in zip( + peak_indices, + peak_heights, + prominence_snr, + salience_snr, + local_snr_values, + local_prominence_snr, + local_salience_snr, + ) + if ( + not in_ignore(E[i]) + and ( + ( + h / background_std >= floor + and p_snr >= prominence_floor + and sal >= salience_floor + ) + or ( + local_snr >= adaptive_floor + and local_p_snr >= adaptive_prominence_floor + and local_sal >= adaptive_salience_floor + ) + ) + ) + ] + + # Validate peaks as local Gaussian components (center/sigma/amplitude) + # rather than raw single-bin maxima, then merge overlapping components. + def _gauss_with_offset(x, amp, mu, sigma, offset): + sigma = max(float(sigma), 1e-12) + return float(offset) + float(amp) * np.exp(-0.5 * ((x - float(mu)) / sigma) ** 2) + + def _fit_local_gaussian(pk_idx): + window = max(0.18, 10.0 * float(self.sampling[0])) + x0 = float(E[pk_idx]) + mask_local = np.abs(E - x0) <= window + if int(np.count_nonzero(mask_local)) < 7: + return None + + x_local = np.asarray(E[mask_local], dtype=float) + y_local = np.asarray(spec[mask_local], dtype=float) + if not np.all(np.isfinite(y_local)): + return None + + baseline = float(np.percentile(y_local, 20)) + peak_val = float(spec[pk_idx]) + amp0 = max(peak_val - baseline, 1e-9) + sigma0 = max(0.04, 2.0 * float(self.sampling[0])) + + lo_sigma = max(1.5 * float(self.sampling[0]), 0.010) + hi_sigma = 0.18 + bounds = ( + [0.0, x0 - 0.06, lo_sigma, baseline - abs(amp0)], + [max(amp0 * 5.0, 1e-6), x0 + 0.06, hi_sigma, baseline + abs(amp0)], + ) + + try: + popt, _ = curve_fit( + _gauss_with_offset, + x_local, + y_local, + p0=[amp0, x0, sigma0, baseline], + bounds=bounds, + maxfev=4000, + ) + except Exception: + return None + + amp, mu, sigma, offset = map(float, popt) + if amp <= 0 or not np.isfinite(mu) or not np.isfinite(sigma): + return None + + y_hat = _gauss_with_offset(x_local, amp, mu, sigma, offset) + ss_res = float(np.sum((y_local - y_hat) ** 2)) + ss_tot = float(np.sum((y_local - float(np.mean(y_local))) ** 2)) + r2 = 1.0 - ss_res / max(ss_tot, 1e-12) + amp_snr = amp / max(float(background_std), 1e-12) + + return { + "idx": int(pk_idx), + "mu": float(mu), + "sigma": float(sigma), + "amp": float(amp), + "amp_snr": float(amp_snr), + "r2": float(r2), + "area": float(amp * sigma), + } + + gaussian_validation_gate = max(2.2 * float(floor), 0.25 * float(snr_threshold)) + strong_keep_idx = { + int(pk_idx) + for pk_idx, _, _, snr, _, _ in display_peaks_with_prom + if float(snr) >= gaussian_validation_gate + } + + gauss_components = [] + for pk_idx, _, _, snr, _, _ in display_peaks_with_prom: + if float(snr) >= gaussian_validation_gate: + continue + fit = _fit_local_gaussian(int(pk_idx)) + if fit is None: + continue + # Keep only physically plausible and sufficiently Gaussian components. + if fit["r2"] < 0.58: + continue + if fit["amp_snr"] < max(2.0, 0.75 * float(floor)): + continue + if fit["sigma"] < max(1.5 * float(self.sampling[0]), 0.010) or fit["sigma"] > 0.18: + continue + gauss_components.append(fit) + + gaussian_validated = bool(gauss_components) + if gauss_components: + # Merge overlapping Gaussian components and keep the stronger one. + gauss_components.sort(key=lambda comp: comp["mu"]) + merged = [] + for comp in gauss_components: + if not merged: + merged.append(comp) + continue + prev = merged[-1] + # Keep neighbouring components separate unless they are truly + # unresolved by both center spacing and valley separation. + center_gap = abs(float(comp["mu"]) - float(prev["mu"])) + overlap_thresh = 1.15 * min(float(prev["sigma"]), float(comp["sigma"])) + + prev_idx = int(prev["idx"]) + comp_idx = int(comp["idx"]) + lo, hi = sorted((prev_idx, comp_idx)) + if hi - lo <= 1: + valley = float(min(spec[lo], spec[hi])) + else: + valley = float(np.min(spec[lo : hi + 1])) + smaller_amp = max(min(float(prev["amp"]), float(comp["amp"])), 1e-12) + valley_relief = (smaller_amp - valley) / smaller_amp + + unresolved_pair = center_gap <= overlap_thresh and valley_relief < 0.22 + if unresolved_pair: + if (comp["area"] > prev["area"]) or ( + comp["area"] == prev["area"] and comp["amp_snr"] > prev["amp_snr"] + ): + merged[-1] = comp + else: + merged.append(comp) + + keep_idx = {int(comp["idx"]) for comp in merged} + keep_idx.update(strong_keep_idx) + display_peaks_with_prom = [ + item for item in display_peaks_with_prom if int(item[0]) in keep_idx + ] + else: + # If weak-peak Gaussian fitting did not validate any component, + # still keep strong visual peaks. + if strong_keep_idx: + display_peaks_with_prom = [ + item for item in display_peaks_with_prom if int(item[0]) in strong_keep_idx + ] + + # Prune weak shoulder-like bumps near a much stronger neighbouring peak. + # This prevents over-detecting pseudo-peaks on the flanks of broad peaks. + if len(display_peaks_with_prom) > 1 and not gaussian_validated: + by_energy = sorted(display_peaks_with_prom, key=lambda item: item[2]) + shoulder_window = max(8.0 * float(self.sampling[0]), 0.22) + weak_snr_ratio = 0.45 + weak_prom_ratio = 0.65 + local_prom_floor = max(3.5, 1.10 * float(floor)) + + pruned = [] + for idx, h, en, snr, p_snr, sal in by_energy: + strongest_neighbor = None + for o_idx, o_h, o_en, o_snr, o_p_snr, o_sal in by_energy: + if o_idx == idx: + continue + if abs(float(o_en) - float(en)) > shoulder_window: + continue + if strongest_neighbor is None or o_snr > strongest_neighbor[0]: + strongest_neighbor = (float(o_snr), float(o_p_snr), float(o_en)) + + if strongest_neighbor is None: + pruned.append((idx, h, en, snr, p_snr, sal)) + continue + + nbr_snr, nbr_prom, _ = strongest_neighbor + is_weak_shoulder = ( + float(snr) < weak_snr_ratio * max(nbr_snr, 1e-12) + and float(p_snr) < weak_prom_ratio * max(nbr_prom, 1e-12) + and float(p_snr) < local_prom_floor + ) + if not is_weak_shoulder: + pruned.append((idx, h, en, snr, p_snr, sal)) + + display_peaks_with_prom = pruned + + display_peaks = [(idx, h, en, snr) for idx, h, en, snr, _, _ in display_peaks_with_prom] + display_peaks.sort(key=lambda item: item[3], reverse=True) + + def candidate_matches(peak_energy, snr, allowed_elements=None): + matches = [] + for element_name, lines in all_info.items(): + if allowed_elements is not None and element_name not in allowed_elements: + continue + for line_name, line_info in lines.items(): + if not type(self)._line_allowed_for_element( + element_name, line_name, edge_filters + ): + continue + line_weight = float(line_info.get("weight", 0.5)) + line_energy = float(line_info["energy (keV)"]) + shell = type(self)._line_shell(line_name) + tol = ( + tolerance * 0.5 + if shell == "M" and ("Ma" not in line_name and "Mb" not in line_name) + else tolerance + ) + distance = abs(peak_energy - line_energy) + if line_weight < min_line_weight or distance > tol: + continue + score = type(self)._peak_confidence( + snr, line_weight, distance, tolerance + ) * type(self)._shell_preference_factor(shell) + matches.append( + { + "element": str(element_name), + "line": str(line_name), + "weight": line_weight, + "distance": distance, + "score": float(score), + "shell": shell, + } + ) + matches.sort(key=lambda m: m["score"], reverse=True) + return matches + + peak_matches = [] + for peak_idx, height, peak_energy, snr in display_peaks: + matches = candidate_matches(peak_energy, snr, search_elements) + if not matches: + continue + best = matches[0] + peak_matches.append( + ( + peak_idx, + height, + peak_energy, + snr, + best["element"], + f"{best['element']} {best['line']}", + best["distance"], + best["line"], + best["weight"], + best["score"], + ) + ) + + energy_min = float(np.min(E)) if len(E) else float(self.origin[0]) + energy_max = float(np.max(E)) if len(E) else energy_min + + def observable_shells_for_element(element): + shells = set() + for line_name, line_info in (all_info.get(str(element), {}) or {}).items(): + if not type(self)._line_allowed_for_element(str(element), line_name, edge_filters): + continue + shell = type(self)._line_shell(line_name) + if shell not in {"K", "L", "M"}: + continue + try: + line_energy = float(line_info.get("energy (keV)", line_info.get("energy"))) + except (TypeError, ValueError): + continue + if energy_min <= line_energy <= energy_max: + shells.add(shell) + return shells + + def strongest_observable_line(element, shell_name): + candidates = [] + for line_name, line_info in (all_info.get(str(element), {}) or {}).items(): + if not type(self)._line_allowed_for_element(str(element), line_name, edge_filters): + continue + if type(self)._line_shell(line_name) != shell_name: + continue + try: + line_energy = float(line_info.get("energy (keV)", line_info.get("energy"))) + line_weight = float(line_info.get("weight", 0.0)) + except (TypeError, ValueError): + continue + if energy_min <= line_energy <= energy_max: + candidates.append((line_weight, line_energy, str(line_name))) + return max(candidates, default=None) + + def shell_has_observable_support(element, shell_name): + strongest = strongest_observable_line(element, shell_name) + if strongest is None: + return True + + _, target_energy, _ = strongest + support_window = max(float(tolerance), 3.0 * float(self.sampling[0]), 0.04) + + for _, _, peak_energy, _ in display_peaks: + dist_to_target = abs(float(peak_energy) - float(target_energy)) + if dist_to_target > support_window: + continue + # Nearby spectral support exists for this shell line. + return True + + local_idx = np.where(np.abs(E - float(target_energy)) <= support_window)[0] + if local_idx.size == 0: + return False + + local_snr = float(np.nanmax(spec[local_idx]) / max(float(background_std), 1e-9)) + weak_bump_threshold = max(2.5, 0.35 * float(snr_threshold)) + if local_snr < weak_bump_threshold: + return False + + return True + + element_stats, line_evidence = {}, {} + for ( + _, + _, + peak_energy, + snr, + element, + _, + distance, + line_name, + line_weight, + conf, + ) in peak_matches: + if search_elements is not None and element not in search_elements: + continue + shell = type(self)._line_shell(line_name) + stats = element_stats.setdefault( + element, + { + "raw_conf": 0.0, + "shells": set(), + "lines": set(), + "strong_matches": 0, + "match_count": 0, + "best_match_conf": 0.0, + "best_match_snr": 0.0, + "best_match_energy": 0.0, + "best_match_distance": float("inf"), + "best_match_weight": 0.0, + "best_match_shell": "?", + }, + ) + label = f"{element} {line_name}" + evidence = line_evidence.setdefault( + label, + { + "match_count": 0, + "strong_matches": 0, + "best_conf": 0.0, + "best_snr": 0.0, + "energies": [], + }, + ) + + stats["raw_conf"] += float(conf) + stats["shells"].add(shell) + stats["lines"].add(line_name) + stats["match_count"] += 1 + stats["strong_matches"] += int( + snr > snr_threshold and distance < distance_threshold_for_sample + ) + if conf > stats["best_match_conf"]: + stats.update( + { + "best_match_conf": float(conf), + "best_match_snr": float(snr), + "best_match_energy": float(peak_energy), + "best_match_distance": float(distance), + "best_match_weight": float(line_weight), + "best_match_shell": shell, + } + ) + + evidence["match_count"] += 1 + evidence["energies"].append(float(peak_energy)) + evidence["strong_matches"] += int( + snr > snr_threshold and distance < distance_threshold_for_sample + ) + if conf > evidence["best_conf"]: + evidence["best_conf"] = float(conf) + evidence["best_snr"] = float(snr) + + # Collect all candidate elements across every display peak (not just best-match winners) + all_candidate_shells: dict[str, set] = {} + for peak_idx, height, peak_energy, snr in display_peaks: + for m in candidate_matches(peak_energy, snr, search_elements): + shell = m["shell"] + if shell in {"K", "L", "M"}: + all_candidate_shells.setdefault(m["element"], set()).add(shell) + + shell_hierarchy = ["K", "L", "M"] # descending energy order + + demoted_elements = set() + for element, shells in all_candidate_shells.items(): + # Prefer shells that actually won first-pass matches for this element. + # Using all candidate shells can falsely trigger higher-shell checks + # (e.g. Cu candidate L-lines) even when the element is only evidenced by K-lines. + observed_shells = set((element_stats.get(element, {}) or {}).get("shells", set())) & { + "K", + "L", + "M", + } + matched_shells = observed_shells if observed_shells else (shells & {"K", "L", "M"}) + observable = observable_shells_for_element(element) + eliminate = False + for matched_shell in matched_shells: + shell_idx = shell_hierarchy.index(matched_shell) + # Every higher-energy shell that is observable must have spectral support. + # Also verify that the supporting shell is genuine by checking its own + # strong secondary lines — prevents a coincidental neighbouring peak + # (e.g. Cu Kb1,3 near Os La1) from falsely satisfying the L-shell check. + for higher_shell in shell_hierarchy[:shell_idx]: + if higher_shell not in observable: + continue + if not shell_has_observable_support(element, higher_shell): + eliminate = True + break + if eliminate: + break + if eliminate: + demoted_elements.add(str(element)) + + element_confidence = {} + # --- Intensity ratio check and multi-peak pattern boost --- + for element, stats in element_stats.items(): + valid_shells = {shell for shell in stats["shells"] if shell in {"K", "L", "M"}} + shell_bonus = float(np.sqrt(max(1, len(valid_shells)))) + line_bonus = 1.0 + 0.30 * float(np.log1p(max(0, len(stats["lines"]) - 1))) + strong_bonus = 1.0 + 0.40 * float(np.log1p(stats["strong_matches"])) + major_bonus = 1.20 if {"K", "L"} & valid_shells else 1.0 + + # Intensity ratio logic + element_peak_intensities = {} + for ( + _, + height, + peak_energy, + snr, + el, + _, + distance, + line_name, + line_weight, + conf, + ) in peak_matches: + if el == element: + element_peak_intensities.setdefault(line_name, []).append(float(height)) + # Only consider if at least 2 lines detected + if len(element_peak_intensities) >= 2: + observed = [] + expected = [] + for line_name, intensities in element_peak_intensities.items(): + observed.append(max(intensities)) + weight = all_info.get(element, {}).get(line_name, {}).get("weight", None) + try: + expected.append(float(weight) if weight is not None else 0.0) + except Exception: + expected.append(0.0) + obs_sum = sum(observed) + exp_sum = sum(expected) + if obs_sum > 0 and exp_sum > 0: + observed_norm = [x / obs_sum for x in observed] + expected_norm = [x / exp_sum for x in expected] + ratio_score = 1.0 - ( + sum(abs(o - e) for o, e in zip(observed_norm, expected_norm)) / 2.0 + ) + ratio_factor = 1.0 + if ratio_score > 0.7: + ratio_factor = 1.15 + 0.25 * (ratio_score - 0.7) + elif ratio_score < 0.4: + ratio_factor = 0.7 + 0.5 * ratio_score + else: + ratio_factor = 1.0 + else: + ratio_factor = 1.0 + + # --- Strong pattern boost: if both main lines for K, L, or M are matched, multiply confidence by 3 (dominates score) --- + matched_lines = set(element_peak_intensities.keys()) + k_lines = {"Ka1", "Kb1"} + l_lines = {"La1", "Lb1"} + m_lines = {"Ma1", "Mb1"} + pattern_factor = 1.0 + if k_lines.issubset(matched_lines): + pattern_factor = 3.0 + elif l_lines.issubset(matched_lines): + pattern_factor = 2.5 + elif m_lines.issubset(matched_lines): + pattern_factor = 2.0 + + element_confidence[element] = ( + stats["raw_conf"] + * shell_bonus + * line_bonus + * strong_bonus + * major_bonus + * ratio_factor + * pattern_factor + ) + + detected_elements = set() + if element_confidence: + conf_values = np.asarray(list(element_confidence.values()), dtype=float) + poisson_mdl_snr = 3.0 + cutoff = max(float(np.percentile(conf_values, 45)), 0.30 * float(conf_values.max())) + for element, confidence in element_confidence.items(): + stats = element_stats[element] + lines = set(stats["lines"]) + # Criterion 1: Both main lines matched (pattern match) → always autodetect + strong_pattern = ( + {"Ka1", "Kb1"}.issubset(lines) + or {"La1", "Lb1"}.issubset(lines) + or {"Ma1", "Mb1"}.issubset(lines) + ) and confidence > 0 + # Criterion 2: High confidence above cutoff and sufficient SNR + high_confidence = ( + confidence >= cutoff and stats["best_match_snr"] >= poisson_mdl_snr + ) + if strong_pattern or high_confidence: + detected_elements.add(element) + + dominant_elements = set() + if element_confidence: + conf_values = np.asarray(list(element_confidence.values()), dtype=float) + conf_floor = max(float(np.median(conf_values)) if conf_values.size else 0.0, 1e-9) + conf_p80 = float(np.percentile(conf_values, 80)) if conf_values.size > 1 else 0.0 + for element, confidence in element_confidence.items(): + stats = element_stats.get(element, {}) + repeat_support = ( + int(stats.get("match_count", 0)) >= 2 + or int(stats.get("strong_matches", 0)) >= 1 + ) + if confidence >= conf_p80 and confidence >= 1.8 * conf_floor and repeat_support: + dominant_elements.add(element) + + anchor_elements = { + element + for element in detected_elements + if element in element_stats + and element_stats[element].get("best_match_energy", 0.0) >= 6.0 + and element_stats[element].get("best_match_weight", 0.0) >= 0.8 + } + max_detected_conf = max( + [element_confidence.get(el, 0.0) for el in detected_elements], default=0.0 + ) + + def prior_boost(element): + prior = float(element_confidence.get(element, 0.0)) / max( + float(max_detected_conf), 1e-9 + ) + factor = 1.0 + 0.5 * prior + if prior >= 0.90: + factor *= 1.9 + elif prior >= 0.75: + factor *= 1.5 + elif prior >= 0.55: + factor *= 1.2 + return prior, factor + + def consistency_boost(element, line_name, peak_energy): + is_detected = element in detected_elements + is_dominant = element in dominant_elements + if is_dominant: + scale = 1.0 + elif is_detected: + scale = 0.80 + else: + scale = 0.65 + # First, check evidence for this exact line + evidence = line_evidence.get(f"{element} {line_name}") + if evidence and any( + abs(float(peak_energy) - float(prev)) <= 0.04 + for prev in evidence.get("energies", []) + ): + best_conf = float(evidence.get("best_conf", 0.0)) + best_snr = float(evidence.get("best_snr", 0.0)) + strong = int(evidence.get("strong_matches", 0)) + line_weight = float( + (all_info.get(element, {}).get(line_name, {}) or {}).get("weight", 0.5) + ) + tier = 1.0 + 0.7 * max(0.0, line_weight - 0.35) + if strong >= 1 and best_conf >= 1.4: + return min(3.2, scale * 2.4 * tier) + if best_conf >= 1.1 and best_snr >= max(floor, 0.75 * snr_threshold): + return min(2.6, scale * 1.9 * tier) + if best_conf >= 0.8: + return min(2.0, scale * 1.5 * tier) + return min(1.5, scale * 1.2 * tier) + # Element was matched via a different line — boost secondary lines of this element + stats = element_stats.get(element, {}) + elem_conf = float(element_confidence.get(element, 0.0)) + elem_strong = int(stats.get("strong_matches", 0)) + line_weight = float( + (all_info.get(element, {}).get(line_name, {}) or {}).get("weight", 0.5) + ) + tier = 1.0 + 0.5 * max(0.0, line_weight - 0.35) + if elem_strong >= 1 and elem_conf >= 1.4: + return min(2.4, scale * 1.8 * tier) + if elem_conf >= 1.1: + return min(2.0, scale * 1.5 * tier) + if elem_conf >= 0.8: + return min(1.6, scale * 1.2 * tier) + return min(1.3, scale * 1.1 * tier) + + def dominant_boost(element): + if element not in dominant_elements: + return 1.0 + prior, _ = prior_boost(element) + stats = element_stats.get(element, {}) + repeat_support = max( + int(stats.get("strong_matches", 0)), max(0, int(stats.get("match_count", 0)) - 1) + ) + base = 2.30 if prior >= 0.90 else 1.85 if prior >= 0.75 else 1.45 + if repeat_support >= 2: + base *= 1.10 + return min(base, 2.60) + + def reranked_matches(peak_energy, snr, allowed_elements=None, top_k=None): + # Compute which elements have both main lines matched (pattern boost) + element_to_lines = {} + for _, _, _, _, el, _, _, ln, _, _ in peak_matches: + element_to_lines.setdefault(el, set()).add(ln) + + def _has_main_line(lines, target): + # Accept compact aliases from x_ray_lines.csv such as La1,2 or Kb1,3. + for ln in lines: + name = str(ln) + if name == target: + return True + if target in {"Ka1", "Kb1", "La1", "Lb1", "Ma1", "Mb1"} and name.startswith( + target + "," + ): + return True + return False + + def _canonical_aliases(target): + if target == "La1": + return ("La1", "La1,2") + if target == "Lb1": + return ("Lb1",) + if target == "Ka1": + return ("Ka1",) + if target == "Kb1": + return ("Kb1", "Kb1,3") + return (target,) + + def _line_evidence_strength(element, target): + best = 0.0 + for alias in _canonical_aliases(target): + ev = line_evidence.get(f"{element} {alias}") + if not ev: + continue + best_conf = float(ev.get("best_conf", 0.0)) + strong = float(ev.get("strong_matches", 0)) + count = float(ev.get("match_count", 0)) + score = best_conf + 0.45 * strong + 0.15 * count + if score > best: + best = score + return best + + def _l_support_strength(element): + return _line_evidence_strength(element, "La1") + _line_evidence_strength( + element, "Lb1" + ) + + candidates = candidate_matches(peak_energy, snr, allowed_elements) + + # For weak peaks, also consider a relaxed-distance pass for already + # detected/dominant elements. This keeps context-consistent lines in + # play (e.g. Te continuation) even when local calibration/noise shifts + # push them slightly beyond the strict tolerance window. + weak_peak = float(snr) < max(2.5 * float(floor), 0.30 * float(snr_threshold)) + if weak_peak: + relaxed_tol = max(float(tolerance), 0.30) + context_elements = set(map(str, detected_elements | dominant_elements)) + if context_elements: + for element_name, lines in all_info.items(): + element_name = str(element_name) + if element_name not in context_elements: + continue + if allowed_elements is not None and element_name not in allowed_elements: + continue + for line_name, line_info in lines.items(): + if not type(self)._line_allowed_for_element( + element_name, line_name, edge_filters + ): + continue + line_weight = float(line_info.get("weight", 0.5)) + line_energy = float(line_info["energy (keV)"]) + shell = type(self)._line_shell(line_name) + tol = ( + relaxed_tol * 0.5 + if shell == "M" + and ("Ma" not in line_name and "Mb" not in line_name) + else relaxed_tol + ) + distance = abs(float(peak_energy) - line_energy) + if line_weight < min_line_weight or distance > tol: + continue + score = type(self)._peak_confidence( + snr, line_weight, distance, relaxed_tol + ) * type(self)._shell_preference_factor(shell) + candidates.append( + { + "element": element_name, + "line": str(line_name), + "weight": line_weight, + "distance": distance, + "score": float(score), + "shell": shell, + } + ) + + # De-duplicate exact element/line candidates, keeping highest score. + if candidates: + uniq = {} + for c in candidates: + key = (str(c["element"]), str(c["line"])) + prev = uniq.get(key) + if prev is None or float(c["score"]) > float(prev["score"]): + uniq[key] = c + candidates = list(uniq.values()) + candidates.sort(key=lambda m: m["score"], reverse=True) + + # Performance guard: the precedence logic below is O(n^2) over + # candidates. Keep the strongest candidates, but always retain + # context-important elements (confirmed/dominant/preferred). + max_candidates = 48 + if len(candidates) > max_candidates: + context_elements = set( + map(str, detected_elements | dominant_elements | preferred_elements) + ) + trimmed = list(candidates[:max_candidates]) + if context_elements: + kept_keys = {(str(c["element"]), str(c["line"])) for c in trimmed} + for c in candidates[max_candidates:]: + key = (str(c["element"]), str(c["line"])) + if key in kept_keys: + continue + if str(c["element"]) in context_elements: + trimmed.append(c) + kept_keys.add(key) + candidates = trimmed + + element_has_l_support = {} + element_has_l_pair = {} + for el, lines in element_to_lines.items(): + has_la = _has_main_line(lines, "La1") + has_lb = _has_main_line(lines, "Lb1") + element_has_l_support[str(el)] = has_la or has_lb + element_has_l_pair[str(el)] = has_la and has_lb + + # Guard against boosted confirmed elements stealing a peak from a + # much-closer strong K/L candidate (e.g. Cu Ka around 8 keV). + distance_anchor = None + for candidate in candidates: + if candidate["shell"] not in {"K", "L"}: + continue + if float(candidate["weight"]) < 0.30: + continue + if float(candidate["score"]) < 0.45: + continue + distance_anchor = max(float(candidate["distance"]), 1e-9) + break + + scored = [] + for match in candidates: + element, line_name, shell = match["element"], match["line"], match["shell"] + is_demoted = str(element) in demoted_elements + + minor_l_penalty = 1.0 + + # Logical guard for L-series continuation: do not let an orphan + # minor L-line assignment (Ll/Lg/Lb2) outrank a closer candidate + # from an element that already shows L-series support (La/Lb). + if ( + shell == "L" + and line_name in {"Ll", "Lg1", "Lb2,15"} + and not element_has_l_support.get(str(element), False) + ): + supported_closer_exists = False + for other in candidates: + other_el = str(other["element"]) + if other_el == str(element): + continue + if other["shell"] != "L": + continue + if not element_has_l_support.get(other_el, False): + continue + if float(other["distance"]) < float(match["distance"]): + supported_closer_exists = True + break + if supported_closer_exists: + minor_l_penalty *= 0.55 + + # Stricter logical precedence: an orphan minor L-line must not + # outrank a closer L-line from an element with an established + # La/Lb pair in this spectrum. + if ( + shell == "L" + and line_name in {"Ll", "Lg1", "Lb2,15"} + and not element_has_l_pair.get(str(element), False) + ): + paired_closer_exists = False + for other in candidates: + other_el = str(other["element"]) + if other_el == str(element): + continue + if other["shell"] != "L": + continue + if not element_has_l_pair.get(other_el, False): + continue + if float(other["distance"]) <= float(match["distance"]): + paired_closer_exists = True + break + if paired_closer_exists: + minor_l_penalty *= 0.70 + + # Evidence-strength precedence for minor L-lines: + # if another element has materially stronger La/Lb evidence and + # a comparable-or-better distance match, do not keep the weaker + # minor-L candidate as a possible winner. + if shell == "L" and line_name in {"Ll", "Lg1", "Lb2,15"}: + this_el = str(element) + this_dist = float(match["distance"]) + this_support = _l_support_strength(this_el) + beaten_by_stronger = False + for other in candidates: + other_el = str(other["element"]) + if other_el == this_el or other["shell"] != "L": + continue + other_support = _l_support_strength(other_el) + if other_support <= max(0.4, this_support + 0.30): + continue + # Accept up to 30 eV slack so support can break near ties. + if float(other["distance"]) <= this_dist + 0.03: + beaten_by_stronger = True + break + if beaten_by_stronger: + minor_l_penalty *= 0.60 + + prior, prior_factor = prior_boost(element) + pref = 1.35 if element in preferred_elements else 1.0 + anchor = 1.15 if element in anchor_elements and shell in {"K", "L"} else 1.0 + dom = dominant_boost(element) + # Pattern boost: if both main lines for K, L, or M are matched by detected peaks, boost candidate score + lines_matched = element_to_lines.get(element, set()) + has_k_pair = _has_main_line(lines_matched, "Ka1") and _has_main_line( + lines_matched, "Kb1" + ) + has_l_pair = _has_main_line(lines_matched, "La1") and _has_main_line( + lines_matched, "Lb1" + ) + has_m_pair = _has_main_line(lines_matched, "Ma1") and _has_main_line( + lines_matched, "Mb1" + ) + pattern_factor = 1.0 + if has_k_pair: + pattern_factor = 3.0 + elif has_l_pair: + pattern_factor = 2.5 + elif has_m_pair: + pattern_factor = 2.0 + + if shell == "M": + prior_factor = 1.0 + 0.3 * prior + dom = min(dom, 1.30) + + # Guard against introducing new singleton elements on weak peaks. + # If an element is not already detected/dominant and only appears + # as an isolated line, require a very tight distance match. + if ( + weak_peak + and element not in detected_elements + and element not in dominant_elements + ): + matched_lines_for_el = element_to_lines.get(element, set()) + # Consider an element supported if it already has any matched + # line in the current spectrum, or strong line evidence from + # first-pass matching. This avoids dropping context-consistent + # secondary lines (e.g. Cu Kb1,3 after Cu Ka1 is matched). + element_line_strength = 0.0 + for ln in matched_lines_for_el: + ev = line_evidence.get(f"{element} {ln}") + if not ev: + continue + best_conf = float(ev.get("best_conf", 0.0)) + strong = float(ev.get("strong_matches", 0)) + count = float(ev.get("match_count", 0)) + element_line_strength = max( + element_line_strength, best_conf + 0.4 * strong + 0.1 * count + ) + + has_support = len(matched_lines_for_el) >= 1 or element_line_strength >= 0.8 + if not has_support: + if float(match["distance"]) > 0.035: + continue + prior_factor *= 0.65 + # For confirmed elements (detected or dominant), the line_weight prior is irrelevant — + # we already know the element is present. Use weight=1.0 and score purely on distance + # so that e.g. Cu Kb1 (weight=0.17) beats Os La1 (weight=1.0) when Cu is confirmed + # and Cu Kb1 is closer to the measured peak. + confirmed = element in detected_elements or element in dominant_elements + if confirmed: + sigma = max(float(tolerance) / 3.0, 1e-9) + distance_factor = np.exp(-0.5 * (float(match["distance"]) / sigma) ** 2) + base_score = ( + np.log1p(max(float(snr), 0.0)) + * 1.0 + * distance_factor + * type(self)._shell_preference_factor(shell) + ) + # Once an element is clearly present, prefer physically + # consistent continuation lines over introducing new + # elements for nearby ambiguous peaks. + continuation = consistency_boost(element, line_name, peak_energy) + base_score = base_score * max(1.0, min(float(continuation), 1.8)) + else: + base_score = match["score"] + consistency = consistency_boost(element, line_name, peak_energy) + # Non-confirmed elements should not gain an aggressive + # boost that steals peaks from already-confirmed elements. + base_score = base_score * min(1.0, float(consistency)) + score = ( + base_score + * prior_factor + * pref + * anchor + * dom + * pattern_factor + * minor_l_penalty + ) + + # If there is a strong nearby K/L anchor, damp long-distance + # takeovers that are caused mainly by cross-peak boosts. + if ( + confirmed + and distance_anchor is not None + and float(match["distance"]) > distance_anchor + ): + ratio = float(match["distance"]) / distance_anchor + if ratio >= 2.0: + score *= ratio**-1.6 + + scored.append({**match, "score": float(score), "demoted": bool(is_demoted)}) + + # Ranking-only policy: keep shell-inconsistent elements as options, + # but place them behind more plausible (non-demoted) candidates. + scored.sort(key=lambda m: (bool(m.get("demoted", False)), -float(m["score"]))) + if mode == "elements_preferred" and preferred_elements: + preferred = [m for m in scored if m["element"] in preferred_elements] + scored = ( + preferred + [m for m in scored if m["element"] not in preferred_elements] + if preferred + else scored + ) + + unique, seen = [], set() + for match in scored: + label = f"{match['element']} {match['line']}" + if label in seen: + continue + seen.add(label) + unique.append(match) + + if top_k is None or len(unique) <= 1: + return unique + selected = [unique[0]] + used_elements = {unique[0]["element"]} + for match in unique[1:]: + if match["element"] in used_elements: + continue + selected.append(match) + used_elements.add(match["element"]) + if len(selected) >= int(top_k): + return selected + for match in unique[1:]: + if match not in selected: + selected.append(match) + if len(selected) >= int(top_k): + break + return selected + + rematch_allowed = { + str(match[4]) for match in peak_matches if str(match[4]) not in ignored_elements + } + rematch_allowed.update(map(str, detected_elements)) + rematch_allowed.update(preferred_elements) + + refined_peak_matches = [] + for peak_idx, height, peak_energy, snr in display_peaks: + best = reranked_matches(peak_energy, snr, search_elements, top_k=1) + best = best[0] if best else None + if best is None: + continue + refined_peak_matches.append( + ( + peak_idx, + height, + peak_energy, + snr, + best["element"], + f"{best['element']} {best['line']}", + best["distance"], + best["line"], + best["weight"], + best["score"], + ) + ) + peak_matches = refined_peak_matches + + # Backfill element_confidence for elements that only appear after the + # unrestricted re-rank (e.g. not in search_elements so never entered + # element_stats in the first pass). Use the same base formula as + # _peak_confidence so the displayed value is meaningful. + for ( + _, + height, + peak_energy, + snr, + element, + _, + distance, + line_name, + line_weight, + _, + ) in peak_matches: + if element in element_confidence: + continue + sigma = max(float(tolerance) / 3.0, 1e-9) + dist_factor = float(np.exp(-0.5 * (float(distance) / sigma) ** 2)) + raw = float( + np.log1p(max(float(snr), 0.0)) * max(float(line_weight), 0.0) * dist_factor + ) + shell = type(self)._line_shell(str(line_name)) + valid_shells = {shell} & {"K", "L", "M"} + major_bonus = 1.20 if {"K", "L"} & valid_shells else 1.0 + element_confidence[element] = raw * major_bonus + + matched_elements = {str(match[4]) for match in peak_matches} + detected_elements = { + str(el) + for el in detected_elements + if str(el) in matched_elements and str(el) not in ignored_elements + } + if mode == "elements_preferred": + detected_elements.update( + str(el) for el in preferred_elements if str(el) in matched_elements + ) + refined_match_by_idx = {int(match[0]): match for match in peak_matches} + plot_peaks = display_peaks[:peaks] + plot_peak_indices = {int(pk_idx) for pk_idx, _, _, _ in plot_peaks} + + final_matches_by_element: dict[str, set[str]] = {} + for _, _, _, _, element, _, _, line_name, _, _ in peak_matches: + if element not in ignored_elements: + final_matches_by_element.setdefault(element, set()).add(str(line_name)) + + # For display purposes (table + plot), restrict to elements/lines seen in plot_peaks + plot_matches_by_element: dict[str, set[str]] = {} + for pk_idx, _, _, _, element, _, _, line_name, _, _ in peak_matches: + if int(pk_idx) in plot_peak_indices and element not in ignored_elements: + plot_matches_by_element.setdefault(element, set()).add(str(line_name)) + + candidate_elements = sorted( + str(el) for el in final_matches_by_element if str(el) not in detected_elements + ) + possible_elements = set(candidate_elements) + + plot_all_identified = set( + el + for el in (set(detected_elements) | set(candidate_elements)) + if el in plot_matches_by_element + ) + if plot_all_identified: + det_rows = [] + for element in sorted(map(str, plot_all_identified)): + conf = element_confidence.get(element, 0.0) + lines_matched = sorted(map(str, plot_matches_by_element.get(element, set()))) + if element in detected_elements: + status = "Dominant" if element in dominant_elements else "Detected" + else: + status = "Possible" + det_rows.append( + (element, status, conf, ", ".join(lines_matched) if lines_matched else "-") + ) + det_rows.sort( + key=lambda r: (0 if r[1] == "Dominant" else 1 if r[1] == "Detected" else 2, -r[2]) + ) + print(f"\n{'Element':<10} {'Confidence':<12} {'Matched Lines'}") + print("-" * 50) + for el, status, conf, lines_str in det_rows: + print(f"{el:<10} {conf:<12.3f} {lines_str}") + print("-" * 50) + else: + print("\nDetected: None") + + elements_for_color = set(detected_elements) | {str(match[4]) for match in peak_matches} + if search_elements is not None: + elements_for_color.update(map(str, search_elements)) + palette = [ + "#1f77b4", + "#d62728", + "#2ca02c", + "#9467bd", + "#ff7f0e", + "#8c564b", + "#e377c2", + "#17becf", + "#bcbd22", + "#7f7f7f", + "#003f5c", + "#7a5195", + "#ef5675", + "#ffa600", + "#2f4b7c", + ] + element_color_map = { + el: palette[i % len(palette)] for i, el in enumerate(sorted(elements_for_color)) + } + y_min = float(np.nanmin(spec)) if len(spec) else 0.0 + y_max = float(np.nanmax(spec)) if len(spec) else 1.0 + y_scale = max(max(1e-9, y_max - y_min), abs(y_max), abs(y_min), 1e-6) + y_dot = -0.04 * y_scale + + def infer_requested_color(peak_energy): + if reference_elements is None: + return None + best_element, best_distance = None, float("inf") + for element in reference_elements: + for line_name, line_info in (all_info.get(str(element), {}) or {}).items(): + if not type(self)._line_allowed_for_element( + str(element), line_name, edge_filters + ): + continue + try: + distance = abs(float(peak_energy) - float(line_info.get("energy (keV)"))) + except (TypeError, ValueError): + continue + if distance < best_distance: + best_distance, best_element = distance, str(element) + return best_element + + table_rows = [] + for peak_idx, height, peak_energy, snr in plot_peaks: + match = refined_match_by_idx.get(int(peak_idx)) + color = ( + element_color_map.get(match[4], "red") + if match is not None + else element_color_map.get(str(infer_requested_color(peak_energy)), "red") + ) + + if not in_ignore(peak_energy): + # Only plot solid lines for matched peaks (autodetected or requested elements) + if match is not None: + ax_spec.axvline( + peak_energy, color=color, linestyle="-", alpha=0.5, linewidth=1.5 + ) + else: + ax_spec.plot( + [peak_energy], + [y_dot], + marker="|", + markersize=4, + color="gray", + alpha=0.8, + linestyle="None", + ) + + if show_text and match is not None: + for grid_element, grid_energy in grid_peaks.items(): + if abs(peak_energy - grid_energy) < 0.1: + ax_spec.text( + peak_energy, + height * 0.7, + f"{grid_element}\n(grid)", + ha="center", + va="bottom", + fontsize=8, + color="gray", + style="italic", + ) + print(f"Peak at {peak_energy} keV may come from the grid.") + break + + def label_with_energy_and_ratio(label): + # label is like 'Fe Ka', want to append (energy, ratio) from all_info and observed/expected + if not label or label == "-" or label == "Unmatched" or label == "Unknown": + return label + parts = label.split() + if len(parts) < 2: + return label + element, line = parts[0], parts[1].replace("*", "") + line_info = all_info.get(element, {}).get(line, {}) + ref_energy = None + if isinstance(line_info, dict): + ref_energy = line_info.get("energy (keV)", line_info.get("energy")) + try: + ref_energy = float(ref_energy) + except (TypeError, ValueError): + ref_energy = None + label_core = label.rstrip("*") + star = "*" if label.endswith("*") else "" + if ref_energy is not None: + return f"{label_core} ({ref_energy:.3f}){star}" + else: + return label + + if match is None: + table_rows.append( + ( + peak_energy, + height, + snr, + "Unmatched" if search_elements is not None else "Unknown", + "-", + "-", + ) + ) + continue + + # Best match for the table MUST be the same element/line shown on the spectrum + # (from refined_match_by_idx). Preserve elements_only filtering for alternatives. + best_label = f"{match[4]} {match[7]}" + ranked = reranked_matches(peak_energy, snr, search_elements, top_k=3) + labels = [ + (f"{m['element']} {m['line']}", float(m["score"]), m["element"], m["line"]) + for m in ranked + ] + # If the spectrum winner appears in ranked, use that ordering; otherwise prepend it. + if not any(lbl.lower() == best_label.lower() for lbl, _, _, _ in labels): + labels = [(best_label, 0.0, match[4], match[7])] + labels + + def fmt(label): + label = ( + f"{label}*" + if requested_elements and str(label).split()[0] in requested_elements + else label + ) + return label + + remaining = [ + (label, score, elem, line) + for label, score, elem, line in labels + if label.lower() != best_label.lower() + ] + + table_rows.append( + ( + peak_energy, + height, + snr, + label_with_energy_and_ratio(fmt(best_label)), + label_with_energy_and_ratio(fmt(remaining[0][0])) + if len(remaining) > 0 + else "-", + label_with_energy_and_ratio(fmt(remaining[1][0])) + if len(remaining) > 1 + else "-", + ) + ) + + current_bottom, current_top = ax_spec.get_ylim() + padded_bottom = min(current_bottom, y_min - 0.10 * y_scale) + padded_top = max(current_top, y_max + 0.18 * y_scale) + ax_spec.set_ylim(bottom=padded_bottom, top=padded_top) + + label_candidates = [] + top_label_y = 0.99 + peak_label_y = 0.92 + # Plot reference lines (dotted) ONLY for explicitly requested elements, not for autodetected/possible + if requested_elements: + energy_min, energy_max = float(np.min(E)), float(np.max(E)) + matched_by_element = {} + for _, _, peak_energy, _, element, _, _, _, _, _ in peak_matches: + matched_by_element.setdefault(str(element), []).append(float(peak_energy)) + + for element in sorted(requested_elements): + candidates = [] + for line_name, line_info in (all_info.get(str(element), {}) or {}).items(): + if not type(self)._line_allowed_for_element( + str(element), line_name, edge_filters + ): + continue + try: + line_energy = float(line_info.get("energy (keV)")) + line_weight = float(line_info.get("weight", 0.0)) + except (TypeError, ValueError): + continue + if energy_min <= line_energy <= energy_max: + candidates.append((str(line_name), line_energy, line_weight)) + candidates = sorted( + [c for c in candidates if c[2] >= 0.05] or candidates, + key=lambda item: item[2], + reverse=True, + )[:6] + for line_name, line_energy, _ in candidates: + if in_ignore(line_energy): + continue + # Skip if already matched by a detected peak + if any( + abs(line_energy - matched_energy) <= max(0.05, 0.5 * tolerance) + for matched_energy in matched_by_element.get(str(element), []) + ): + continue + color = element_color_map.get(str(element), "gray") + style = "--" + alpha = 0.5 + ax_spec.axvline( + line_energy, color="gray", linestyle=style, alpha=alpha, linewidth=1.2 + ) + label_candidates.append( + ( + float(line_energy), + f"{element} {line_name}", + color, + style, + float(top_label_y), + "axes_top", + 8, + "normal", + 0.8, + ) + ) + + if show_text and peak_matches: + label_allowed = set(detected_elements) | possible_elements + if requested_elements: + label_allowed.update(str(el) for el in requested_elements) + for pk_idx, _height, peak_energy, _, element, match_str, _, _, _, _ in peak_matches: + if int(pk_idx) not in plot_peak_indices: + continue + is_requested = requested_elements is not None and element in requested_elements + if element not in label_allowed or in_ignore(peak_energy): + continue + label = f"{element} {match_str.split()[-1]}" + ("*" if is_requested else "") + label_candidates.append( + ( + float(peak_energy), + label, + element_color_map.get(element, "black"), + "-", + float(peak_label_y), + "axes_peak", + 10, + "bold", + 1.0, + ) + ) + + legend_handles, legend_labels = [], set() + if show_text and label_candidates: + label_candidates.sort(key=lambda item: item[0]) + drawn_texts = [] + for ( + peak_energy, + label_text, + color, + linestyle, + y_value, + y_mode, + font_size, + font_weight, + alpha_value, + ) in label_candidates: + common = dict( + ha="center", + fontsize=font_size, + color=color, + weight=font_weight, + rotation=90, + alpha=alpha_value, + ) + if y_mode in {"axes_top", "axes_peak"}: + txt = ax_spec.text( + peak_energy, + y_value, + label_text, + va="top", + transform=ax_spec.get_xaxis_transform(), + clip_on=True, + **common, + ) + else: + txt = ax_spec.text(peak_energy, y_value, label_text, va="bottom", **common) + # Prioritize data-peak labels over top reference labels if collisions occur. + priority = 1 if y_mode in {"data", "axes_peak"} else 0 + drawn_texts.append((txt, label_text, color, linestyle, priority)) + + if drawn_texts: + fig.canvas.draw() + ax_bbox = ax_spec.get_window_extent() + renderer = fig.canvas.get_renderer() + kept_bboxes = [] + # Keep higher-priority labels first, then by x-position for stable layout. + drawn_texts.sort(key=lambda item: (-item[4], item[0].get_position()[0])) + for txt, label_text, color, linestyle, _ in drawn_texts: + txt_bbox = txt.get_window_extent(renderer=renderer) + out_of_bounds = ( + txt_bbox.x0 < ax_bbox.x0 + or txt_bbox.x1 > ax_bbox.x1 + or txt_bbox.y0 < ax_bbox.y0 + or txt_bbox.y1 > ax_bbox.y1 + ) + overlaps_kept = any(txt_bbox.overlaps(prev_bbox) for prev_bbox in kept_bboxes) + + if out_of_bounds or overlaps_kept: + txt.remove() + key = (label_text, str(color), linestyle) + if key not in legend_labels: + legend_labels.add(key) + legend_handles.append( + Line2D( + [0], + [0], + color=color, + linestyle=linestyle, + linewidth=1.5, + label=label_text, + ) + ) + else: + kept_bboxes.append(txt_bbox) + + if legend_handles: + overlap_legend = ax_spec.legend( + handles=legend_handles, loc="upper right", fontsize=8, title="Overlapping Labels" + ) + ax_spec.add_artist(overlap_legend) + + if line is not None: + x_min, x_max = ax_spec.get_xlim() + _ref_energies = [line] if isinstance(line, (int, float)) else list(line) + for ref_energy in _ref_energies: + try: + ref_energy = float(ref_energy) + except (TypeError, ValueError): + continue + # Do not let out-of-window reference lines change autoscaled limits. + if x_min <= ref_energy <= x_max: + ax_spec.axvline( + ref_energy, color="black", linestyle="--", linewidth=1.2, zorder=3 + ) + ax_spec.set_xlim(x_min, x_max) + + fig.tight_layout() + plt.show() + + sorted_table_rows = sorted(table_rows, key=lambda item: item[0]) + print( + f"{'Energy (keV)':<12} {'Intensity':<12} {'SNR':<8} {'Best Match':<22} {'Alt 2':<22} {'Alt 3':<22}" + ) + print("-" * 105) + for peak_energy, height, snr, best_match, alt_2, alt_3 in sorted_table_rows: + print( + f"{peak_energy:<12.3f} {height:<12.2f} {snr:<8.1f} {best_match:<22} {alt_2:<22} {alt_3:<22}" + ) + print("-" * 105) + print( + f"{len(plot_peaks)} of {len(display_peaks)} peaks above " + f"floor={floor:.1f}, snr_threshold={snr_threshold:.1f} displayed.\n" + ) + + if return_details: + return { + "figure": fig, + "axes": (ax_img, ax_spec), + "detected_elements": sorted(detected_elements), + "element_confidence": element_confidence, + "display_peaks": display_peaks, + "peak_matches": peak_matches, + "floor": floor, + "snr_quantile_floor": floor, + "snr_min": floor, + "snr_threshold": snr_threshold, + } + return fig, (ax_img, ax_spec) + + def _fit_mean_model_pytorch( + self, + energy_axis, + spectrum_raw, + elements_to_fit, + peak_width, + polynomial_background_degree, + num_iters, + optimizer, + lr, + loss_name, + normalize_target, + default_lr_adam, + default_lr_lbfgs, + verbose=False, + ): + """Fit a single mean spectrum using the PyTorch EDS model.""" + target = spectrum_raw + spectrum_offset = torch.tensor(0.0, dtype=spectrum_raw.dtype, device=spectrum_raw.device) + spectrum_scale = torch.tensor(1.0, dtype=spectrum_raw.dtype, device=spectrum_raw.device) + if normalize_target: + spectrum_offset = spectrum_raw.min() + spectrum_scale = torch.clamp(spectrum_raw.max() - spectrum_offset, min=1e-8) + target = (spectrum_raw - spectrum_offset) / spectrum_scale + + background = PolynomialBackground( + energy_axis, + degree=polynomial_background_degree, + ) + peaks = GaussianPeaks( + energy_axis, + peak_width=peak_width, + elements_to_fit=elements_to_fit, + ) + model = EDSModel(peaks, background) + model = model.to(device=energy_axis.device, dtype=energy_axis.dtype) + if len(model.peak_model.element_names) == 0: + raise ValueError("No elements found in the selected energy range/elements_to_fit.") + + optimizer_name = optimizer.lower() + if optimizer_name == "adam": + if lr is None: + lr = default_lr_adam + optimizer_obj = torch.optim.Adam(model.parameters(), lr=lr) + elif optimizer_name == "lbfgs": + if lr is None: + lr = default_lr_lbfgs + optimizer_obj = torch.optim.LBFGS( + model.parameters(), + lr=lr, + line_search_fn="strong_wolfe", + ) + else: + raise ValueError("optimizer must be 'lbfgs' or 'adam'") + + loss_iter = [] + for i in range(num_iters): + if optimizer_name == "lbfgs": + + def closure(): + optimizer_obj.zero_grad() + predicted = model() + loss = eds_data_loss(predicted, target, loss=loss_name) + loss.backward() + return loss + + loss = optimizer_obj.step(closure) + if not torch.is_tensor(loss): + with torch.no_grad(): + loss = eds_data_loss(model(), target, loss=loss_name) + else: + optimizer_obj.zero_grad() + predicted = model() + loss = eds_data_loss(predicted, target, loss=loss_name) + loss.backward() + optimizer_obj.step() + + loss_iter.append(float(loss.detach().cpu().item())) + if verbose and ((i + 1) % max(1, num_iters // 10) == 0 or i == 0): + print(f"iter {i + 1:4d}/{num_iters}: loss={loss_iter[-1]:.6g}") + + with torch.no_grad(): + final_pred_target = model() + if normalize_target: + final_pred_raw = final_pred_target * spectrum_scale + spectrum_offset + else: + final_pred_raw = final_pred_target + + return { + "model": model, + "loss_history": np.asarray(loss_iter), + "final_pred_raw": final_pred_raw.detach(), + "spectrum_offset": spectrum_offset.detach(), + "spectrum_scale": spectrum_scale.detach(), + } + + def fit_spectrum_mean_pytorch( + self, + energy_range=None, + elements_to_fit=None, + peak_width=0.1, + num_iters=1000, + lr=None, + polynomial_background_degree=3, + optimizer="lbfgs", + device=None, + ): + """Fit the spatially-summed mean EDS spectrum and display results. + + A convenience wrapper around :meth:`_fit_mean_model_pytorch` that + handles device selection, energy windowing, and result visualization. + + Parameters + ---------- + energy_range : sequence[float] | None, optional + Two-element energy interval ``[emin, emax]`` in keV. If ``None``, + the full energy axis is used. + elements_to_fit : sequence[str] | None, optional + Element symbols to include in the fit. If ``None``, uses keys + from ``self.model_elements``. + peak_width : float, optional + Initial FWHM-like peak width in keV. + num_iters : int, optional + Number of optimization iterations. + lr : float | None, optional + Learning rate. If ``None``, an optimizer-specific default is used. + polynomial_background_degree : int, optional + Degree of the polynomial background basis. + optimizer : {"adam", "lbfgs"}, optional + Optimizer to use. + device : str | torch.device | None, optional + Torch device. If ``None``, uses CUDA when available. + + Returns + ------- + dict + Keys include ``loss_history``, ``fitted_spectrum``, + ``input_spectrum``, ``background_spectrum``, ``concentrations``, + ``element_names``, ``peak_widths``, ``energy_axis``, and + ``fit_range``. + """ + optimizer_name = str(optimizer).lower() + if optimizer_name not in {"adam", "lbfgs"}: + raise ValueError("optimizer must be 'lbfgs' or 'adam'") + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(device) + if device.type == "cuda" and not torch.cuda.is_available(): + raise ValueError("CUDA device requested but torch.cuda.is_available() is False.") + + if elements_to_fit is None: + if not self.model_elements: + raise ValueError("elements_to_fit must be specified") + elements_to_fit = list(self.model_elements.keys()) + print(f"using model_elements {elements_to_fit}") + + energy_axis_np = self.energy_axis.copy() + energy_axis = torch.tensor(energy_axis_np, dtype=torch.float32, device=device) + spectra = torch.tensor(self.array, dtype=torch.float32, device=device) + + if energy_range is not None: + ind = (energy_axis >= energy_range[0]) & (energy_axis <= energy_range[1]) + energy_axis = energy_axis[ind] + spectra = spectra[ind] + else: + energy_range = [float(energy_axis.min().item()), float(energy_axis.max().item())] + + print("fitting spectrum globally") + spectrum_raw = spectra.sum((-1, -2)) + mean_fit = self._fit_mean_model_pytorch( + energy_axis=energy_axis, + spectrum_raw=spectrum_raw, + elements_to_fit=elements_to_fit, + peak_width=peak_width, + polynomial_background_degree=polynomial_background_degree, + num_iters=num_iters, + optimizer=optimizer_name, + lr=lr, + loss_name="mse", + normalize_target=True, + default_lr_adam=1e-3, + default_lr_lbfgs=1.0, + verbose=True, + ) + + model = mean_fit["model"] + loss_history = mean_fit["loss_history"] + spectrum_offset = mean_fit["spectrum_offset"] + spectrum_scale = mean_fit["spectrum_scale"] + with torch.no_grad(): + final_pred = mean_fit["final_pred_raw"].cpu().numpy() + shell_concs = ( + nn.functional.softplus(model.peak_model.concentrations).detach().cpu().numpy() + ) + shell_names = list(model.peak_model.shell_group_names) + shell_element_indices = ( + model.peak_model.shell_group_element_indices.detach().cpu().numpy() + ) + concs = np.zeros(len(model.peak_model.element_names), dtype=np.float32) + np.add.at(concs, shell_element_indices, shell_concs) + final_fwhm = ( + torch.nn.functional.softplus(model.peak_model.peak_width_by_peak) + .detach() + .cpu() + .numpy() + ) + background_fit = ( + (model.background_model().detach() * spectrum_scale + spectrum_offset) + .cpu() + .numpy() + ) + + print( + f"\nFinal: width median={np.median(final_fwhm):.3f} keV, " + f"min={final_fwhm.min():.3f}, max={final_fwhm.max():.3f}" + ) + + top_n = max(10, len(elements_to_fit) if elements_to_fit is not None else 0) + sorted_indices = np.argsort(concs)[::-1] + print("\nTop elements:") + for i, idx in enumerate(sorted_indices[:top_n], 1): + elem = model.peak_model.element_names[idx] + conc = concs[idx] + print(f"{i:2d}. {elem:2s}: {conc:.3f}") + + shell_top_n = max(10, min(len(shell_names), top_n)) + shell_sorted_indices = np.argsort(shell_concs)[::-1] + print("\nTop edge groups:") + for i, idx in enumerate(shell_sorted_indices[:shell_top_n], 1): + shell_name = shell_names[idx] + shell_conc = shell_concs[idx] + print(f"{i:2d}. {shell_name:>6s}: {shell_conc:.3f}") + + energy_axis_plot = energy_axis.detach().cpu().numpy() + spectrum_raw_plot = spectrum_raw.detach().cpu().numpy() + fig, ax = plt.subplots(2, 1, figsize=(10, 6)) + ax[0].plot(np.arange(loss_history.shape[0]), loss_history, color="k") + ax[0].set_title("loss") + ax[0].set_xlabel("iterations") + ax[0].set_ylabel("loss") + ax[0].set_yscale("log") + + ax[1].plot(energy_axis_plot, spectrum_raw_plot, "k-", label="Data", linewidth=1) + ax[1].plot(energy_axis_plot, final_pred, "r-", label="Fit", linewidth=2) + ax[1].plot( + energy_axis_plot, + background_fit, + "b--", + label="Background", + linewidth=1.5, + ) + ax[1].set_xlim(energy_range[0], energy_range[1]) + ax[1].legend() + ax[1].set_title("fit spectrum") + ax[1].set_xlabel("Energy (keV)") + ax[1].set_ylabel("Counts") + plt.tight_layout() + plt.show() + + return { + "loss_history": loss_history, + "fitted_spectrum": final_pred, + "input_spectrum": spectrum_raw_plot, + "background_spectrum": background_fit, + "concentrations": concs, + "element_names": model.peak_model.element_names, + "edge_concentrations": shell_concs, + "edge_names": shell_names, + "edge_element_indices": shell_element_indices, + "peak_widths": final_fwhm, + "energy_axis": energy_axis_plot, + "fit_range": energy_range, + } + + def fit_spectrum_pytorch( + self, + energy_range=None, + elements_to_fit=None, + peak_width=0.1, + num_iters=300, + num_iters_global=200, + polynomial_background_degree=3, + optimizer_global="lbfgs", + optimizer_local="lbfgs", + loss_global=None, + loss_local="poisson", + freeze_peak_width=True, + spatial_lambda=0.0, + min_total_counts=0.0, + verbose=True, + fit_mean_only=False, + show_plot=True, + lr_global=None, + lr_local=None, + device=None, + constrain_background=0.1, + ): + """Fit EDS spectra using a PyTorch model. + + Supports two workflows: + - Mean-only fitting (`fit_mean_only=True`): fit a single spectrum formed by + summing over all spatial pixels. + - Global + local fitting (`fit_mean_only=False`): fit a global mean model, + then refine concentrations/background per pixel across the full cube. + + Parameters + ---------- + energy_range : sequence[float] | None, optional + Two-element energy interval ``[emin, emax]`` in keV used for fitting. + If ``None``, the full energy axis is used. + elements_to_fit : sequence[str] | None, optional + Element symbols (or model-supported element labels) to include in the + fit. If ``None``, uses keys from ``self.model_elements``. + peak_width : float, optional + Initial peak width (FWHM-like parameter in keV) for model peaks. + num_iters : int, optional + Number of optimization iterations for mean-only mode, or local + per-pixel refinement iterations in full-cube mode. + num_iters_global : int, optional + Number of iterations for the global/mean stage in full-cube mode. + polynomial_background_degree : int, optional + Degree of polynomial background basis. + optimizer_global : {"adam", "lbfgs"}, optional + Optimizer for the global/mean stage. + optimizer_local : {"adam", "lbfgs"}, optional + Optimizer for per-pixel local fitting. + loss_global : {"poisson", "mse"} | None, optional + Global-stage data term. If ``None``, defaults to ``"mse"`` for + mean-only mode and ``"poisson"`` otherwise. + loss_local : {"poisson", "mse"}, optional + Local-stage data term (ignored when ``fit_mean_only=True``). + freeze_peak_width : bool, optional + If ``True``, keep peak widths fixed during local fitting. + spatial_lambda : float, optional + L2 spatial smoothness weight applied to abundance maps during local + fitting. Must be non-negative. + min_total_counts : float, optional + Minimum per-pixel integrated counts required for a pixel to + participate in local fitting. + verbose : bool, optional + If ``True``, print optimization progress. + fit_mean_only : bool, optional + If ``True``, run only the mean-spectrum fit and skip per-pixel + refinement. + show_plot : bool, optional + If ``True``, display diagnostic plots. + lr_global : float | None, optional + Learning rate for the global optimizer. If ``None``, an optimizer- + specific default is used. + lr_local : float | None, optional + Learning rate for the local optimizer. If ``None``, an optimizer- + specific default is used. + device : str | torch.device | None, optional + Torch device for fitting (for example ``"cpu"`` or ``"cuda"``). + If ``None``, uses CUDA when available, otherwise CPU. + constrain_background : float, optional + Background prior weight used in local fitting to keep per-pixel + background coefficients close to the globally optimized background. + Set to ``0`` to disable. This is only used when + ``fit_mean_only=False``. + + Returns + ------- + dict + Fit results. Contents depend on the selected mode. + + Mean-only mode (``fit_mean_only=True``) returns keys: + ``loss_history``, ``fitted_spectrum``, ``input_spectrum``, + ``background_spectrum``, ``concentrations``, ``element_names``, + ``edge_concentrations``, ``edge_names``, ``edge_element_indices``, + ``peak_widths``, ``energy_axis``, ``fit_range``. + + Full-cube mode (``fit_mean_only=False``) returns keys: + ``abundance_maps``, ``element_names``, ``peak_widths``, + ``loss_history``, ``global_loss_history``, ``valid_pixel_mask``, + ``energy_axis``, ``input_spectrum``, ``fitted_spectrum``, + ``background_spectrum``, ``input_spectrum_all_pixels``, + ``fitted_spectrum_all_pixels``, ``background_spectrum_all_pixels``, + ``fit_range``. + + Raises + ------ + TypeError + If ``constrain_background`` is not numeric (for example ``bool``). + ValueError + If optimizer/loss names are invalid, ``spatial_lambda < 0``, CUDA is + requested but unavailable, ``constrain_background < 0``, or no pixels + satisfy ``min_total_counts``. + """ + + def _normalize_choice(name, param_name, allowed_values): + name_norm = str(name).lower() + if name_norm not in allowed_values: + allowed_display = "', '".join(sorted(allowed_values)) + raise ValueError(f"{param_name} must be '{allowed_display}'") + return name_norm + + effective_optimizer_global = _normalize_choice( + optimizer_global, "optimizer_global", {"adam", "lbfgs"} + ) + effective_optimizer_local = _normalize_choice( + optimizer_local, "optimizer_local", {"adam", "lbfgs"} + ) + effective_loss_global = ( + _normalize_choice(loss_global, "loss_global", {"poisson", "mse"}) + if loss_global is not None + else ("mse" if fit_mean_only else "poisson") + ) + effective_loss_local = ( + _normalize_choice(loss_local, "loss_local", {"poisson", "mse"}) + if not fit_mean_only + else None + ) + + if spatial_lambda < 0: + raise ValueError("spatial_lambda must be >= 0") + + if isinstance(constrain_background, bool): + raise TypeError("constrain_background must be a non-negative float.") + try: + background_prior_lambda = float(constrain_background) + except (TypeError, ValueError) as exc: + raise TypeError("constrain_background must be a non-negative float.") from exc + if background_prior_lambda < 0: + raise ValueError("constrain_background must be >= 0") + + if elements_to_fit is None: + if not self.model_elements: + raise ValueError("elements_to_fit must be specified") + elements_to_fit = list(self.model_elements.keys()) + if verbose: + print(f"using model_elements {elements_to_fit}") + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + device = torch.device(device) + if device.type == "cuda" and not torch.cuda.is_available(): + raise ValueError("CUDA device requested but torch.cuda.is_available() is False.") + + effective_lr_global = lr_global + effective_lr_local = lr_local + + energy_axis_np = self.energy_axis.copy() + energy_axis = torch.tensor(energy_axis_np, dtype=torch.float32, device=device) + spectra = torch.tensor(self.array, dtype=torch.float32, device=device) + + if energy_range is not None: + ind = (energy_axis >= energy_range[0]) & (energy_axis <= energy_range[1]) + energy_axis = energy_axis[ind] + spectra = spectra[ind] + else: + energy_range = [float(energy_axis.min().item()), float(energy_axis.max().item())] + + if fit_mean_only: + if verbose: + print("fitting spectrum globally") + spectrum_raw = spectra.sum((-1, -2)) + mean_fit = self._fit_mean_model_pytorch( + energy_axis=energy_axis, + spectrum_raw=spectrum_raw, + elements_to_fit=elements_to_fit, + peak_width=peak_width, + polynomial_background_degree=polynomial_background_degree, + num_iters=num_iters, + optimizer=effective_optimizer_global, + lr=effective_lr_global, + loss_name=effective_loss_global, + normalize_target=True, + default_lr_adam=1e-3, + default_lr_lbfgs=1.0, + verbose=verbose, + ) + + model = mean_fit["model"] + loss_history = mean_fit["loss_history"] + spectrum_offset = mean_fit["spectrum_offset"] + spectrum_scale = mean_fit["spectrum_scale"] + with torch.no_grad(): + final_pred = mean_fit["final_pred_raw"].cpu().numpy() + shell_concs = ( + nn.functional.softplus(model.peak_model.concentrations).detach().cpu().numpy() + ) + shell_element_indices = ( + model.peak_model.shell_group_element_indices.detach().cpu().numpy() + ) + concs = np.zeros(len(model.peak_model.element_names), dtype=np.float32) + np.add.at(concs, shell_element_indices, shell_concs) + final_fwhm = ( + torch.nn.functional.softplus(model.peak_model.peak_width_by_peak) + .detach() + .cpu() + .numpy() + ) + background_fit = ( + (model.background_model().detach() * spectrum_scale + spectrum_offset) + .cpu() + .numpy() + ) + + print( + f"\nFinal: width median={np.median(final_fwhm):.3f} keV, " + f"min={final_fwhm.min():.3f}, max={final_fwhm.max():.3f}" + ) + + top_n = max(10, len(elements_to_fit) if elements_to_fit is not None else 0) + sorted_indices = np.argsort(concs)[::-1] + print("\nTop elements:") + for i, idx in enumerate(sorted_indices[:top_n], 1): + elem = model.peak_model.element_names[idx] + conc = concs[idx] + print(f"{i:2d}. {elem:2s}: {conc:.3f}") + + if show_plot: + energy_axis_plot = energy_axis.detach().cpu().numpy() + spectrum_raw_plot = spectrum_raw.detach().cpu().numpy() + fig, ax = plt.subplots(2, 1, figsize=(10, 6)) + ax[0].plot(np.arange(loss_history.shape[0]), loss_history, color="k") + ax[0].set_title("loss") + ax[0].set_xlabel("iterations") + ax[0].set_ylabel("loss") + ax[0].set_yscale("log") + + ax[1].plot(energy_axis_plot, spectrum_raw_plot, "k-", label="Data", linewidth=1) + ax[1].plot(energy_axis_plot, final_pred, "r-", label="Fit", linewidth=2) + ax[1].plot( + energy_axis_plot, + background_fit, + "b--", + label="Background", + linewidth=1.5, + ) + ax[1].set_xlim(energy_range[0], energy_range[1]) + ax[1].legend() + ax[1].set_title("fit spectrum") + ax[1].set_xlabel("Energy (keV)") + ax[1].set_ylabel("Counts") + plt.tight_layout() + plt.show() + + return { + "loss_history": loss_history, + "fitted_spectrum": final_pred, + "input_spectrum": spectrum_raw.detach().cpu().numpy(), + "background_spectrum": background_fit, + "concentrations": concs, + "element_names": model.peak_model.element_names, + "edge_concentrations": shell_concs, + "edge_names": list(model.peak_model.shell_group_names), + "edge_element_indices": shell_element_indices, + "peak_widths": final_fwhm, + "energy_axis": energy_axis.detach().cpu().numpy(), + "fit_range": energy_range, + } + + n_energy, n_y, n_x = spectra.shape + n_pixels = n_y * n_x + spectra_flat = spectra.permute(1, 2, 0).reshape(n_pixels, n_energy) + + total_counts = spectra_flat.sum(dim=1) + valid_pixel_mask = total_counts >= float(min_total_counts) + if not torch.any(valid_pixel_mask): + raise ValueError("No pixels satisfy min_total_counts. Lower threshold and retry.") + + mean_spectrum = spectra_flat[valid_pixel_mask].mean(dim=0) + + if verbose: + print("fitting spectrum globally") + global_fit = self._fit_mean_model_pytorch( + energy_axis=energy_axis, + spectrum_raw=mean_spectrum, + elements_to_fit=elements_to_fit, + peak_width=peak_width, + polynomial_background_degree=polynomial_background_degree, + num_iters=num_iters_global, + optimizer=effective_optimizer_global, + lr=effective_lr_global, + loss_name=effective_loss_global, + normalize_target=True, + default_lr_adam=1e-3, + default_lr_lbfgs=1.0, + verbose=verbose, + ) + global_model = global_fit["model"] + global_loss_history = global_fit["loss_history"] + global_scale = global_fit["spectrum_scale"].detach() + global_offset = global_fit["spectrum_offset"].detach() + global_fitted_spectrum = global_fit["final_pred_raw"].detach().cpu().numpy() + + n_elements = len(global_model.peak_model.element_names) + with torch.no_grad(): + global_conc_shell = ( + nn.functional.softplus(global_model.peak_model.concentrations).detach() + * global_scale + ) + shell_element_indices = global_model.peak_model.shell_group_element_indices + global_conc = torch.zeros( + n_elements, + dtype=global_conc_shell.dtype, + device=global_conc_shell.device, + ) + global_conc.index_add_(0, shell_element_indices, global_conc_shell) + global_bg_coeffs = global_model.background_model.coeffs.detach() * global_scale + if global_bg_coeffs.numel() > 0: + global_bg_coeffs = global_bg_coeffs.clone() + global_bg_coeffs[0] = global_bg_coeffs[0] + global_offset + global_peak_width_params = global_model.peak_model.peak_width_by_peak.detach().clone() + + peak_energies = global_model.peak_model.peak_energies + peak_weights = global_model.peak_model.peak_weights + peak_element_indices = global_model.peak_model.peak_element_indices + energy_step = float(global_model.peak_model.energy_step) + + background_basis = polynomial_energy_basis( + energy_axis, degree=polynomial_background_degree + ) + + mean_total = torch.clamp(mean_spectrum.sum(), min=1e-8) + pixel_scales = (total_counts / mean_total).unsqueeze(1) + conc_init = torch.clamp( + global_conc.unsqueeze(0) * pixel_scales, + min=1e-3, + ) + conc_init = torch.clamp( + conc_init * (1.0 + 0.02 * torch.randn_like(conc_init)), + min=1e-3, + ) + + conc_logits = nn.Parameter(inverse_softplus(conc_init)) + bg_coeffs_init = global_bg_coeffs.unsqueeze(0).repeat(n_pixels, 1) * pixel_scales + bg_coeffs = nn.Parameter(bg_coeffs_init.clone()) + + if freeze_peak_width: + peak_width_params = global_peak_width_params + else: + peak_width_params = nn.Parameter(global_peak_width_params.clone()) + + if freeze_peak_width: + element_basis = build_element_basis( + energy_axis=energy_axis, + peak_energies=peak_energies, + peak_weights=peak_weights, + peak_element_indices=peak_element_indices, + peak_width_by_peak=peak_width_params, + n_elements=n_elements, + energy_step=energy_step, + ) + + trainable_params = [conc_logits, bg_coeffs] + if not freeze_peak_width: + trainable_params.append(peak_width_params) + + local_lr = ( + effective_lr_local + if effective_lr_local is not None + else (0.05 if effective_optimizer_local == "adam" else 1.0) + ) + + if effective_optimizer_local == "adam": + adam_param_groups = [{"params": [conc_logits], "lr": local_lr}] + adam_param_groups.append({"params": [bg_coeffs], "lr": local_lr}) + if not freeze_peak_width: + adam_param_groups.append({"params": [peak_width_params], "lr": local_lr}) + local_opt = torch.optim.Adam(adam_param_groups) + else: + local_opt = torch.optim.LBFGS( + trainable_params, + lr=local_lr, + line_search_fn="strong_wolfe", + ) + + loss_history = [] + + def _forward_model(): + basis = ( + element_basis + if freeze_peak_width + else build_element_basis( + energy_axis=energy_axis, + peak_energies=peak_energies, + peak_weights=peak_weights, + peak_element_indices=peak_element_indices, + peak_width_by_peak=peak_width_params, + n_elements=n_elements, + energy_step=energy_step, + ) + ) + conc = nn.functional.softplus(conc_logits) # (P, n_elements) + peaks_pred = conc @ basis.t() + bg_pred = bg_coeffs @ background_basis + predicted = torch.clamp(peaks_pred + bg_pred, min=1e-8, max=1e8) + return predicted, conc, bg_pred + + def _background_regularization(): + if background_prior_lambda <= 0: + return bg_coeffs.new_tensor(0.0) + + coeff_init_eval = bg_coeffs_init[valid_pixel_mask] + coeff_eval = bg_coeffs[valid_pixel_mask] + coeff_scale = torch.clamp(coeff_init_eval.abs().mean(), min=1e-8) + reg_prior = ((coeff_eval - coeff_init_eval) / coeff_scale).pow(2).mean() + return background_prior_lambda * reg_prior + + def _local_loss(pred_local, conc_local): + local_scale = torch.clamp(global_scale, min=1e-8) + pred_eval = pred_local[valid_pixel_mask] / local_scale + target_eval = spectra_flat[valid_pixel_mask] / local_scale + + loss_data = eds_data_loss( + pred_eval, + target_eval, + loss=effective_loss_local, + ) + loss_total = loss_data + _background_regularization() + + if spatial_lambda <= 0: + return loss_total + + conc_maps = conc_local.view(n_y, n_x, n_elements).permute(2, 0, 1) + conc_maps = conc_maps / torch.clamp(global_scale, min=1e-8) + loss_smooth = abundance_smoothness_l2(conc_maps) + return loss_total + spatial_lambda * loss_smooth + + if verbose: + print("fitting spectrum position-by-position") + for i in range(num_iters): + if effective_optimizer_local == "lbfgs": + + def _local_closure(): + local_opt.zero_grad() + pred_local, conc_local, _bg_local = _forward_model() + loss_total = _local_loss(pred_local, conc_local) + loss_total.backward() + return loss_total + + loss_value = local_opt.step(_local_closure) + if not torch.is_tensor(loss_value): + with torch.no_grad(): + pred_local, conc_local, _bg_local = _forward_model() + loss_value = _local_loss(pred_local, conc_local) + else: + local_opt.zero_grad() + pred_local, conc_local, _bg_local = _forward_model() + loss_value = _local_loss(pred_local, conc_local) + loss_value.backward() + local_opt.step() + + loss_history.append(float(loss_value.detach().cpu().item())) + if verbose and ((i + 1) % max(1, num_iters // 10) == 0 or i == 0): + print(f"iter {i + 1:4d}/{num_iters}: loss={loss_history[-1]:.6g}") + + with torch.no_grad(): + pred_final, conc_final, bg_final = _forward_model() + mean_input_spectrum = spectra_flat[valid_pixel_mask].mean(dim=0).cpu().numpy() + mean_fitted_spectrum = pred_final[valid_pixel_mask].mean(dim=0).cpu().numpy() + mean_background_spectrum = bg_final[valid_pixel_mask].mean(dim=0).cpu().numpy() + mean_input_spectrum_all = spectra_flat.mean(dim=0).cpu().numpy() + mean_fitted_spectrum_all = pred_final.mean(dim=0).cpu().numpy() + mean_background_spectrum_all = bg_final.mean(dim=0).cpu().numpy() + + abundance_maps = conc_final.view(n_y, n_x, n_elements).permute(2, 0, 1).cpu().numpy() + peak_widths = nn.functional.softplus(peak_width_params).detach().cpu().numpy() + + pytorch_spectrum_images = self._build_pytorch_spectrum_images( + abundance_maps=abundance_maps, + element_names=list(global_model.peak_model.element_names), + ) + if hasattr(self, "_spectrum_images_pytorch"): + self._spectrum_images_pytorch = { + **self._spectrum_images_pytorch, + **pytorch_spectrum_images, + } + else: + self._spectrum_images_pytorch = {} + self._spectrum_images_pytorch = { + **self._spectrum_images_pytorch, + **pytorch_spectrum_images, + } + + loss_history_array = np.asarray(loss_history) + energy_axis_np = energy_axis.cpu().numpy() + + if show_plot: + fig, ax = plt.subplots(1, 1, figsize=(8, 4)) + global_x = np.arange(global_loss_history.shape[0]) + local_x = np.arange(loss_history_array.shape[0]) + global_loss_history.shape[0] + ax.plot( + global_x, + global_loss_history, + "b-", + label="global", + ) + ax.plot( + local_x, + loss_history_array, + "r-", + label="local", + ) + ax.axvline( + x=global_loss_history.shape[0] - 0.5, + color="gray", + linestyle="--", + linewidth=1.0, + label="switch", + ) + ax.set_title("loss") + ax.set_xlabel("iterations") + ax.set_ylabel("loss") + ax.set_yscale("log") + ax.legend() + plt.tight_layout() + plt.show() + + fig, ax = plt.subplots(1, 1, figsize=(10, 4)) + ax.plot(energy_axis_np, mean_input_spectrum, "k-", label="Data", linewidth=1) + ax.plot( + energy_axis_np, + global_fitted_spectrum, + color="cyan", + label="Global fit", + linewidth=2.5, + ) + ax.plot(energy_axis_np, mean_fitted_spectrum, "r-", label="Fit", linewidth=2.5) + ax.plot( + energy_axis_np, + mean_background_spectrum, + "b--", + label="Background", + linewidth=2.5, + ) + ax.set_xlim(energy_range[0], energy_range[1]) + ax.legend() + ax.set_title("fit spectrum after local fitting (valid-pixel averaged)") + ax.set_xlabel("Energy (keV)") + ax.set_ylabel("Counts") + plt.tight_layout() + plt.show() + + self.show_spectrum_images(method="fit") + + return { + "abundance_maps": abundance_maps, + "element_names": global_model.peak_model.element_names, + "peak_widths": peak_widths, + "loss_history": loss_history_array, + "global_loss_history": np.asarray(global_loss_history), + "valid_pixel_mask": valid_pixel_mask.view(n_y, n_x).cpu().numpy(), + "energy_axis": energy_axis_np, + "input_spectrum": mean_input_spectrum, + "fitted_spectrum": mean_fitted_spectrum, + "background_spectrum": mean_background_spectrum, + "input_spectrum_all_pixels": mean_input_spectrum_all, + "fitted_spectrum_all_pixels": mean_fitted_spectrum_all, + "background_spectrum_all_pixels": mean_background_spectrum_all, + "fit_range": energy_range, + "spectrum_images_pytorch": self._spectrum_images_pytorch, + } + + def calculate_background_polynomial( + self, + spectrum, + energy_axis=None, + degree=3, + percentile=10, + window_size=50, + ): + """ + Fit an EDS continuum background with a polynomial power series in energy. + + A rolling low-percentile envelope is used as the fit target so sharp + characteristic X-ray peaks do not dominate the continuum fit. + """ + + spectrum = np.asarray(spectrum, dtype=float) + if spectrum.ndim != 1: + raise ValueError("spectrum must be a 1D array") + if spectrum.size == 0: + raise ValueError("spectrum must contain at least one channel") + + if energy_axis is None: + energy_axis = np.asarray(self.energy_axis, dtype=float) + if energy_axis.shape != spectrum.shape: + energy_axis = float(self.origin[0]) + float(self.sampling[0]) * np.arange( + spectrum.size, dtype=float + ) + else: + energy_axis = np.asarray(energy_axis, dtype=float) + if energy_axis.shape != spectrum.shape: + raise ValueError("energy_axis must have the same shape as spectrum") + + if isinstance(degree, bool): + raise TypeError("degree must be a non-negative integer") + try: + degree = int(degree) + except (TypeError, ValueError) as exc: + raise TypeError("degree must be a non-negative integer") from exc + if degree < 0: + raise ValueError("degree must be >= 0") + + try: + percentile = float(percentile) + except (TypeError, ValueError) as exc: + raise TypeError("percentile must be a number between 0 and 100") from exc + if percentile < 0 or percentile > 100: + raise ValueError("percentile must be between 0 and 100") + + if isinstance(window_size, bool): + raise TypeError("window_size must be a positive integer") + try: + window_size = int(window_size) + except (TypeError, ValueError) as exc: + raise TypeError("window_size must be a positive integer") from exc + if window_size < 1: + raise ValueError("window_size must be >= 1") + window_size = min(window_size, spectrum.size) + + finite = np.isfinite(spectrum) & np.isfinite(energy_axis) + if np.count_nonzero(finite) < degree + 1: + raise ValueError("not enough finite spectrum points for the requested degree") + + half_window = window_size // 2 + envelope = np.full_like(spectrum, np.nan, dtype=float) + for channel in range(spectrum.size): + start = max(0, channel - half_window) + end = min(spectrum.size, channel + half_window + 1) + values = spectrum[start:end] + values = values[np.isfinite(values)] + if values.size: + envelope[channel] = np.percentile(values, percentile) + + fit_mask = finite & np.isfinite(envelope) + if np.count_nonzero(fit_mask) < degree + 1: + raise ValueError("not enough background fit points for the requested degree") + + fit_energy = energy_axis[fit_mask] + fit_counts = envelope[fit_mask] + energy_min = float(np.min(fit_energy)) + energy_span = float(np.max(fit_energy) - energy_min) + if energy_span <= 0: + if degree != 0: + raise ValueError("energy_axis must span more than one value for degree > 0") + return np.full_like(spectrum, max(float(np.median(fit_counts)), 0.0), dtype=float) + + # Scaling improves conditioning; this remains a polynomial in energy. + def scaled_energy(energy): + return 2.0 * (np.asarray(energy, dtype=float) - energy_min) / energy_span - 1.0 + + def polynomial_background(energy, *coefficients): + energy_scaled = scaled_energy(energy) + background = np.zeros_like(energy_scaled, dtype=float) + for power, coefficient in enumerate(coefficients): + background += coefficient * (energy_scaled**power) + return background + + scaled_fit_energy = scaled_energy(fit_energy) + initial_coefficients = np.polynomial.polynomial.polyfit( + scaled_fit_energy, + fit_counts, + deg=degree, + ) + try: + coefficients, _ = curve_fit( + polynomial_background, + fit_energy, + fit_counts, + p0=initial_coefficients, + maxfev=10000, + ) + except (RuntimeError, ValueError, FloatingPointError): + coefficients = initial_coefficients + + background = polynomial_background(energy_axis, *coefficients) + finite_counts = spectrum[finite] + max_count = max(float(np.max(finite_counts)), float(np.max(fit_counts)), 0.0) + background = np.nan_to_num(background, nan=0.0, posinf=max_count, neginf=0.0) + return np.maximum(background, 0.0) + + def calculate_background_powerlaw(self, spectrum, *args, **kwargs): + """Compatibility wrapper for the EDS polynomial background fit.""" + return self.calculate_background_polynomial(spectrum, *args, **kwargs) diff --git a/src/quantem/spectroscopy/dataset3deels.py b/src/quantem/spectroscopy/dataset3deels.py new file mode 100644 index 00000000..f0f61aba --- /dev/null +++ b/src/quantem/spectroscopy/dataset3deels.py @@ -0,0 +1,990 @@ +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +from numpy.typing import NDArray +from scipy.interpolate import interp1d +from scipy.ndimage import median_filter +from scipy.optimize import curve_fit +from scipy.stats import norm + +from quantem.core.visualization import show_2d +from quantem.spectroscopy.dataset3dspectroscopy import Dataset3dspectroscopy + + +class Dataset3deels(Dataset3dspectroscopy): + """An EELS dataset class that inherits from Dataset3dspectroscopy. + + This class represents a scanning transmission electron microscopy (STEM) dataset, + where the data consists of a 3D array with dimensions (energy, scan_y, scan_x). + The first dimension represents the energy, while the latter + two dimensions represent real space sampling. + + """ + + element_info = None + element_info_path = "eels_edges.csv" + dataset_type = "EELS" + + def __init__( + self, + array: NDArray | Any, + name: str, + origin: NDArray | tuple | list | float | int, + sampling: NDArray | tuple | list | float | int, + units: list[str] | tuple | list, + signal_units: str = "arb. units", + _token: object | None = None, + ): + """Initialize a 3D EELS dataset. + + Parameters + ---------- + array : NDArray | Any + The underlying 3D array data + name : str + A descriptive name for the dataset + origin : NDArray | tuple | list | float | int + The origin coordinates for each dimension + sampling : NDArray | tuple | list | float | int + The sampling rate/spacing for each dimension + units : list[str] | tuple | list + Units for each dimension + signal_units : str, optional + Units for the array values, by default "arb. units" + _token : object | None, optional + Token to prevent direct instantiation, by default None + """ + super().__init__( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + _token=_token, + ) + self._virtual_images = {} + self.dataset_type = "eels" + + def calculate_background_iterative(self, spectrum): + """ + Subtract background typical for EELS using iterative Gaussian fitting. + This method isolates the continuum background from the low-loss region. + + WARNING: Only use with EELS data! Will remove peaks if used with EDS. + + Parameters + ---------- + spectrum : ndarray + 1D EELS spectrum + energy_axis : ndarray + Energy axis corresponding to spectrum + + Returns + ------- + ndarray + Background-subtracted spectrum + """ + + from scipy.ndimage import gaussian_filter + from scipy.stats import norm + + # Smooth for better fitting + spec_smooth = gaussian_filter(spectrum, sigma=1.0) + pixel_vals = spec_smooth.copy() + + # Iteratively fit Gaussian to low-intensity values (the continuum) + # Remove outliers (edge peaks) iteratively + num_iterations = 10 + cutoff = 3 # +/- 3 sigma + + for _ in range(num_iterations): + mu, std = norm.fit(pixel_vals) + if std == 0: + break + # Keep only values within +/- 3 sigma (removes edge contributions) + lower = mu - cutoff * std + upper = mu + cutoff * std + pixel_vals = pixel_vals[(pixel_vals >= lower) & (pixel_vals <= upper)] + + # Subtract the estimated background level + background_fit = mu + + return background_fit + + def powerlaw_backgroundfit_eels(self, spectrum, energy_range, target_edge, window_size): + """ + Using a window of the energy axis preceding the target edge, fit a power law function to use for background subtraction. + The input window size should be 10-30% of the target edge energy. + """ + + energy_axis = self.energy_axis + + if energy_range is not None: + energy_range[0] = np.maximum(energy_range[0], energy_axis[0]) + energy_range[1] = np.minimum(energy_range[1], energy_axis[-1]) + + indices = np.where( + (energy_axis >= energy_range[0]) & (energy_axis <= energy_range[1]) + )[0] + energy_axis = energy_axis[indices] + else: + indices = np.arange(self.shape[0]) + + # Check that input window size is between 10% and 30% + + if window_size < 10 or window_size > 30: + raise ValueError("Invalid window size. Please input a value of between 10 and 30.") + + # Check that the target edge is within the energy range of the spectrum + # and that a pre-edge region of size at least 10% of the target edge, ending 5 eV before the target edge + # exists for pre-edge fitting. + + if target_edge < energy_axis[0] or target_edge > energy_axis[-1]: + raise ValueError("Target edge is outside of energy range.") + elif ((target_edge - 5) - target_edge * (window_size / 100)) < energy_axis[0]: + raise ValueError( + "Insufficient pre-edge background fitting region for this target edge and window size within given energy range." + ) + + # Fit power law function to spectrum within window region of the energy exis + + window_minE = (target_edge - 5) - target_edge * (window_size / 100) + window_maxE = target_edge - 5 + + window_indices = np.where((energy_axis >= window_minE) & (energy_axis <= window_maxE))[0] + + window_E = energy_axis[window_indices] + window_I = spectrum[window_indices] + + def powerlaw_function(E, A, r): + return A * (E ** (-r)) + + popt, _ = curve_fit(powerlaw_function, window_E, window_I, maxfev=2000) + background_fit = powerlaw_function(energy_axis, popt[0], popt[1]) + + # Plot the region of the spectrum between user-specified energy range, overlaid with the background fit curve, with background estimation + # window boundaries indicated + + fig, ax = plt.subplots() + ax.plot(energy_axis, spectrum, label="spectrum", color="b") + ax.plot(energy_axis, background_fit, label="background", color="r") + ax.vlines( + x=[window_minE, window_maxE], + ymin=0, + ymax=np.max(spectrum), + label="window limits", + color="k", + linestyle="dashed", + ) + ax.legend() + + return background_fit + + def smooth_eels_rollingaverage(self, roi=None, energy_range=None, mask=None, kernel_size=10): + energy_axis = self.energy_axis + + if energy_range is not None: + energy_range[0] = np.maximum(energy_range[0], energy_axis[0]) + energy_range[1] = np.minimum(energy_range[1], energy_axis[-1]) + + indices = np.where( + (energy_axis >= energy_range[0]) & (energy_axis <= energy_range[1]) + )[0] + energy_axis = energy_axis[indices] + else: + indices = np.arange(self.shape[0]) + + array3d_subrange = self.array[indices, :, :] + + kernel = np.ones(kernel_size) / kernel_size + + # For each probe position, convolve spectral data with smoothing kernel + + array3d_smoothed = np.zeros(array3d_subrange.shape) + + for kk in range(array3d_subrange.shape[1]): + for ll in range(array3d_subrange.shape[2]): + probe_spectrum = self.array[:, kk, ll] + spectrum_smoothed = np.convolve(probe_spectrum, kernel, mode="same") + array3d_smoothed[:, kk, ll] = spectrum_smoothed + + smoothed_data3d = Dataset3deels.from_array( + array=array3d_smoothed, + sampling=self.sampling, + origin=energy_axis[0], + units=self.units, + ) + + # Plot raw and smoothed mean spectra on the same set of axes + + mean_spectrum_raw = self.calculate_mean_spectrum( + roi=roi, + energy_range=energy_range, + mask=mask, + ) + mean_spectrum_smoothed = smoothed_data3d.calculate_mean_spectrum( + roi=roi, + energy_range=energy_range, + mask=mask, + ) + + fig, ax = plt.subplots() + ax.plot(energy_axis, mean_spectrum_raw, label="raw spectrum", color="b") + ax.plot(energy_axis, mean_spectrum_smoothed, label="kernel-smoothed spectrum", color="r") + ax.legend() + + return smoothed_data3d + + def measure_zlp_offset( + self, + zlp_guess_x=None, + fit_window=0.8, + fit_to_plane=False, + median_filter_pixels=3, + fit_zlp=True, + ): + """ + Measure ZLP offset at each pixel position by using a guess of ZLP posfitting each spectrum to a Gaussian + """ + + # Define Gaussian constraint to fit ZLP to + def _gaussian_fit(x, A, mu, sigma): + return A * np.exp(-0.5 * ((x - mu) / sigma) ** 2) + + def _plane_fit_2d(M, a, b, c): + x, y = M + return (a * x) + (b * y) + c + + _n_energy, n_y, n_x = self.array.shape + energy_axis = self.energy_axis + + # For each pixel, measure the zlp position by fitting a Gaussian to the measured zero-loss signal and taking its center as the zlp position. + + zlp_measured = np.zeros((n_y, n_x)) + + for iy in range(n_y): + for ix in range(n_x): + # Apply median filter to discount hot pixels that might spuriously produce the maximum intensity of the spectrum + if median_filter_pixels > 0: + spec_filt = median_filter(self.array[:, iy, ix], median_filter_pixels) + else: + spec_filt = self.array[:, iy, ix] + + if fit_zlp: + # Use initial guess for ZLP to define window for Gaussian fitting. If zlp_guess_x=None (default) use the maximum value of the spectrum + if zlp_guess_x is not None: + zlp_crude_idx = int(np.argmin(np.abs(energy_axis - zlp_guess_x))) + else: + zlp_crude_idx = int(np.argmax(spec_filt)) + + mu0 = float(energy_axis[zlp_crude_idx]) + + lo = mu0 - fit_window + hi = mu0 + fit_window + + x_mask = (energy_axis >= lo) & (energy_axis <= hi) + + xw = energy_axis[x_mask] + yw = spec_filt[x_mask] + + A0 = float(spec_filt[zlp_crude_idx]) + sigma0 = fit_window / 2 + + p0 = (A0, mu0, sigma0) + + bounds = ( + ( + 0.0, + lo, + 1e-12, + ), + ( + np.inf, + hi, + np.inf, + ), + ) + + popt, _ = curve_fit(_gaussian_fit, xw, yw, p0=p0, bounds=bounds) + + zlp_measured[iy, ix] = float(popt[1]) + else: + zlp_crude_idx = int(np.argmax(spec_filt)) + zlp_measured[iy, ix] = float(energy_axis[zlp_crude_idx]) + + if fit_to_plane: + # Fit a 2D plane to the array of measured ZLPs + xdata, ydata = np.meshgrid(np.arange(n_x), np.arange(n_y)) + + xdata_unpacked = np.vstack((xdata.ravel(), ydata.ravel())) + ydata_unpacked = zlp_measured.ravel() + + popt, _ = curve_fit(_plane_fit_2d, xdata_unpacked, ydata_unpacked) + + zlp_plane_1d = _plane_fit_2d(xdata_unpacked, popt[0], popt[1], popt[2]) + zlp_plane_2d = zlp_plane_1d.reshape(n_y, n_x) + + show_2d( + [zlp_measured, zlp_plane_2d], + cmap="magma", + title=["Measured ZLP (mean of Gaussian fit)", "ZLP plane fit"], + ) + return zlp_plane_2d + else: + show_2d( + [zlp_measured], + cmap="magma", + title=["Measured ZLP (mean of Gaussian fit)"], + ) + return zlp_measured + + def apply_zlp_correction( + self, + zlp_guess_x=None, + zlp_shifts_array=None, + fit_window=0.8, + measure_offset=True, + fit_to_plane=True, + fit_zlp=True, + return_3d_dataset=True, + ): + # Default behavior is to automatically call measure_zlp_offset to generate an array of ZLP shifts for each scan position. + # Alternatively, a 2D array matching the x and y dimensions of the 3D dataset can be supplied as the value of zlp_shifts_array to skip this step. + # If measure_offset is False and no 2D ZLP shifts array is provided, a scalar input for zlp_guess_x can be used to shift the energy axis at every scan position by that amount. + if measure_offset: + zlp_array = self.measure_zlp_offset( + zlp_guess_x=zlp_guess_x, + fit_window=fit_window, + fit_to_plane=fit_to_plane, + fit_zlp=fit_zlp, + ) + elif zlp_shifts_array is not None: + zlp_array = np.asarray(zlp_shifts_array, dtype=float) + if zlp_array.shape != self.array.shape[1:3]: + raise ValueError( + "Dimensions of input array for ZLP shifts do not match X and Y dimensions of 3D spectroscopy dataset." + ) + elif zlp_guess_x is not None: + zlp_array = np.ones(self.array.shape[1:3], dtype=float) * zlp_guess_x + else: + raise ValueError( + "measure_offset was set to False and no input argument for ZLP shifts was provided." + ) + + zlp_array = np.asarray(zlp_array, dtype=float) + if not np.all(np.isfinite(zlp_array)): + raise ValueError("ZLP shifts must contain only finite values.") + + # Initialize 3D array to populate with spectra aligned along the energy axis + corrected_array = np.empty(self.array.shape, dtype=np.result_type(self.array.dtype, float)) + + n_energy, n_y, n_x = self.array.shape + + energy_axis = self.energy_axis + if np.all((zlp_array >= 0) & (zlp_array <= n_energy - 1)) and ( + np.min(zlp_array) < energy_axis[0] or np.max(zlp_array) > energy_axis[-1] + ): + zlp_array = np.interp(zlp_array, np.arange(n_energy), energy_axis) + + # Apply sub-channel ZLP shifts using 1D linear interpolation along the energy axis. + for iy in range(n_y): + for ix in range(n_x): + spec = self.array[:, iy, ix] + corrected_array[:, iy, ix] = np.interp( + energy_axis + zlp_array[iy, ix], + energy_axis, + spec, + left=np.nan, + right=np.nan, + ) + + # Remove all planes along energy axis containing NaN, to equalize spectra lengths across all scan positions + mask = np.isnan(corrected_array).any(axis=(1, 2)) + aligned_data_3d = corrected_array[~mask] + new_Eaxis = energy_axis[~mask] + + if aligned_data_3d.shape[0] == 0: + raise ValueError( + "ZLP shifts leave no shared energy range after alignment. " + "Check that zlp_shifts_array is in energy units, not channel indices." + ) + + new_origin = float(new_Eaxis[0]) + + # Calculate mean spectra before and after correction for plotting + mean_spectrum_raw = self.array.mean(axis=(1, 2)) + mean_spectrum_corrected = aligned_data_3d.mean(axis=(1, 2)) + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) + ax1.plot(energy_axis, mean_spectrum_raw, label="Raw mean spectrum", color="r") + ax2.plot(new_Eaxis, mean_spectrum_corrected, label="ZLP-corrected spectrum", color="b") + ax1.set_xlabel("Energy (eV)") + ax1.set_ylabel("Intensity") + ax1.grid(True, alpha=0.1) + ax1.legend() + ax2.set_xlabel("Energy (eV)") + ax2.set_ylabel("Intensity") + ax2.grid(True, alpha=0.1) + ax2.legend() + + fig.tight_layout() + + if return_3d_dataset: + return Dataset3deels.from_array( + array=aligned_data_3d, + name=self.name, + sampling=self.sampling, + origin=new_origin, + units=self.units, + ) + return aligned_data_3d + + def calibrate_zero_loss_peak(self, center_guess=None, search_window=10): + """ + Calibrate the energy axis by centering the zero loss peak at 0 eV. + Finds the ZLP at every pixel, fits a 2D plane to the ZLP positions, + and shifts each spectrum individually so the ZLP sits at 0, while aligning + all ZLPs to the same channel index, allowing a single origin to correctly + calibrate the entire dataset. + + Parameters + ---------- + center_guess : float or None + Expected energy position of the ZLP in eV. If None, uses the + tallest peak in each spectrum as the ZLP. If provided, searches + for the tallest peak within the search window around that energy. + search_window : int + Number of channels to search on either side of center_guess. + Only used when center_guess is not None. Default is 10. + + Returns + ------- + Dataset3deels + New dataset with corrected energy calibration. + """ + + n_energy, n_y, n_x = self.array.shape + + dE = float(self.sampling[0]) + E0 = float(self.origin[0]) + energy_axis = self.energy_axis + + # --- Build ZLP position map --- + # For every pixel, find the energy where the ZLP sits. + # A median filter is applied to each spectrum first to remove + # hot pixels (cosmic rays, detector glitches) that could be + # brighter than the ZLP and fool the peak finder. + # If center_guess is provided, only look within a window + # of search_window channels around that energy. + # If center_guess is None, just find the tallest peak. + + zlp_map = np.zeros((n_y, n_x)) + + if center_guess is not None: + guess_index = int(round((center_guess - E0) / dE)) + lo = max(guess_index - search_window, 0) + hi = min(guess_index + search_window + 1, n_energy) + + for iy in range(n_y): + for ix in range(n_x): + spectrum = median_filter(self.array[:, iy, ix], size=3) + + if center_guess is None: + peak_index = np.argmax(spectrum) + else: + peak_index = lo + np.argmax(spectrum[lo:hi]) + + zlp_map[iy, ix] = E0 + peak_index * dE + + # --- Fit a 2D plane to the ZLP map --- + # The plane equation is: zlp_energy(y, x) = a*y + b*x + c + # This smooths out noisy per-pixel ZLP measurements by assuming + # the drift varies linearly across the scan area. + + y_coords, x_coords = np.meshgrid(np.arange(n_y), np.arange(n_x), indexing="ij") + y_flat = y_coords.ravel() + x_flat = x_coords.ravel() + z_flat = zlp_map.ravel() + + A = np.column_stack([y_flat, x_flat, np.ones(len(y_flat))]) + coeffs, _, _, _ = np.linalg.lstsq(A, z_flat, rcond=None) + a, b, c = coeffs + + zlp_plane = a * y_coords + b * x_coords + c + + # --- Shift each spectrum so the ZLP lands at 0 eV --- + # For each pixel, subtract its plane-predicted ZLP position from + # the energy axis, then interpolate the spectrum back onto the + # original energy grid. This physically moves the data so all + # ZLPs align at the same channel index. + + corrected_array = np.zeros_like(self.array) + + for iy in range(n_y): + for ix in range(n_x): + shift = zlp_plane[iy, ix] + shifted_energy = energy_axis - shift + interpolator = interp1d( + shifted_energy, + self.array[:, iy, ix], + kind="linear", + bounds_error=False, + fill_value=0.0, + ) + corrected_array[:, iy, ix] = interpolator(energy_axis) + + return Dataset3deels.from_array( + array=corrected_array, + name=self.name, + sampling=self.sampling, + origin=self.origin, + units=self.units, + ) + + def correct_zlp_shift(ll, hl): + """ + Aligns ZLP jitter across the spatial map and synchronizes Dual-EELS pairs. + """ + print(f"QuantEM: Aligning {ll.name} and syncing {hl.name}...") + + # 1. Map the drift via argmax + zlp_indices = np.argmax(ll.array, axis=0) + ref_idx = int(np.median(zlp_indices)) + shifts = zlp_indices - ref_idx + + # 2. Apply internal QuantEM calibration + ll.calibrate_zero_loss_peak() + + # 3. Synchronize High-Loss energy origin based on median shift + shift_ev = np.median(shifts) * ll.sampling[0] + hl.origin[0] -= shift_ev + + print("QuantEM: Alignment and Dual-EELS sync complete.") + return ll, hl, shifts + + def plot_absolute_zlp_shift(dataset, search_window=(-10, 10)): + """ + Calculates the ZLP shift per pixel and plots the absolute deviation from 0.0 eV. + """ + data = dataset.array + n_e = data.shape[0] + + # Generate energy axis + energies = dataset.origin[0] + np.arange(n_e) * dataset.sampling[0] + + # Mask energy window for peak finding + mask = (energies > search_window[0]) & (energies < search_window[1]) + search_energies = energies[mask] + + # Calculate peak map and absolute deviation + peak_indices = np.argmax(data[mask, :, :], axis=0) + zlp_map_ev = search_energies[peak_indices] + absolute_shift = np.abs(zlp_map_ev) + + # Visualization + fig, ax = plt.subplots(figsize=(8, 6)) + im = ax.imshow(absolute_shift, cmap="magma", origin="lower") + + plt.colorbar(im, ax=ax, label="Absolute Shift (eV)") + ax.set_title(f"Absolute ZLP Deviation: {dataset.name}") + ax.set_xlabel("X (pixels)") + ax.set_ylabel("Y (pixels)") + + plt.tight_layout() + plt.show() + + return absolute_shift + + def visualize_thickness_windows(dataset, zlp_window=(-3.0, 3.0), total_window=(-3.0, 75.0)): + """ + Visualizes integration windows for I0 (ZLP) and It (Total). + Returns a configuration dictionary for the calculation step. + """ + # 1. Extract Energy and Mean Spectrum + data = dataset.array + mean_spec = np.mean(data, axis=(1, 2)) + + # Use built-in energy axis if available, else generate from metadata + if hasattr(dataset, "energy_axis"): + energy = dataset.energy_axis + else: + energy = dataset.origin[0] + np.arange(dataset.shape[0]) * dataset.sampling[0] + + # 2. Find indices for the windows + zlp_idx = ( + np.argmin(np.abs(energy - zlp_window[0])), + np.argmin(np.abs(energy - zlp_window[1])), + ) + tot_idx = ( + np.argmin(np.abs(energy - total_window[0])), + np.argmin(np.abs(energy - total_window[1])), + ) + + # 3. Create the Visualization + fig, ax = plt.subplots(figsize=(10, 5)) + ax.plot(energy, mean_spec, "k-", lw=1.5, label="Mean Spectrum", zorder=5) + + # Highlight Windows + z_mask = (energy >= zlp_window[0]) & (energy <= zlp_window[1]) + t_mask = (energy >= total_window[0]) & (energy <= total_window[1]) + + ax.fill_between( + energy[z_mask], 0, mean_spec[z_mask], color="red", alpha=0.3, label="$I_0$ (ZLP)" + ) + ax.fill_between( + energy[t_mask], 0, mean_spec[t_mask], color="blue", alpha=0.1, label="$I_t$ (Total)" + ) + + ax.axvline(0, color="green", lw=1.5, ls=":", label="0 eV") + ax.set_title(f"QuantEM: Integration Windows ({dataset.name})", fontweight="bold") + ax.set_xlabel("Energy Loss (eV)") + ax.set_ylabel("Intensity (counts)") + ax.set_xlim(energy[0], total_window[1] + 20) + ax.legend() + + plt.tight_layout() + plt.show() + + return { + "zlp_idx": zlp_idx, + "total_idx": tot_idx, + "zlp_val": zlp_window, + "total_val": total_window, + } + + def calculate_thickness_log_ratio(dataset, window_params, plot=True): + """ + Calculates the relative thickness map (t/lambda) using the Log-Ratio method. + """ + data = dataset.array + z_start, z_end = window_params["zlp_idx"] + t_start, t_end = window_params["total_idx"] + + print(f"QuantEM: Calculating thickness for {dataset.name}...") + + # 1. Vectorized Integration + I_zlp = np.sum(data[z_start : z_end + 1, :, :], axis=0) + I_total = np.sum(data[t_start : t_end + 1, :, :], axis=0) + + # 2. Log-Ratio Calculation (with epsilon to avoid log(0)) + epsilon = 1e-10 + t_over_lambda = np.log((I_total + epsilon) / (I_zlp + epsilon)) + + # 3. Data Cleaning + t_over_lambda = np.nan_to_num(t_over_lambda, nan=0.0, posinf=0.0, neginf=0.0) + t_over_lambda = np.clip(t_over_lambda, 0, 4.0) + + if plot: + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) + fig.suptitle(f"Thickness Analysis: {dataset.name}", fontsize=14) + + im = ax1.imshow(t_over_lambda, cmap="viridis", origin="lower") + ax1.set_title(r"Relative Thickness Map ($t/\lambda$)") + plt.colorbar(im, ax=ax1, label=r"$t/\lambda$") + + ax2.hist(t_over_lambda.flatten(), bins=50, color="steelblue", alpha=0.7, ec="k") + ax2.axvline( + np.mean(t_over_lambda), + color="red", + ls="--", + label=f"Mean: {np.mean(t_over_lambda):.2f}", + ) + ax2.set_title("Thickness Distribution") + ax2.set_xlabel(r"$t/\lambda$") + ax2.legend() + + plt.tight_layout() + plt.show() + + return t_over_lambda + + def interpret_thickness_quality(t_over_lambda, a=0.3, b=1, c=2, dataset=None): + """ + Performs a scientific quality assessment on the calculated t/lambda map. + + The Physical Meaning of the ThresholdsThe t/lambda value represents the average number of inelastic scattering events + an electron undergoes. + Vacuum (< a): + (default a = 0.3) + In pure vacuum, t/lambda should be 0. In practice, values up to ~0.3 often indicate the presence of thin carbon support films, + surface contamination, or detector noise. Measurements in this regime are highly sensitive to ZLP (Zero Loss Peak) estimation errors. + + Thin (a c): + The "Multiple Scattering Regime. + " Most electrons have undergone three or more scattering events, resulting in a "spectral soup" + where fine-structure details and high-resolution chemical information are significantly broadened or lost. + """ + + name = dataset.name if dataset else "Dataset" + + # Classification Masks + vacuum = t_over_lambda < a + thin = (t_over_lambda >= a) & (t_over_lambda < b) + medium = (t_over_lambda >= b) & (t_over_lambda < c) + thick = t_over_lambda >= c + + print(f"\n{'=' * 20} QUANTEM INTERPRETATION: {name} {'=' * 20}") + for label, mask in [ + ("Vacuum (<0.3)", vacuum), + ("Thin (0.3-1.0)", thin), + ("Medium (1.0-2.0)", medium), + ("Thick (>2.0)", thick), + ]: + pct = 100 * np.sum(mask) / t_over_lambda.size + print(f" {label:20}: {pct:5.1f}%") + + # Plotting Classification + classified = np.zeros_like(t_over_lambda) + classified[thin] = 1 + classified[medium] = 2 + classified[thick] = 3 + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) + + im1 = ax1.imshow(classified, cmap="RdYlGn_r", origin="lower") + ax1.set_title("Region Classification") + cbar = plt.colorbar(im1, ax=ax1, ticks=[0, 1, 2, 3]) + cbar.ax.set_yticklabels(["Vacuum", "Thin", "Medium", "Thick"]) + + t_masked = np.copy(t_over_lambda) + t_masked[vacuum] = np.nan + im2 = ax2.imshow(t_masked, cmap="viridis", origin="lower") + ax2.set_title("Sample-Only Thickness") + plt.colorbar(im2, ax=ax2, label=r"$t/\lambda$") + + plt.tight_layout() + plt.show() + + def plot_absolute_thickness(t_lambda_map, mfp_nm, dataset=None): + """ + Converts relative thickness to nanometers and visualizes the absolute map. + """ + thickness_nm = t_lambda_map * mfp_nm + name = dataset.name if dataset else "Sample" + + # Mask vacuum for better visualization contrast + display_map = np.copy(thickness_nm) + display_map[t_lambda_map < 0.1] = np.nan + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) + fig.suptitle(f"Physical Analysis: {name}", fontsize=14) + + im = ax1.imshow(display_map, cmap="magma", origin="lower") + ax1.set_title("Absolute Thickness (nm)") + plt.colorbar(im, ax=ax1, label="nm") + + valid_data = thickness_nm[t_lambda_map >= 0.1].flatten() + ax2.hist(valid_data, bins=50, color="firebrick", alpha=0.7, ec="k") + ax2.axvline( + np.nanmean(display_map), + color="blue", + ls="--", + label=f"Mean: {np.nanmean(display_map):.1f} nm", + ) + ax2.set_title("Physical Distribution") + ax2.set_xlabel("Thickness (nm)") + ax2.legend() + + plt.tight_layout() + plt.show() + + print( + f"\nQuantEM Absolute Report:\n Mean: {np.nanmean(display_map):.2f} nm\n MFP: {mfp_nm:.2f} nm" + ) + return thickness_nm + + def plot_dual_eels_picker(ll, hl, coords=None, title="QuantEM: Dual-EELS Analysis"): + """ + Dual-EELS Picker with starting coordinates. + """ + # 1. Setup Data + sum_ll = np.sum(ll.array, axis=0) + sum_hl = np.sum(hl.array, axis=0) + energy_ll = ll.origin[0] + np.arange(ll.shape[0]) * ll.sampling[0] + energy_hl = hl.origin[0] + np.arange(hl.shape[0]) * hl.sampling[0] + + # 2. Handle Initial Coordinates + if coords is not None: + cx, cy = coords + else: + cx, cy = ll.shape[2] // 2, ll.shape[1] // 2 + + # 3. Create Figure + fig, axes = plt.subplots(2, 2, figsize=(14, 9)) + fig.suptitle(f"{title}\n(Click on maps to update spectra)", fontsize=16) + ax_map_ll, ax_spec_ll = axes[0, 0], axes[0, 1] + ax_map_hl, ax_spec_hl = axes[1, 0], axes[1, 1] + + # Plot Maps & Markers + ax_map_ll.imshow(sum_ll, cmap="viridis", origin="lower") + (marker_ll,) = ax_map_ll.plot(cx, cy, "r+", ms=15, mew=2) + + ax_map_hl.imshow(sum_hl, cmap="magma", origin="lower") + (marker_hl,) = ax_map_hl.plot(cx, cy, "r+", ms=15, mew=2) + + # Plot Initial Spectra + (line_ll,) = ax_spec_ll.plot(energy_ll, ll.array[:, cy, cx], color="tab:blue") + (line_hl,) = ax_spec_hl.plot(energy_hl, hl.array[:, cy, cx], color="tab:red") + + def update_plots(x, y): + marker_ll.set_data([x], [y]) + marker_hl.set_data([x], [y]) + + new_ll = ll.array[:, y, x] + new_hl = hl.array[:, y, x] + line_ll.set_ydata(new_ll) + line_hl.set_ydata(new_hl) + + # Rescale + ax_spec_ll.set_ylim(0, np.max(new_ll) * 1.1) + ax_spec_hl.set_ylim(0, np.max(new_hl) * 1.1) + + ax_spec_ll.set_title(f"LL Spectrum at ({x}, {y})") + ax_spec_hl.set_title(f"HL Spectrum at ({x}, {y})") + fig.canvas.draw_idle() + + def on_click(event): + if event.inaxes in [ax_map_ll, ax_map_hl]: + ix, iy = int(round(event.xdata)), int(round(event.ydata)) + if 0 <= ix < ll.shape[2] and 0 <= iy < ll.shape[1]: + update_plots(ix, iy) + + fig.canvas.mpl_connect("button_press_event", on_click) + + ax_spec_ll.set_title(f"LL Spectrum at ({cx}, {cy})") + ax_spec_hl.set_title(f"HL Spectrum at ({cx}, {cy})") + + plt.tight_layout() + plt.close(fig) # Prevents double-plotting in VS Code + return fig + + def plot_quantem_diagnostic(dataset, zlp_window=5.0, title_suffix=""): + """ + QuantEM Diagnostic Dashboard: Visualizes mean spectra, spatial variation, + and Zero Loss Peak (ZLP) centering accuracy. + + 1. Global Average Spectrum (Top Left): Shows the mean intensity across the entire scan. + It is used to check the signal-to-noise ratio and see if the Zero Loss Peak (ZLP) is roughly centered at 0 eV. + 2. Spatial Variation (Top Right): Plots spectra from a 5x5 grid of pixels across your sample. + This helps you see if the energy shift or intensity changes drastically from one side of the scan to the other + (e.g., due to sample thickness changes or beam drift). + 3. Integrated Intensity Map (Bottom Left): A spatial image of the total counts. + This is your "search image" to help you correlate the spectral data with the physical structure of your sample. + 4. ZLP Alignment Detail (Bottom Right): A high-zoom view of the energy region around 0 eV of the Mean Spectrum. + It includes a dashed green line at the "Target 0" to show exactly how much residual calibration error remains + after your alignment. + + Parameters: + ----------- + dataset : QuantEM Object + The EELS dataset containing .array, .origin, and .sampling attributes. + zlp_window : float, optional + The energy range (± eV) to display in the ZLP zoom plot. Default is 5.0. + title_suffix : str, optional + Additional text to append to the figure title (e.g., "(RAW)" or "(Aligned)"). + + Returns: + -------- + fig : matplotlib.figure.Figure + The figure object for further manipulation or saving. + """ + data = dataset.array + energy = dataset.origin[0] + np.arange(data.shape[0]) * dataset.sampling[0] + + mean_spec = np.mean(data, axis=(1, 2)) + zlp_pos = energy[np.argmax(mean_spec)] + sum_img = np.sum(data, axis=0) + + fig = plt.figure(figsize=(14, 9)) + gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.2) + fig.suptitle(f"QuantEM Diagnostic: {dataset.name} {title_suffix}", fontsize=16) + + # 1. Mean Spectrum + ax1 = fig.add_subplot(gs[0, 0]) + ax1.plot(energy, mean_spec, color="black", label="Mean") + ax1.axvline(0, color="green", ls=":", label="Target") + ax1.set_title("Global Average Spectrum") + ax1.legend() + + # 2. Spatial Variability + ax2 = fig.add_subplot(gs[0, 1]) + # Take a 5x5 grid for better representation than 3x3 + yy, xx = np.meshgrid( + np.linspace(0, data.shape[1] - 1, 5, dtype=int), + np.linspace(0, data.shape[2] - 1, 5, dtype=int), + ) + for y, x in zip(yy.flatten(), xx.flatten()): + ax2.plot(energy, data[:, y, x], alpha=0.3, lw=0.5) + ax2.set_title("Spatial Variation (Grid Samples)") + + # 3. Map + ax3 = fig.add_subplot(gs[1, 0]) + im = ax3.imshow(sum_img, cmap="viridis", origin="lower") + plt.colorbar(im, ax=ax3) + ax3.set_title("Integrated Intensity") + + # 4. ZLP Zoom + ax4 = fig.add_subplot(gs[1, 1]) + mask = (energy > zlp_pos - zlp_window) & (energy < zlp_pos + zlp_window) + ax4.plot(energy[mask], mean_spec[mask], lw=2) + ax4.axvline(0, color="green", ls=":") + ax4.set_title("ZLP Alignment Detail") + plt.close(fig) + + return fig + + def plot_zlp_drift_diagnostics(dataset, title="ZLP Drift Analysis"): + """ + QuantEM Diagnostic: Maps the ZLP position and calculates the drift distribution. + Uses scipy.stats for Gaussian fitting. + """ + data = dataset.array + energy = dataset.origin[0] + np.arange(data.shape[0]) * dataset.sampling[0] + + # 1. Mask and find peak per pixel + search_mask = (energy > -2.0) & (energy < 2.0) + search_energies = energy[search_mask] + peak_indices = np.argmax(data[search_mask, :, :], axis=0) + zlp_map = search_energies[peak_indices] + + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6)) + fig.suptitle(f"QuantEM: {dataset.name} - {title}", fontsize=16) + + # Plot A: Map + im = ax1.imshow(zlp_map, cmap="RdYlBu_r", origin="lower") + plt.colorbar(im, ax=ax1, label="Energy Shift (eV)") + + # Plot B: Histogram + Scipy Fit + flat_pos = zlp_map.flatten() + mu, std = norm.fit(flat_pos) # Professional scipy fitting + + ax2.hist(flat_pos, bins=30, density=True, alpha=0.6, color="skyblue") + x_range = np.linspace(np.min(flat_pos), np.max(flat_pos), 100) + ax2.plot( + x_range, + norm.pdf(x_range, mu, std), + color="darkred", + lw=2, + label=f"Fit: μ={mu:.3f} eV, σ={std:.3f} eV", + ) + ax2.legend() + + plt.tight_layout() + + plt.close(fig) + + return fig diff --git a/src/quantem/spectroscopy/dataset3dspectroscopy.py b/src/quantem/spectroscopy/dataset3dspectroscopy.py new file mode 100644 index 00000000..8cd5ec5f --- /dev/null +++ b/src/quantem/spectroscopy/dataset3dspectroscopy.py @@ -0,0 +1,1369 @@ +import csv +import json +import os +from typing import Any, Optional + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.patches import Rectangle +from numpy.typing import NDArray + +from quantem.core.datastructures.dataset3d import Dataset3d +from quantem.core.visualization import show_2d +from quantem.spectroscopy.utils import load_eels_edges_database, load_xray_lines_database + + +class _ModelElementsDict(dict): + """dict subclass for model_elements with a readable repr.""" + + def __repr__(self): + if not self: + return "Model Elements:\n None" + lines = ["Model Elements:"] + for element, line_info in self.items(): + if isinstance(line_info, dict) and line_info: + line_names = ", ".join(sorted(line_info.keys())) + lines.append(f" {element}: {line_names}") + else: + lines.append(f" {element}") + return "\n".join(lines) + + def _repr_html_(self): + if not self: + return "Model Elements:
  None" + rows = "".join( + f"{el}{', '.join(sorted(info.keys())) if isinstance(info, dict) and info else ''}" + for el, info in self.items() + ) + return f"Model Elements:{rows}
" + + +class Dataset3dspectroscopy(Dataset3d): + # stores the element line info so you don't need to reload each time + element_info = None + element_info_path = None + atomic_weights = None + + def __init__( + self, + array: NDArray | Any, + name: str, + origin: NDArray | tuple | list | float | int, + sampling: NDArray | tuple | list | float | int, + units: list[str] | tuple | list, + signal_units: str = "arb. units", + _token: object | None = None, + ): + super().__init__( + array=array, + name=name, + origin=origin, + sampling=sampling, + units=units, + signal_units=signal_units, + _token=type(self)._token if _token is None else _token, + ) + + self.model_elements = _ModelElementsDict() + self.attached_spectra = None + + # loads elemental information + @classmethod + def load_element_info(cls): + """Load element database for EDS X-ray lines or EELS edges.""" + if cls.element_info is not None: + return cls.element_info + + path = getattr(cls, "element_info_path", None) + if path is None: + raise NotImplementedError( + f"{cls.__name__} must define `element_info_path` to load element metadata." + ) + full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), path) + + dataset_type = str(getattr(cls, "dataset_type", "")).lower() + if path.lower().endswith(".csv"): + if dataset_type == "eels": + cls.element_info = load_eels_edges_database(full_path) + else: + cls.element_info = load_xray_lines_database(full_path) + else: + with open(full_path, "r", encoding="utf-8") as f: + cls.element_info = json.load(f)["elements"] + + if dataset_type == "eds": + cls._normalize_element_info() + + return cls.element_info + + @classmethod + def _ensure_element_info(cls): + """Load and return the cached element metadata.""" + return cls.load_element_info() or {} + + @classmethod + def load_atomic_weights(cls): + """Load atomic weights table from CSV once per class.""" + if cls.atomic_weights is not None: + return cls.atomic_weights + + atomic_weights_path = "atomic_weights.csv" + full_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), atomic_weights_path) + data = {} + with open(full_path, "r", newline="") as f: + reader = csv.reader(f) + for row_index, row in enumerate(reader, start=1): + if not row: + continue + if len(row) < 2: + raise ValueError( + f"{atomic_weights_path} row {row_index} must contain element symbol and weight" + ) + symbol = str(row[0]).strip() + weight_raw = str(row[1]).strip() + if not symbol: + continue + try: + weight = float(weight_raw) + except ValueError as exc: + raise ValueError( + f"{atomic_weights_path} row {row_index} has invalid weight: {weight_raw!r}" + ) from exc + data[symbol] = weight + + if not data: + raise ValueError(f"{atomic_weights_path} did not contain any atomic weights") + + cls.atomic_weights = data + return cls.atomic_weights + + @staticmethod + def _normalize_element_specs(specs): + if isinstance(specs, str): + return [s.strip() for s in specs.split(",") if s.strip()] + if isinstance(specs, (list, tuple, set)): + out = [] + for spec in specs: + out.extend([s.strip() for s in str(spec).split(",") if s.strip()]) + return out + raise TypeError("elements must be a string or a sequence of strings") + + @staticmethod + def _resolve_element_key(all_info, token): + token_norm = str(token).strip().lower() + return next((key for key in all_info if str(key).lower() == token_norm), None) + + @staticmethod + def _line_matches_selectors(line_name, selectors): + if not selectors: + return True + line_norm = str(line_name).strip().lower() + return any(line_norm == sel or line_norm.startswith(sel) for sel in selectors) + + @staticmethod + def _line_info_matches_selectors(line_info, selectors): + if not selectors or not isinstance(line_info, dict): + return False + edge_label = str(line_info.get("edge_label", "")).strip().lower() + return bool(edge_label) and any( + edge_label == sel or edge_label.startswith(sel) for sel in selectors + ) + + @classmethod + def _select_lines(cls, line_dict, selectors): + if not isinstance(line_dict, dict): + return {} + if not selectors: + return dict(line_dict) + + selector_norm = [str(sel).strip().lower() for sel in selectors if str(sel).strip()] + return { + line_name: line_info + for line_name, line_info in line_dict.items() + if cls._line_matches_selectors(line_name, selector_norm) + or cls._line_info_matches_selectors(line_info, selector_norm) + } + + def add_elements_to_model(self, elements): + """ + Add elements to the model for persistent use in show_mean_spectrum. + + Parameters + ---------- + elements : list or str + Element/line spec(s) to add. Examples: + - 'Al' (all lines for Al) + - 'Te La' (only Te La line) + - ['Au Ma', 'Te La', 'Si'] + """ + all_info = type(self)._ensure_element_info() + if all_info is None: + return + + specs = type(self)._normalize_element_specs(elements) + added_this_call = {} + + for spec in specs: + tokens = str(spec).split() + if not tokens: + continue + + element_key = type(self)._resolve_element_key(all_info, tokens[0]) + if element_key is None: + continue + + selectors = tokens[1:] + selected_lines = type(self)._select_lines(all_info[element_key], selectors) + if not selected_lines: + continue + + existing_before = self.model_elements.get(element_key) + if not isinstance(existing_before, dict): + existing_before = {} + existing_keys_before = set(existing_before.keys()) + if not selectors: + self.model_elements[element_key] = selected_lines + else: + existing = self.model_elements.get(element_key) + if not isinstance(existing, dict): + existing = {} + existing.update(selected_lines) + self.model_elements[element_key] = existing + + added_keys = [ + line_name + for line_name in selected_lines.keys() + if line_name not in existing_keys_before + ] + if added_keys: + if element_key not in added_this_call: + added_this_call[element_key] = [] + added_this_call[element_key].extend(added_keys) + if not self.model_elements: + self.model_elements = _ModelElementsDict() + + if added_this_call: + print("Added to model:") + for element_key in sorted(added_this_call.keys()): + unique_lines = sorted( + set(str(line_name) for line_name in added_this_call[element_key]) + ) + print(f" - {element_key}: {', '.join(unique_lines)}") + else: + print("Added to model: nothing new") + + def remove_elements_from_model(self, elements): + """ + Remove element(s) from the persistent model used in show_mean_spectrum. + + Parameters + ---------- + elements : list or str + Element/line spec(s) to remove. Examples: + - 'Al' (remove all Al lines) + - 'Te La' (remove only Te La line) + - ['Au Ma', 'Te La'] + """ + if not self.model_elements: + return + + specs = type(self)._normalize_element_specs(elements) + for spec in specs: + tokens = str(spec).split() + if not tokens: + continue + + element_key = type(self)._resolve_element_key(self.model_elements, tokens[0]) + if element_key is None: + continue + + selectors = [str(token).strip().lower() for token in tokens[1:] if str(token).strip()] + if not selectors: + self.model_elements.pop(element_key, None) + continue + + lines_info = self.model_elements.get(element_key) + if not isinstance(lines_info, dict): + self.model_elements.pop(element_key, None) + continue + + self.model_elements[element_key] = { + line_name: line_info + for line_name, line_info in lines_info.items() + if not type(self)._line_matches_selectors(line_name, selectors) + } + if not self.model_elements[element_key]: + self.model_elements.pop(element_key, None) + + if not self.model_elements: + self.model_elements = _ModelElementsDict() + + def clear_model_elements(self): + """Clear all elements from the model.""" + self.model_elements = _ModelElementsDict() + + # Storage of spectra alongside dataset + + def add_spectrum_to_data(self, spectrum, energy_axis): + """ + Store processed spectra in the 3D spectroscopy dataset structure, in a 1D array of 2D arrays. By default, calculate_mean_spectrum will store spectrum at first available index + """ + from quantem.core.datastructures.dataset1d import Dataset1d + + two_d_spectrum = Dataset1d.from_array( + array=spectrum, origin=energy_axis[0], sampling=self.sampling[0], units=self.units[0] + ) + + if self.attached_spectra is not None: + self.attached_spectra.append(two_d_spectrum) + else: + self.attached_spectra = [] + self.attached_spectra.append(two_d_spectrum) + + def clear_attached_spectra(self): + self.attached_spectra = None + + def plot_attached_spectrum(self, spectrum_index=0): + fig, (ax_spec) = plt.subplots(1, 1, figsize=(12, 4)) + + ax_spec.plot( + self.attached_spectra[spectrum_index][1], + self.attached_spectra[spectrum_index][0], + linewidth=1.5, + ) + if self.dataset_type == "eds": + ax_spec.set_xlabel("Energy (keV)") + elif self.dataset_type == "eels": + ax_spec.set_xlabel("Energy (eV)") + ax_spec.set_ylabel("Intensity") + ax_spec.set_title(f"Spectrum in index {spectrum_index}") + ax_spec.grid(True, alpha=0.1) + + fig.tight_layout() + plt.show() + + ## PCA ANALYSIS METHODS + def perform_pca( + self, + n_components: int = 10, + standardize: bool = True, + mask: Optional[NDArray] = None, + plot_results: bool = True, + return_results=False, + ) -> dict: + """ + Perform Principal Component Analysis (PCA) on the spectroscopy dataset. + + Parameters + ---------- + n_components : int + Number of principal components to compute + standardize : bool + If True, standardize the data before PCA (zero mean, unit variance) + mask : Optional[NDArray] + Optional spatial mask to select pixels for analysis. Accepts shape + (scan_y, scan_x) or a flattened spatial mask. + plot_results : bool + If True, plot the explained variance and first few components + + Returns + ------- + dict + Dictionary containing: + - 'pca': PCA result attributes + - 'components': principal component spectra (n_components x n_energy) + - 'loadings': spatial loadings (n_components x scan_y x scan_x) + - 'explained_variance_ratio': explained variance for each component + - 'reconstructed': reconstructed dataset (dataset3dspectroscopy) using n_components + """ + + from quantem.spectroscopy import Dataset3deds, Dataset3deels + + data = np.asarray(self.array, dtype=float) + n_energy, ny, nx = data.shape + n_pixels = ny * nx + + spectra = np.moveaxis(data, 0, -1).reshape(n_pixels, n_energy) + pixel_mask = np.ones(n_pixels, dtype=bool) + + if mask is not None: + mask_array = np.asarray(mask, dtype=bool) + if mask_array.shape == (ny, nx): + pixel_mask = mask_array.reshape(-1) + elif mask_array.shape == (n_pixels,): + pixel_mask = mask_array + else: + raise ValueError( + f"mask shape {mask_array.shape} must match spatial shape {(ny, nx)} " + f"or flattened shape {(n_pixels,)}" + ) + + if not np.any(pixel_mask): + raise ValueError("mask must select at least one spatial pixel") + + selected_spectra = spectra[pixel_mask] + + if standardize: + mean = np.mean(selected_spectra, axis=0) + std = np.std(selected_spectra, axis=0) + std[std == 0] = 1 # Avoid division by zero + pca_input = (selected_spectra - mean) / std + else: + mean = np.zeros(n_energy) + std = np.ones(n_energy) + pca_input = selected_spectra + + ( + components, + loadings, + explained_variance, + explained_variance_ratio, + reconstructed, + ) = self._run_pca(pca_input, n_components) + + reconstructed = reconstructed * std + mean + + loadings_flat = np.zeros((n_components, n_pixels), dtype=loadings.dtype) + loadings_flat[:, pixel_mask] = loadings.T + loadings_spatial = loadings_flat.reshape(n_components, ny, nx) + + if plot_results: + self._plot_pca_results( + components, + loadings_spatial, + explained_variance_ratio, + n_show=min(4, n_components), + ) + + reconstructed_spectra = spectra.copy() + reconstructed_spectra[pixel_mask] = reconstructed + reconstructed_array = reconstructed_spectra.reshape(ny, nx, n_energy).transpose(2, 0, 1) + + dataset_type = str(self.dataset_type).lower() + if dataset_type == "eds": + dataset_class = Dataset3deds + elif dataset_type == "eels": + dataset_class = Dataset3deels + else: + raise ValueError(f"Unsupported spectroscopy dataset_type {self.dataset_type!r}") + + reconstructed_data3d = dataset_class.from_array( + array=reconstructed_array, + sampling=self.sampling, + origin=self.origin, + units=self.units, + ) + + if return_results: + return { + "pca": { + "components_": components, + "explained_variance_": explained_variance, + "explained_variance_ratio_": explained_variance_ratio, + }, + "components": components, + "loadings": loadings_spatial, + "explained_variance_ratio": explained_variance_ratio, + "explained_variance": explained_variance, + "reconstructed": reconstructed_data3d, + } + + def _run_pca(self, data: NDArray | Any, n_components: int): + array = np.asarray(data, dtype=float) + n_samples, n_features = array.shape + max_components = min(n_samples, n_features) + if not 1 <= n_components <= max_components: + raise ValueError(f"n_components={n_components} must be between 1 and {max_components}") + + mean = np.mean(array, axis=0) + centered = torch.as_tensor(array - mean, dtype=torch.float64) + _, s, vh = torch.linalg.svd(centered, full_matrices=False) + + components = vh[:n_components].cpu().numpy() + loadings = (centered @ vh[:n_components].T).cpu().numpy() + + denom = max(n_samples - 1, 1) + explained_variance = ((s[:n_components] ** 2) / denom).cpu().numpy() + total_variance = torch.sum((s**2) / denom).item() + explained_variance_ratio = ( + explained_variance / total_variance + if total_variance > 0 + else np.zeros_like(explained_variance) + ) + reconstructed = loadings @ components + mean + + return ( + components, + loadings, + explained_variance, + explained_variance_ratio, + reconstructed, + ) + + def _plot_pca_results( + self, + components: NDArray, + loadings: NDArray, + explained_variance_ratio: NDArray, + n_show: int = 4, + ): + """ + Plot PCA results including scree plot, components, and loadings. + + Parameters + ---------- + components : NDArray + Principal component spectra + loadings : NDArray + Spatial loadings for each component + explained_variance_ratio : NDArray + Explained variance ratios + n_show : int + Number of components to show + """ + fig, (ax_scree, ax_components) = plt.subplots(1, 2, figsize=(12, 4)) + cumsum_var = np.cumsum(explained_variance_ratio) + component_numbers = np.arange(1, len(explained_variance_ratio) + 1) + + ax_scree.bar( + component_numbers, + explained_variance_ratio * 100, + alpha=0.6, + label="Individual", + ) + ax_scree.plot(component_numbers, cumsum_var * 100, "ro-", label="Cumulative") + ax_scree.set_xlabel("Component Number") + ax_scree.set_ylabel("Explained Variance (%)") + ax_scree.set_title("Scree Plot") + ax_scree.legend() + ax_scree.grid(True, alpha=0.3) + + energy_sampling = float(self.sampling[0]) + energy_origin = float(self.origin[0]) + energy_axis = energy_origin + energy_sampling * np.arange(components.shape[1]) + + for i in range(n_show): + ax_components.plot( + energy_axis, + components[i], + label=f"PC{i + 1} ({explained_variance_ratio[i] * 100:.1f}%)", + ) + ax_components.set_xlabel("Energy") + ax_components.set_ylabel("Component") + ax_components.set_title("Principal Component Spectra") + ax_components.legend() + ax_components.grid(True, alpha=0.3) + + fig.suptitle("PCA Analysis") + fig.tight_layout() + plt.show() + + show_2d( + [loadings[i] for i in range(n_show)], + title=[ + f"Loading {i + 1} ({explained_variance_ratio[i] * 100:.1f}%)" + for i in range(n_show) + ], + cmap="RdBu_r", + cbar=True, + scalebar={ + "sampling": float(self.sampling[1]), + "units": str(self.units[1]), + }, + ) + plt.show() + + def _calibrated_position_to_pixel(self, value, axis): + if value is None: + return None + + sampling = float(self.sampling[axis]) + if sampling == 0: + raise ValueError(f"Cannot convert calibrated ROI on axis {axis}: sampling is zero") + + origin = float(self.origin[axis]) if hasattr(self, "origin") else 0.0 + return int(np.round((float(value) - origin) / sampling)) + + def _calibrated_span_to_pixels(self, value, axis): + if value is None: + return None + + sampling = abs(float(self.sampling[axis])) + if sampling == 0: + raise ValueError( + f"Cannot convert calibrated ROI span on axis {axis}: sampling is zero" + ) + + pixels = int(np.round(float(value) / sampling)) + if pixels < 1: + raise ValueError( + f"Calibrated ROI span on axis {axis} converts to {pixels} pixels; expected >= 1" + ) + return pixels + + def _validate_roi_bounds(self, y, x, dy, dx): + errs = [] + ymax = int(self.shape[1]) + xmax = int(self.shape[2]) + + for name, val in (("y", y), ("x", x), ("dy", dy), ("dx", dx)): + if val is None: + errs.append(f"{name} is None (missing after normalization).") + + if errs: + raise ValueError("Invalid ROI:\n - " + "\n - ".join(errs)) + + if y < 0: + errs.append(f"y={y} < 0") + if x < 0: + errs.append(f"x={x} < 0") + if dy < 1: + errs.append(f"dy={dy} < 1") + if dx < 1: + errs.append(f"dx={dx} < 1") + + if y >= ymax: + errs.append(f"y start {y} out of bounds [0, {ymax - 1}]") + if x >= xmax: + errs.append(f"x start {x} out of bounds [0, {xmax - 1}]") + + end_y = y + dy + end_x = x + dx + if end_y > ymax: + errs.append(f"y+dy = {end_y} exceeds height {ymax}") + if end_x > xmax: + errs.append(f"x+dx = {end_x} exceeds width {xmax}") + + if errs: + raise ValueError("Invalid ROI:\n - " + "\n - ".join(errs)) + + def _resolve_roi(self, roi=None, roi_cal=None): + selector_count = int(roi is not None) + int(roi_cal is not None) + if selector_count > 1: + raise ValueError("Use only one ROI selector: roi or roi_cal") + + if roi is not None: + roi_spec = roi + elif roi_cal is not None: + if len(roi_cal) == 2: + y_cal, x_cal = roi_cal + roi_spec = [ + self._calibrated_position_to_pixel(y_cal, axis=1), + self._calibrated_position_to_pixel(x_cal, axis=2), + ] + elif len(roi_cal) == 4: + y_cal, x_cal, dy_cal, dx_cal = roi_cal + roi_spec = [ + self._calibrated_position_to_pixel(y_cal, axis=1), + self._calibrated_position_to_pixel(x_cal, axis=2), + self._calibrated_span_to_pixels(dy_cal, axis=1), + self._calibrated_span_to_pixels(dx_cal, axis=2), + ] + else: + raise ValueError("roi_cal must be [y, x] or [y, x, dy, dx]") + else: + roi_spec = None + + if roi_spec is None: + y, x, dy, dx = 0, 0, int(self.shape[1]), int(self.shape[2]) + elif len(roi_spec) == 2: + y, x = roi_spec + y, x, dy, dx = int(y), int(x), 1, 1 + elif len(roi_spec) == 4: + y_val, x_val, dy_val, dx_val = roi_spec + y = 0 if y_val is None else int(y_val) + x = 0 if x_val is None else int(x_val) + dy = int(self.shape[1]) - y if dy_val is None else int(dy_val) + dx = int(self.shape[2]) - x if dx_val is None else int(dx_val) + else: + raise ValueError( + "ROI must be None, [y, x], or [y, x, dy, dx]. Use one selector: roi or roi_cal" + ) + + self._validate_roi_bounds(y, x, dy, dx) + return y, x, dy, dx + + def calculate_mean_spectrum( + self, + roi=None, + energy_range=None, + mask=None, + attach_mean_spectrum=True, + roi_cal=None, + normalize=False, + ): + """Calculate a spectrum from a spatial ROI. + + Parameters + ---------- + normalize : bool, optional + If ``True``, scale the mean spectrum to the range [0, 1]. If + ``False``, return the mean spectrum in original intensity units. + """ + y, x, dy, dx = self._resolve_roi(roi=roi, roi_cal=roi_cal) + + # SPECTRUM CALCULATION -------------------------------------------------------------- + + dE = float(self.sampling[0]) + E0 = float(self.origin[0]) if hasattr(self, "origin") else 0.0 + E = E0 + dE * np.arange(self.shape[0]) + + # MASK HANDLING --------------------------------------------------------------------- + if mask is not None: + # Convert to ndarray and validate + mask = np.asarray(mask) + + # Check that it's a proper ndarray + if not isinstance(mask, np.ndarray): + raise TypeError(f"Mask must be a numpy ndarray, got {type(mask)}") + + # Check dimensions - must be 1D + if mask.ndim != 1: + raise ValueError( + f"Mask must be 1-dimensional, got {mask.ndim}D array with shape {mask.shape}" + ) + + # Convert to bool dtype and validate + if mask.dtype != bool: + try: + mask = mask.astype(bool) + except (ValueError, TypeError): + raise TypeError(f"Mask cannot be converted to boolean dtype from {mask.dtype}") + + # Check shape matches energy axis + arr = np.asarray(self.array, dtype=float) + if mask.shape != (arr.shape[0],): + raise ValueError( + f"Mask shape {mask.shape} does not match energy axis shape ({arr.shape[0]},)" + ) + + arr = arr[mask, y : y + dy, x : x + dx] # select masked energies and ROI + if arr.shape[0] > 0: + spec = arr.mean(axis=(1, 2)) + else: + spec = np.zeros(0) + E = E[mask] # Mask the energy axis as well + else: + spec = np.empty(self.shape[0], dtype=float) + for k in range(self.shape[0]): + img = np.asarray(self.array[k], dtype=float) + roi_data = img[y : y + dy, x : x + dx] + if roi_data.size == 0: + raise ValueError("ROI is empty; check y/x/dy/dx.") + spec[k] = roi_data.mean() + + # APPLY ENERGY RANGE --------------------------------------------------------------- + + if energy_range is not None: + # Check for errors in energy_range input + if energy_range[0] >= energy_range[1]: + raise ValueError("Invalid energy range parameter.") + + # If the entire energy range specified is outside the original energy range of the data, raise an error. + if energy_range[1] < E[0] or energy_range[0] > E[-1]: + raise ValueError("Energy range parameter is outside of data bounds.") + + # If either side of input energy_range is beyond the original energy range of the data, default to the limit of the data instead. + energy_range[0] = np.maximum(energy_range[0], E[0]) + energy_range[1] = np.minimum(energy_range[1], E[-1]) + + indices = np.where((E >= energy_range[0]) & (E <= energy_range[1]))[0] + spec = spec[indices] + E = E[indices] + + if normalize and spec.size > 0: + finite = np.isfinite(spec) + if np.any(finite): + spec_min = np.min(spec[finite]) + spec_max = np.max(spec[finite]) + if spec_max > spec_min: + spec = (spec - spec_min) / (spec_max - spec_min) + else: + spec = np.zeros_like(spec, dtype=float) + + if attach_mean_spectrum: + self.add_spectrum_to_data(spec, E) + + return spec + + def show_mean_spectrum( + self, + roi=None, + roi_cal=None, + energy_range=None, + mask=None, + intensity_range=None, + normalize=False, + **kwargs, + ): + """ + Plot the mean spectrum from a spatial ROI in a 3D spectroscopy cube (E, Y, X). + + Parameters + ---------- + roi : list or tuple, optional + Region of interest as [y, x, dy, dx] where: + - y, x: top-left pixel coordinates + - dy, dx: height and width of ROI + Use None for default values: + - [y, None, dy, None] = row y with height dy, full width + - [None, x, None, dx] = column x with width dx, full height + - [y, x, None, None] = from (y,x) to bottom-right corner + If roi=None, uses full image. Can also be [y, x] for single pixel. + energy_range : list or tuple, optional + Energy range to display as [min_energy, max_energy] in keV. + mask : array, optional + Boolean mask for pixel selection. + intensity_range : 2-tuple, None + If not None, sets intensity range on spectrum plot + normalize : bool, optional + If ``True``, scale the mean spectrum to the range [0, 1]. If + ``False``, plot the mean spectrum in original intensity units. + Returns + ------- + (fig, ax) : tuple + The Matplotlib Figure and Axes of the spectrum plot. + """ + + # ADJUST ROI BASED ON GIVEN FLAGS ----------------------------------------------- + # Parse ROI parameter + if roi is None: + # Full image + y, x, dy, dx = 0, 0, int(self.shape[1]), int(self.shape[2]) + elif len(roi) == 2: + # Single pixel [y, x] + y, x, dy, dx = int(roi[0]), int(roi[1]), 1, 1 + elif len(roi) == 4: + # Full ROI [y, x, dy, dx] with None support for defaults + y_val, x_val, dy_val, dx_val = roi + + # Handle None values with defaults + y = 0 if y_val is None else int(y_val) + x = 0 if x_val is None else int(x_val) + dy = int(self.shape[1]) - y if dy_val is None else int(dy_val) + dx = int(self.shape[2]) - x if dx_val is None else int(dx_val) + else: + raise ValueError( + "roi must be None, [y, x], or [y, x, dy, dx] (with None for defaults)" + ) + + # CALCULATE MEAN SPECTRUM FOR GIVEN ROI AND ENERGY RANGE -------------------------- + + y, x, dy, dx = self._resolve_roi(roi=roi, roi_cal=roi_cal) + + spec = self.calculate_mean_spectrum( + roi=roi, + roi_cal=roi_cal, + energy_range=energy_range, + mask=mask, + normalize=normalize, + ) + + dE = float(self.sampling[0]) + E0 = float(self.origin[0]) if hasattr(self, "origin") else 0.0 + E = E0 + dE * np.arange(self.shape[0]) + + if mask is not None: + E = E[np.asarray(mask, dtype=bool)] + + if energy_range is not None: + indices = np.where((E >= energy_range[0]) & (E <= energy_range[1]))[0] + E = E[indices] + + # PLOTTING --------------------------------------------------------------------------- + + # Create subplot layout: image on left, spectrum on right + fig, (ax_img, ax_spec) = plt.subplots(1, 2, figsize=(12, 4)) + + # LEFT PLOT: Show sum image with ROI highlighted + # Create sum image across all energy channels (or masked channels) + if mask is not None: + sum_img = np.asarray(self.array, dtype=float)[mask, :, :].sum(axis=0) + title_suffix = " (masked energies)" + else: + sum_img = np.asarray(self.array, dtype=float).sum(axis=0) + title_suffix = "" + + map_title = f"Integrated Intensity Map{title_suffix}" + show_2d( + sum_img, + figax=(fig, ax_img), + title=map_title, + cmap="viridis", + cbar=True, + show_ticks=True, + scalebar={ + "sampling": float(self.sampling[1]), + "units": str(self.units[1]), + }, + **kwargs, + ) + # Highlight the ROI with a rectangle + rect = Rectangle( + (x - 0.5, y - 0.5), dx, dy, linewidth=2, edgecolor="red", facecolor="none", alpha=0.8 + ) + ax_img.add_patch(rect) + + # RIGHT PLOT: Show spectrum + ax_spec.plot(E, spec, linewidth=1.5, color="k") + if self.dataset_type == "eds": + ax_spec.set_xlabel("Energy (keV)") + else: + ax_spec.set_xlabel("Energy (eV)") + ax_spec.set_ylabel("Normalized intensity" if normalize else "Intensity") + ax_spec.set_title(f"Spectrum from ROI [{y}:{y + dy}, {x}:{x + dx}]") + ax_spec.grid(True, alpha=0.1) + if intensity_range is not None: + ax_spec.set_ylim([intensity_range[0], intensity_range[1]]) + + fig.tight_layout() + return fig, (ax_img, ax_spec) + + def show_energy_window_map( + self, + energy_window=None, + roi=None, + roi_cal=None, + mask=None, + cmap="viridis", + show=True, + ): + """Show a spatial map integrated over a selected energy window. + + This is a complementary view to ``show_mean_spectrum``: + - ``show_mean_spectrum`` answers *what energies are present*. + - ``show_energy_window_map`` answers *where a chosen energy range is present*. + + Parameters + ---------- + energy_window : list[float] | tuple[float, float] | None + Energy interval [emin, emax] to integrate. If None, use the + full calibrated energy range of the dataset. + roi : list | tuple | None, optional + ROI as ``[y, x]`` or ``[y, x, dy, dx]`` (with ``None`` defaults), + used only for overlay rectangle. + mask : array-like | None, optional + Optional boolean mask over energy channels. If provided, it is + combined with ``energy_window``. + cmap : str, optional + Matplotlib colormap for the map. + show : bool, optional + If True, call ``plt.show()``. + + Returns + ------- + tuple + ``(fig, (ax_map, ax_spec), energy_map)`` where ``energy_map`` is the integrated 2D array. + """ + y, x, dy, dx = self._resolve_roi(roi=roi, roi_cal=roi_cal) + has_roi_overlay = any(val is not None for val in (roi, roi_cal)) + + dE = float(self.sampling[0]) + E0 = float(self.origin[0]) if hasattr(self, "origin") else 0.0 + E = E0 + dE * np.arange(self.shape[0]) + + if energy_window is None: + emin = float(np.min(E)) + emax = float(np.max(E)) + else: + if len(energy_window) != 2: + raise ValueError("energy_window must be [min_energy, max_energy]") + + emin = float(energy_window[0]) + emax = float(energy_window[1]) + if not np.isfinite(emin) or not np.isfinite(emax) or emin >= emax: + raise ValueError( + "Invalid energy_window. Expected [min_energy, max_energy] with min < max" + ) + + window_mask = (E >= emin) & (E <= emax) + if mask is not None: + mask = np.asarray(mask, dtype=bool) + if mask.shape != (self.shape[0],): + raise ValueError( + f"Mask shape {mask.shape} does not match energy axis shape ({self.shape[0]},)" + ) + window_mask = window_mask & mask + + if not np.any(window_mask): + raise ValueError("No energy channels selected. Adjust energy_window or mask") + + arr = np.asarray(self.array, dtype=float) + energy_map = arr[window_mask, :, :].sum(axis=0) + + spec = self.calculate_mean_spectrum( + roi=roi, + roi_cal=roi_cal, + mask=mask, + attach_mean_spectrum=False, + ) + if mask is not None: + E_spec = E[mask] + else: + E_spec = E + + unit_label = "keV" if str(self.dataset_type).lower() == "eds" else "eV" + fig, (ax_map, ax_spec) = plt.subplots(1, 2, figsize=(12, 4)) + show_2d( + energy_map, + figax=(fig, ax_map), + title=f"Energy-Window Map [{emin:.3f}, {emax:.3f}] {unit_label}", + cmap=cmap, + cbar=True, + show_ticks=True, + scalebar={ + "sampling": float(self.sampling[1]), + "units": str(self.units[1]), + }, + ) + + if has_roi_overlay: + rect = Rectangle( + (x - 0.5, y - 0.5), + dx, + dy, + linewidth=2, + edgecolor="red", + facecolor="none", + alpha=0.8, + ) + ax_map.add_patch(rect) + + ax_spec.plot(E_spec, spec, linewidth=1.5, color="k") + ax_spec.axvspan(emin, emax, color="orange", alpha=0.2, label="Selected window") + ax_spec.set_xlabel(f"Energy ({unit_label})") + ax_spec.set_ylabel("Intensity") + ax_spec.set_title(f"Spectrum from ROI [{y}:{y + dy}, {x}:{x + dx}]") + ax_spec.grid(True, alpha=0.1) + ax_spec.legend(loc="best") + + fig.tight_layout() + + if show: + plt.show() + + return fig, (ax_map, ax_spec), energy_map + + # BACKGROND SUBTRACTION + + def subtract_background( + self, + roi=None, + energy_range=None, + mask=None, + target_edge=None, + window_size=None, + method="powerlaw", + polynomial_degree=3, + return_dataset=True, + attach_spectrum=True, + fit_mode="global", + kernel_width=1, + show=True, + show_subtracted=True, + return_background=False, + ): + """ + Subtract fitted background from a 3D spectroscopy dataset. + + Parameters + ---------- + fit_mode : {"global", "local"}, optional + ``"global"`` fits one background to the ROI mean spectrum and subtracts + it from every probe position. ``"local"`` fits a background at each + probe position from the average spectrum of its nearest spatial + neighbors. + kernel_width : int, optional + Number of nearest spatial neighbors to average for each local + background fit. The current pixel is included. Used only when + ``fit_mode="local"``. + window_size : int, optional + For EDS, number of spectral channels in the rolling low-percentile + envelope used before polynomial fitting. For EELS power-law fitting, + percent of ``target_edge`` used for the pre-edge fit window. Defaults + to 50 channels for EDS and 10 percent for EELS. + show : bool, optional + If True, plot the mean raw spectrum, fitted background, and + background-subtracted spectrum. + polynomial_degree : int, optional + Degree of the polynomial power-series background used for EDS data. + Ignored for EELS data. + return_background : bool, optional + If True, return ``(dataset, background_cube)`` when ``return_dataset`` + is True, otherwise return the background cube. + + Returns + ------- + Dataset3dspectroscopy or tuple or ndarray or None + Background-subtracted dataset by default. If ``return_background`` is + True, also returns the fitted background cube. + """ + + from quantem.spectroscopy import Dataset3deds, Dataset3deels + + fit_mode = str(fit_mode).lower() + if fit_mode not in {"global", "local"}: + raise ValueError("fit_mode must be 'global' or 'local'") + + E, indices = self._background_energy_axis_and_indices(energy_range, mask) + array3d = np.asarray(self.array, dtype=float)[indices, :, :] + y, x, dy, dx = self._resolve_roi(roi=roi) + + if fit_mode == "global": + input_spectrum = array3d[:, y : y + dy, x : x + dx].mean(axis=(1, 2)) + background = self._fit_background_spectrum( + input_spectrum, + E, + method=method, + target_edge=target_edge, + window_size=window_size, + polynomial_degree=polynomial_degree, + ) + background_cube = np.broadcast_to(background[:, None, None], array3d.shape) + else: + background_cube = self._fit_local_background_cube( + array3d, + E, + method=method, + target_edge=target_edge, + window_size=window_size, + polynomial_degree=polynomial_degree, + kernel_width=kernel_width, + ) + + spec3D_subtracted = np.maximum(array3d - background_cube, 0) + input_mean_spectrum = array3d[:, y : y + dy, x : x + dx].mean(axis=(1, 2)) + background_mean_spectrum = background_cube[:, y : y + dy, x : x + dx].mean(axis=(1, 2)) + subtracted_mean_spectrum = spec3D_subtracted[:, y : y + dy, x : x + dx].mean(axis=(1, 2)) + + if attach_spectrum: + self.add_spectrum_to_data(subtracted_mean_spectrum, E) + + if show: + self._plot_background_subtraction( + E, + input_mean_spectrum, + background_mean_spectrum, + subtracted_mean_spectrum, + fit_mode=fit_mode, + show_subtracted=show_subtracted, + ) + + dataset_type = str(self.dataset_type).lower() + if dataset_type == "eds": + dataset_class = Dataset3deds + elif dataset_type == "eels": + dataset_class = Dataset3deels + else: + raise ValueError(f"Unsupported spectroscopy dataset_type {self.dataset_type!r}") + + output_origin = np.array(self.origin, dtype=float, copy=True) + output_origin[0] = E[0] + + if return_dataset: + subtracted_dataset = dataset_class.from_array( + array=spec3D_subtracted, + sampling=self.sampling, + origin=output_origin, + units=self.units, + ) + if return_background: + background_dataset = dataset_class.from_array( + array=np.array(background_cube, copy=True), + sampling=self.sampling, + origin=output_origin, + units=self.units, + ) + return subtracted_dataset, background_dataset + return subtracted_dataset + + if return_background: + return background_cube + + print("Notice: no 3D dataset was returned") + + def _background_energy_axis_and_indices(self, energy_range, mask): + E = np.asarray(self.energy_axis, dtype=float) + selected = np.ones(E.shape, dtype=bool) + + if energy_range is not None: + if len(energy_range) != 2: + raise ValueError("energy_range must be [min_energy, max_energy]") + e_min = float(energy_range[0]) + e_max = float(energy_range[1]) + if e_min >= e_max: + raise ValueError("Invalid energy range parameter.") + if e_max < E[0] or e_min > E[-1]: + raise ValueError("Energy range parameter is outside of data bounds.") + e_min = max(e_min, float(E[0])) + e_max = min(e_max, float(E[-1])) + selected &= (E >= e_min) & (E <= e_max) + + if mask is not None: + mask = np.asarray(mask, dtype=bool) + if mask.shape != E.shape: + raise ValueError( + f"Mask shape {mask.shape} does not match energy axis shape {E.shape}" + ) + selected &= mask + + if not np.any(selected): + raise ValueError("No energy channels selected. Adjust energy_range or mask") + + indices = np.where(selected)[0] + return E[indices], indices + + def _fit_background_spectrum( + self, + spectrum, + energy_axis, + method, + target_edge, + window_size, + polynomial_degree=3, + ): + dataset_type = str(self.dataset_type).lower() + spectrum = np.asarray(spectrum, dtype=float) + + if dataset_type == "eds": + return self.calculate_background_polynomial( + spectrum, + energy_axis=np.asarray(energy_axis, dtype=float), + degree=polynomial_degree, + window_size=50 if window_size is None else window_size, + ) + + if dataset_type != "eels": + raise ValueError(f"Unsupported spectroscopy dataset_type {self.dataset_type!r}") + + method = str(method).lower() + if method == "iterative": + return np.full_like(spectrum, float(self.calculate_background_iterative(spectrum))) + if method != "powerlaw": + raise ValueError("EELS background method must be 'powerlaw' or 'iterative'") + if target_edge is None: + raise ValueError("target_edge is required for EELS powerlaw background fitting") + + return self._fit_eels_powerlaw_background( + spectrum, + np.asarray(energy_axis, dtype=float), + target_edge=target_edge, + window_size=10 if window_size is None else window_size, + ) + + def _fit_eels_powerlaw_background(self, spectrum, energy_axis, target_edge, window_size): + from scipy.optimize import curve_fit + + if window_size < 10 or window_size > 30: + raise ValueError("Invalid window size. Please input a value of between 10 and 30.") + + target_edge = float(target_edge) + if target_edge < energy_axis[0] or target_edge > energy_axis[-1]: + raise ValueError("Target edge is outside of energy range.") + + window_minE = (target_edge - 5) - target_edge * (float(window_size) / 100) + window_maxE = target_edge - 5 + if window_minE < energy_axis[0]: + raise ValueError( + "Insufficient pre-edge background fitting region for this target edge " + "and window size within given energy range." + ) + + window_indices = np.where((energy_axis >= window_minE) & (energy_axis <= window_maxE))[0] + if len(window_indices) < 2: + raise ValueError("Insufficient points in EELS pre-edge background fitting window.") + + window_E = energy_axis[window_indices] + window_I = np.asarray(spectrum, dtype=float)[window_indices] + + def powerlaw_function(E, A, r): + return A * (E ** (-r)) + + popt, _ = curve_fit(powerlaw_function, window_E, window_I, maxfev=2000) + return powerlaw_function(energy_axis, popt[0], popt[1]) + + def _fit_local_background_cube( + self, + array3d, + energy_axis, + method, + target_edge, + window_size, + polynomial_degree, + kernel_width, + ): + from scipy.spatial import cKDTree + + n_energy, ny, nx = array3d.shape + n_pixels = ny * nx + try: + n_neighbors = int(kernel_width) + except (TypeError, ValueError) as exc: + raise TypeError("kernel_width must be an integer") from exc + if n_neighbors < 1: + raise ValueError("kernel_width must be >= 1") + n_neighbors = min(n_neighbors, n_pixels) + + yy, xx = np.indices((ny, nx)) + coords = np.column_stack((yy.reshape(-1), xx.reshape(-1))) + _, neighbor_indices = cKDTree(coords).query(coords, k=n_neighbors) + if n_neighbors == 1: + neighbor_indices = neighbor_indices[:, None] + + spectra = np.moveaxis(array3d, 0, -1).reshape(n_pixels, n_energy) + background = np.empty_like(spectra) + + for pixel_index, neighbors in enumerate(neighbor_indices): + local_spectrum = spectra[neighbors].mean(axis=0) + try: + background[pixel_index] = self._fit_background_spectrum( + local_spectrum, + energy_axis, + method=method, + target_edge=target_edge, + window_size=window_size, + polynomial_degree=polynomial_degree, + ) + except Exception as exc: + y, x = divmod(pixel_index, nx) + raise RuntimeError(f"Background fit failed at pixel ({y}, {x})") from exc + + return background.reshape(ny, nx, n_energy).transpose(2, 0, 1) + + def _plot_background_subtraction( + self, + energy_axis, + input_spectrum, + background_spectrum, + subtracted_spectrum, + fit_mode, + show_subtracted, + ): + fig, (ax_specbacksub) = plt.subplots(1, 1, figsize=(12, 4)) + + ax_specbacksub.plot(energy_axis, input_spectrum, linewidth=1.2, label="Input") + ax_specbacksub.plot(energy_axis, background_spectrum, linewidth=1.2, label="Background") + if show_subtracted: + ax_specbacksub.plot( + energy_axis, + subtracted_spectrum, + linewidth=1.5, + label="Background-subtracted", + ) + if self.dataset_type == "eds": + ax_specbacksub.set_xlabel("Energy (keV)") + else: + ax_specbacksub.set_xlabel("Energy (eV)") + ax_specbacksub.set_ylabel("Intensity") + ax_specbacksub.set_title(f"Background-subtracted spectrum from ROI ({fit_mode})") + ax_specbacksub.grid(True, alpha=0.1) + ax_specbacksub.legend() + + fig.tight_layout() + plt.show() + + @property + def energy_axis(self): + energy_axis = np.arange(self.shape[0]) * self.sampling[0] + self.origin[0] + return energy_axis diff --git a/src/quantem/spectroscopy/eels_edges.csv b/src/quantem/spectroscopy/eels_edges.csv new file mode 100644 index 00000000..08a780e9 --- /dev/null +++ b/src/quantem/spectroscopy/eels_edges.csv @@ -0,0 +1,1047 @@ +atomic_number,symbol,element,edge_label,edge_energy_eV +1,H,Hydrogen,major,14 +2,He,Helium,major,25 +3,Li,Lithium,major,55 +4,Be,Beryllium,major,111 +5,B,Boron,major,188 +6,C,Carbon,major,284 +7,N,Nitrogen,major,402 +8,O,Oxygen,major,532 +8,O,Oxygen,minor,24 +9,F,Fluorine,major,685 +9,F,Fluorine,minor,31 +10,Ne,Neon,major,867 +10,Ne,Neon,major,18 +10,Ne,Neon,minor,45 +11,Na,Sodium,major,1072 +11,Na,Sodium,major,31 +11,Na,Sodium,minor,63 +12,Mg,Magnesium,major,1305 +12,Mg,Magnesium,major,51 +12,Mg,Magnesium,minor,89 +13,Al,Aluminum,major,1560 +13,Al,Aluminum,major,73 +13,Al,Aluminum,minor,118 +14,Si,Silicon,major,1839 +14,Si,Silicon,major,99 +14,Si,Silicon,minor,149 +15,P,Phosphorus,major,2146 +15,P,Phosphorus,major,132 +15,P,Phosphorus,minor,189 +16,S,Sulfur,major,2472 +16,S,Sulfur,major,165 +16,S,Sulfur,minor,229 +17,Cl,Chlorine,major,2822 +17,Cl,Chlorine,major,202 +17,Cl,Chlorine,major,200 +17,Cl,Chlorine,minor,270 +17,Cl,Chlorine,minor,18 +18,Ar,Argon,major,3203 +18,Ar,Argon,major,247 +18,Ar,Argon,major,245 +18,Ar,Argon,major,12 +18,Ar,Argon,minor,320 +18,Ar,Argon,minor,25 +19,K,Potassium,major,3607 +19,K,Potassium,major,296 +19,K,Potassium,major,294 +19,K,Potassium,major,18 +19,K,Potassium,minor,377 +19,K,Potassium,minor,34 +20,Ca,Calcium,major,4038 +20,Ca,Calcium,major,350 +20,Ca,Calcium,major,346 +20,Ca,Calcium,major,25 +20,Ca,Calcium,minor,438 +20,Ca,Calcium,minor,44 +21,Sc,Scandium,major,4493 +21,Sc,Scandium,major,407 +21,Sc,Scandium,major,402 +21,Sc,Scandium,major,32 +21,Sc,Scandium,minor,501 +21,Sc,Scandium,minor,54 +22,Ti,Titanium,major,4966 +22,Ti,Titanium,major,462 +22,Ti,Titanium,major,456 +22,Ti,Titanium,major,35 +22,Ti,Titanium,minor,564 +22,Ti,Titanium,minor,60 +23,V,Vanadium,major,5465 +23,V,Vanadium,major,521 +23,V,Vanadium,major,513 +23,V,Vanadium,major,38 +23,V,Vanadium,minor,628 +23,V,Vanadium,minor,67 +24,Cr,Chromium,major,5989 +24,Cr,Chromium,major,584 +24,Cr,Chromium,major,575 +24,Cr,Chromium,major,43 +24,Cr,Chromium,minor,695 +24,Cr,Chromium,minor,74 +25,Mn,Manganese,major,6539 +25,Mn,Manganese,major,651 +25,Mn,Manganese,major,640 +25,Mn,Manganese,major,49 +25,Mn,Manganese,minor,769 +25,Mn,Manganese,minor,84 +26,Fe,Iron,major,7112 +26,Fe,Iron,major,721 +26,Fe,Iron,major,708 +26,Fe,Iron,major,54 +26,Fe,Iron,minor,846 +26,Fe,Iron,minor,93 +27,Co,Cobalt,major,7709 +27,Co,Cobalt,major,794 +27,Co,Cobalt,major,779 +27,Co,Cobalt,major,60 +27,Co,Cobalt,minor,926 +27,Co,Cobalt,minor,101 +28,Ni,Nickel,major,8333 +28,Ni,Nickel,major,872 +28,Ni,Nickel,major,855 +28,Ni,Nickel,major,68 +28,Ni,Nickel,minor,1008 +28,Ni,Nickel,minor,112 +29,Cu,Copper,major,8979 +29,Cu,Copper,major,951 +29,Cu,Copper,major,931 +29,Cu,Copper,major,74 +29,Cu,Copper,minor,1097 +29,Cu,Copper,minor,120 +30,Zn,Zinc,major,9659 +30,Zn,Zinc,major,1043 +30,Zn,Zinc,major,1020 +30,Zn,Zinc,minor,1194 +30,Zn,Zinc,minor,136 +30,Zn,Zinc,minor,87 +31,Ga,Gallium,major,10367 +31,Ga,Gallium,major,1142 +31,Ga,Gallium,major,1115 +31,Ga,Gallium,minor,1298 +31,Ga,Gallium,minor,158 +31,Ga,Gallium,minor,107 +31,Ga,Gallium,minor,103 +31,Ga,Gallium,minor,17 +32,Ge,Germanium,major,11103 +32,Ge,Germanium,major,1248 +32,Ge,Germanium,major,1217 +32,Ge,Germanium,major,29 +32,Ge,Germanium,minor,1414 +32,Ge,Germanium,minor,180 +32,Ge,Germanium,minor,128 +32,Ge,Germanium,minor,121 +33,As,Arsenic,major,11867 +33,As,Arsenic,major,1359 +33,As,Arsenic,major,1323 +33,As,Arsenic,major,41 +33,As,Arsenic,minor,1527 +33,As,Arsenic,minor,204 +33,As,Arsenic,minor,146 +33,As,Arsenic,minor,141 +34,Se,Selenium,major,12658 +34,Se,Selenium,major,1476 +34,Se,Selenium,major,1436 +34,Se,Selenium,major,57 +34,Se,Selenium,minor,1654 +34,Se,Selenium,minor,232 +34,Se,Selenium,minor,168 +34,Se,Selenium,minor,162 +35,Br,Bromine,major,13474 +35,Br,Bromine,major,1596 +35,Br,Bromine,major,1550 +35,Br,Bromine,major,70 +35,Br,Bromine,major,69 +35,Br,Bromine,minor,1782 +35,Br,Bromine,minor,257 +35,Br,Bromine,minor,189 +35,Br,Bromine,minor,182 +35,Br,Bromine,minor,27 +36,Kr,Krypton,major,14326 +36,Kr,Krypton,major,1727 +36,Kr,Krypton,major,1675 +36,Kr,Krypton,major,89 +36,Kr,Krypton,major,11 +36,Kr,Krypton,minor,1921 +36,Kr,Krypton,minor,287 +36,Kr,Krypton,minor,223 +36,Kr,Krypton,minor,214 +36,Kr,Krypton,minor,24 +37,Rb,Rubidium,major,15200 +37,Rb,Rubidium,major,1864 +37,Rb,Rubidium,major,1804 +37,Rb,Rubidium,major,112 +37,Rb,Rubidium,major,110 +37,Rb,Rubidium,major,15 +37,Rb,Rubidium,major,14 +37,Rb,Rubidium,minor,2065 +37,Rb,Rubidium,minor,322 +37,Rb,Rubidium,minor,247 +37,Rb,Rubidium,minor,239 +37,Rb,Rubidium,minor,29 +38,Sr,Strontium,major,16105 +38,Sr,Strontium,major,2007 +38,Sr,Strontium,major,1940 +38,Sr,Strontium,major,135 +38,Sr,Strontium,major,133 +38,Sr,Strontium,major,20 +38,Sr,Strontium,minor,2216 +38,Sr,Strontium,minor,358 +38,Sr,Strontium,minor,280 +38,Sr,Strontium,minor,269 +38,Sr,Strontium,minor,38 +39,Y,Yttrium,major,17038 +39,Y,Yttrium,major,2156 +39,Y,Yttrium,major,2080 +39,Y,Yttrium,major,160 +39,Y,Yttrium,major,157 +39,Y,Yttrium,major,26 +39,Y,Yttrium,minor,2373 +39,Y,Yttrium,minor,394 +39,Y,Yttrium,minor,312 +39,Y,Yttrium,minor,300 +39,Y,Yttrium,minor,45 +40,Zr,Zirconium,major,17998 +40,Zr,Zirconium,major,2307 +40,Zr,Zirconium,major,2222 +40,Zr,Zirconium,major,182 +40,Zr,Zirconium,major,180 +40,Zr,Zirconium,major,29 +40,Zr,Zirconium,minor,2532 +40,Zr,Zirconium,minor,430 +40,Zr,Zirconium,minor,344 +40,Zr,Zirconium,minor,331 +40,Zr,Zirconium,minor,51 +41,Nb,Niobium,major,18986 +41,Nb,Niobium,major,2465 +41,Nb,Niobium,major,2371 +41,Nb,Niobium,major,207 +41,Nb,Niobium,major,205 +41,Nb,Niobium,major,34 +41,Nb,Niobium,minor,2698 +41,Nb,Niobium,minor,468 +41,Nb,Niobium,minor,378 +41,Nb,Niobium,minor,363 +41,Nb,Niobium,minor,58 +42,Mo,Molybdenum,major,20000 +42,Mo,Molybdenum,major,2625 +42,Mo,Molybdenum,major,2520 +42,Mo,Molybdenum,major,230 +42,Mo,Molybdenum,major,227 +42,Mo,Molybdenum,major,35 +42,Mo,Molybdenum,minor,2866 +42,Mo,Molybdenum,minor,505 +42,Mo,Molybdenum,minor,410 +42,Mo,Molybdenum,minor,392 +42,Mo,Molybdenum,minor,62 +43,Tc,Technetium,major,21044 +43,Tc,Technetium,major,2793 +43,Tc,Technetium,major,2677 +43,Tc,Technetium,major,256 +43,Tc,Technetium,major,253 +43,Tc,Technetium,major,39 +43,Tc,Technetium,minor,3043 +43,Tc,Technetium,minor,544 +43,Tc,Technetium,minor,445 +43,Tc,Technetium,minor,425 +43,Tc,Technetium,minor,68 +44,Ru,Ruthenium,major,22117 +44,Ru,Ruthenium,major,2967 +44,Ru,Ruthenium,major,2838 +44,Ru,Ruthenium,major,284 +44,Ru,Ruthenium,major,279 +44,Ru,Ruthenium,major,43 +44,Ru,Ruthenium,minor,3224 +44,Ru,Ruthenium,minor,585 +44,Ru,Ruthenium,minor,483 +44,Ru,Ruthenium,minor,407 +44,Ru,Ruthenium,minor,75 +45,Rh,Rhodium,major,23220 +45,Rh,Rhodium,major,3146 +45,Rh,Rhodium,major,3004 +45,Rh,Rhodium,major,312 +45,Rh,Rhodium,major,307 +45,Rh,Rhodium,major,48 +45,Rh,Rhodium,minor,3412 +45,Rh,Rhodium,minor,627 +45,Rh,Rhodium,minor,521 +45,Rh,Rhodium,minor,496 +45,Rh,Rhodium,minor,81 +46,Pd,Palladium,major,24350 +46,Pd,Palladium,major,3330 +46,Pd,Palladium,major,3173 +46,Pd,Palladium,major,340 +46,Pd,Palladium,major,335 +46,Pd,Palladium,major,51 +46,Pd,Palladium,minor,3604 +46,Pd,Palladium,minor,670 +46,Pd,Palladium,minor,559 +46,Pd,Palladium,minor,532 +46,Pd,Palladium,minor,86 +47,Ag,Silver,major,25514 +47,Ag,Silver,major,3524 +47,Ag,Silver,major,3351 +47,Ag,Silver,major,373 +47,Ag,Silver,major,367 +47,Ag,Silver,minor,3806 +47,Ag,Silver,minor,718 +47,Ag,Silver,minor,602 +47,Ag,Silver,minor,571 +47,Ag,Silver,minor,95 +47,Ag,Silver,minor,63 +47,Ag,Silver,minor,56 +48,Cd,Cadmium,major,26711 +48,Cd,Cadmium,major,3727 +48,Cd,Cadmium,major,3538 +48,Cd,Cadmium,major,411 +48,Cd,Cadmium,major,404 +48,Cd,Cadmium,minor,4018 +48,Cd,Cadmium,minor,770 +48,Cd,Cadmium,minor,651 +48,Cd,Cadmium,minor,617 +48,Cd,Cadmium,minor,108 +48,Cd,Cadmium,minor,67 +49,In,Indium,major,27940 +49,In,Indium,major,3938 +49,In,Indium,major,3730 +49,In,Indium,major,451 +49,In,Indium,major,443 +49,In,Indium,minor,4238 +49,In,Indium,minor,826 +49,In,Indium,minor,702 +49,In,Indium,minor,664 +49,In,Indium,minor,122 +49,In,Indium,minor,77 +50,Sn,Tin,major,29200 +50,Sn,Tin,major,4156 +50,Sn,Tin,major,3929 +50,Sn,Tin,major,493 +50,Sn,Tin,major,485 +50,Sn,Tin,major,24 +50,Sn,Tin,minor,4465 +50,Sn,Tin,minor,884 +50,Sn,Tin,minor,756 +50,Sn,Tin,minor,714 +50,Sn,Tin,minor,137 +50,Sn,Tin,minor,89 +51,Sb,Antimony,major,30491 +51,Sb,Antimony,major,4380 +51,Sb,Antimony,major,4132 +51,Sb,Antimony,major,537 +51,Sb,Antimony,major,528 +51,Sb,Antimony,major,31 +51,Sb,Antimony,minor,4698 +51,Sb,Antimony,minor,944 +51,Sb,Antimony,minor,812 +51,Sb,Antimony,minor,766 +51,Sb,Antimony,minor,152 +51,Sb,Antimony,minor,98 +52,Te,Tellurium,major,31814 +52,Te,Tellurium,major,4341 +52,Te,Tellurium,major,583 +52,Te,Tellurium,major,572 +52,Te,Tellurium,major,40 +52,Te,Tellurium,minor,4939 +52,Te,Tellurium,minor,4612 +52,Te,Tellurium,minor,1006 +52,Te,Tellurium,minor,870 +52,Te,Tellurium,minor,819 +52,Te,Tellurium,minor,168 +52,Te,Tellurium,minor,110 +53,I,Iodine,major,33169 +53,I,Iodine,major,4557 +53,I,Iodine,major,631 +53,I,Iodine,major,619 +53,I,Iodine,major,50 +53,I,Iodine,minor,5188 +53,I,Iodine,minor,4852 +53,I,Iodine,minor,1072 +53,I,Iodine,minor,931 +53,I,Iodine,minor,875 +53,I,Iodine,minor,186 +53,I,Iodine,minor,123 +54,Xe,Xenon,major,34561 +54,Xe,Xenon,major,4782 +54,Xe,Xenon,major,684 +54,Xe,Xenon,major,672 +54,Xe,Xenon,major,63 +54,Xe,Xenon,minor,5453 +54,Xe,Xenon,minor,5104 +54,Xe,Xenon,minor,1143 +54,Xe,Xenon,minor,999 +54,Xe,Xenon,minor,937 +54,Xe,Xenon,minor,208 +54,Xe,Xenon,minor,147 +55,Cs,Cesium,major,5012 +55,Cs,Cesium,major,740 +55,Cs,Cesium,major,726 +55,Cs,Cesium,major,79 +55,Cs,Cesium,major,77 +55,Cs,Cesium,major,13 +55,Cs,Cesium,major,11 +55,Cs,Cesium,minor,5715 +55,Cs,Cesium,minor,5359 +55,Cs,Cesium,minor,1217 +55,Cs,Cesium,minor,1065 +55,Cs,Cesium,minor,998 +55,Cs,Cesium,minor,231 +55,Cs,Cesium,minor,172 +55,Cs,Cesium,minor,162 +55,Cs,Cesium,minor,23 +56,Ba,Barium,major,5247 +56,Ba,Barium,major,796 +56,Ba,Barium,major,781 +56,Ba,Barium,major,93 +56,Ba,Barium,major,90 +56,Ba,Barium,major,17 +56,Ba,Barium,major,15 +56,Ba,Barium,minor,5989 +56,Ba,Barium,minor,5624 +56,Ba,Barium,minor,1293 +56,Ba,Barium,minor,1137 +56,Ba,Barium,minor,1062 +56,Ba,Barium,minor,253 +56,Ba,Barium,minor,180 +56,Ba,Barium,minor,180 +56,Ba,Barium,minor,39 +57,La,Lanthanum,major,849 +57,La,Lanthanum,major,832 +57,La,Lanthanum,major,99 +57,La,Lanthanum,major,14 +57,La,Lanthanum,minor,6266 +57,La,Lanthanum,minor,5891 +57,La,Lanthanum,minor,1361 +57,La,Lanthanum,minor,1204 +57,La,Lanthanum,minor,1123 +57,La,Lanthanum,minor,270 +57,La,Lanthanum,minor,206 +57,La,Lanthanum,minor,191 +57,La,Lanthanum,minor,32 +58,Ce,Cerium,major,5723 +58,Ce,Cerium,major,901 +58,Ce,Cerium,major,883 +58,Ce,Cerium,major,110 +58,Ce,Cerium,major,20 +58,Ce,Cerium,minor,6549 +58,Ce,Cerium,minor,6164 +58,Ce,Cerium,minor,1435 +58,Ce,Cerium,minor,1273 +58,Ce,Cerium,minor,1185 +58,Ce,Cerium,minor,290 +58,Ce,Cerium,minor,233 +58,Ce,Cerium,minor,207 +58,Ce,Cerium,minor,38 +59,Pr,Praseodymium,major,5964 +59,Pr,Praseodymium,major,951 +59,Pr,Praseodymium,major,931 +59,Pr,Praseodymium,major,113 +59,Pr,Praseodymium,major,22 +59,Pr,Praseodymium,minor,6835 +59,Pr,Praseodymium,minor,6440 +59,Pr,Praseodymium,minor,1511 +59,Pr,Praseodymium,minor,1337 +59,Pr,Praseodymium,minor,1242 +59,Pr,Praseodymium,minor,305 +59,Pr,Praseodymium,minor,236 +59,Pr,Praseodymium,minor,218 +59,Pr,Praseodymium,minor,37 +60,Nd,Neodymium,major,5964 +60,Nd,Neodymium,major,1000 +60,Nd,Neodymium,major,978 +60,Nd,Neodymium,major,118 +60,Nd,Neodymium,major,21 +60,Nd,Neodymium,minor,6835 +60,Nd,Neodymium,minor,5964 +60,Nd,Neodymium,minor,1575 +60,Nd,Neodymium,minor,1403 +60,Nd,Neodymium,minor,1297 +60,Nd,Neodymium,minor,315 +60,Nd,Neodymium,minor,225 +60,Nd,Neodymium,minor,225 +60,Nd,Neodymium,minor,38 +61,Pm,Promethium,major,6459 +61,Pm,Promethium,major,1052 +61,Pm,Promethium,major,1027 +61,Pm,Promethium,major,120 +61,Pm,Promethium,major,121 +61,Pm,Promethium,major,24 +61,Pm,Promethium,minor,7428 +61,Pm,Promethium,minor,7013 +61,Pm,Promethium,minor,1646 +61,Pm,Promethium,minor,1471 +61,Pm,Promethium,minor,1357 +61,Pm,Promethium,minor,330 +61,Pm,Promethium,minor,242 +62,Sm,Samarium,major,6716 +62,Sm,Samarium,major,1106 +62,Sm,Samarium,major,1080 +62,Sm,Samarium,major,129 +62,Sm,Samarium,major,21 +62,Sm,Samarium,minor,7737 +62,Sm,Samarium,minor,7312 +62,Sm,Samarium,minor,1723 +62,Sm,Samarium,minor,1541 +62,Sm,Samarium,minor,1420 +62,Sm,Samarium,minor,346 +62,Sm,Samarium,minor,266 +62,Sm,Samarium,minor,247 +62,Sm,Samarium,minor,37 +63,Eu,Europium,major,6977 +63,Eu,Europium,major,1161 +63,Eu,Europium,major,1131 +63,Eu,Europium,major,133 +63,Eu,Europium,major,22 +63,Eu,Europium,minor,8052 +63,Eu,Europium,minor,7617 +63,Eu,Europium,minor,1800 +63,Eu,Europium,minor,1614 +63,Eu,Europium,minor,1481 +63,Eu,Europium,minor,360 +63,Eu,Europium,minor,284 +63,Eu,Europium,minor,257 +63,Eu,Europium,minor,32 +64,Gd,Gadolinium,major,7243 +64,Gd,Gadolinium,major,1217 +64,Gd,Gadolinium,major,1185 +64,Gd,Gadolinium,major,141 +64,Gd,Gadolinium,major,20 +64,Gd,Gadolinium,minor,8376 +64,Gd,Gadolinium,minor,7930 +64,Gd,Gadolinium,minor,1881 +64,Gd,Gadolinium,minor,1688 +64,Gd,Gadolinium,minor,1544 +64,Gd,Gadolinium,minor,376 +64,Gd,Gadolinium,minor,289 +64,Gd,Gadolinium,minor,271 +64,Gd,Gadolinium,minor,36 +65,Tb,Terbium,major,7514 +65,Tb,Terbium,major,1275 +65,Tb,Terbium,major,1241 +65,Tb,Terbium,major,147 +65,Tb,Terbium,major,25 +65,Tb,Terbium,minor,8708 +65,Tb,Terbium,minor,8252 +65,Tb,Terbium,minor,1968 +65,Tb,Terbium,minor,1768 +65,Tb,Terbium,minor,1611 +65,Tb,Terbium,minor,398 +65,Tb,Terbium,minor,310 +65,Tb,Terbium,minor,285 +65,Tb,Terbium,minor,39 +66,Dy,Dysprosium,major,1333 +66,Dy,Dysprosium,major,1295 +66,Dy,Dysprosium,major,154 +66,Dy,Dysprosium,major,26 +66,Dy,Dysprosium,major,7790 +66,Dy,Dysprosium,minor,9046 +66,Dy,Dysprosium,minor,8581 +66,Dy,Dysprosium,minor,2047 +66,Dy,Dysprosium,minor,1842 +66,Dy,Dysprosium,minor,1676 +66,Dy,Dysprosium,minor,416 +66,Dy,Dysprosium,minor,332 +66,Dy,Dysprosium,minor,293 +66,Dy,Dysprosium,minor,63 +67,Ho,Holmium,major,8071 +67,Ho,Holmium,major,1392 +67,Ho,Holmium,major,1351 +67,Ho,Holmium,major,161 +67,Ho,Holmium,major,20 +67,Ho,Holmium,minor,9394 +67,Ho,Holmium,minor,8918 +67,Ho,Holmium,minor,2128 +67,Ho,Holmium,minor,1923 +67,Ho,Holmium,minor,1741 +67,Ho,Holmium,minor,436 +67,Ho,Holmium,minor,344 +67,Ho,Holmium,minor,307 +67,Ho,Holmium,minor,51 +68,Er,Erbium,major,8358 +68,Er,Erbium,major,1453 +68,Er,Erbium,major,1409 +68,Er,Erbium,major,177 +68,Er,Erbium,major,168 +68,Er,Erbium,major,29 +68,Er,Erbium,minor,9751 +68,Er,Erbium,minor,9264 +68,Er,Erbium,minor,2207 +68,Er,Erbium,minor,2006 +68,Er,Erbium,minor,1812 +68,Er,Erbium,minor,449 +68,Er,Erbium,minor,366 +68,Er,Erbium,minor,320 +68,Er,Erbium,minor,60 +69,Tm,Thulium,major,8648 +69,Tm,Thulium,major,1515 +69,Tm,Thulium,major,1468 +69,Tm,Thulium,major,180 +69,Tm,Thulium,major,32 +69,Tm,Thulium,minor,10116 +69,Tm,Thulium,minor,9617 +69,Tm,Thulium,minor,2307 +69,Tm,Thulium,minor,2090 +69,Tm,Thulium,minor,1885 +69,Tm,Thulium,minor,472 +69,Tm,Thulium,minor,386 +69,Tm,Thulium,minor,337 +69,Tm,Thulium,minor,53 +70,Yb,Ytterbium,major,8944 +70,Yb,Ytterbium,major,1576 +70,Yb,Ytterbium,major,1528 +70,Yb,Ytterbium,major,198 +70,Yb,Ytterbium,major,185 +70,Yb,Ytterbium,major,23 +70,Yb,Ytterbium,minor,10486 +70,Yb,Ytterbium,minor,9978 +70,Yb,Ytterbium,minor,2398 +70,Yb,Ytterbium,minor,2173 +70,Yb,Ytterbium,minor,1950 +70,Yb,Ytterbium,minor,487 +70,Yb,Ytterbium,minor,397 +70,Yb,Ytterbium,minor,344 +70,Yb,Ytterbium,minor,54 +71,Lu,Lutetium,major,9244 +71,Lu,Lutetium,major,1639 +71,Lu,Lutetium,major,1589 +71,Lu,Lutetium,major,195 +71,Lu,Lutetium,major,195 +71,Lu,Lutetium,major,28 +71,Lu,Lutetium,minor,10870 +71,Lu,Lutetium,minor,10349 +71,Lu,Lutetium,minor,2491 +71,Lu,Lutetium,minor,2264 +71,Lu,Lutetium,minor,2024 +71,Lu,Lutetium,minor,506 +71,Lu,Lutetium,minor,410 +71,Lu,Lutetium,minor,359 +71,Lu,Lutetium,minor,57 +72,Hf,Hafnium,major,9561 +72,Hf,Hafnium,major,1716 +72,Hf,Hafnium,major,1662 +72,Hf,Hafnium,major,38 +72,Hf,Hafnium,major,31 +72,Hf,Hafnium,minor,11271 +72,Hf,Hafnium,minor,10739 +72,Hf,Hafnium,minor,2601 +72,Hf,Hafnium,minor,2365 +72,Hf,Hafnium,minor,2108 +72,Hf,Hafnium,minor,538 +72,Hf,Hafnium,minor,437 +72,Hf,Hafnium,minor,380 +72,Hf,Hafnium,minor,224 +72,Hf,Hafnium,minor,214 +72,Hf,Hafnium,minor,65 +73,Ta,Tantalum,major,9881 +73,Ta,Tantalum,major,1793 +73,Ta,Tantalum,major,1735 +73,Ta,Tantalum,major,45 +73,Ta,Tantalum,major,36 +73,Ta,Tantalum,minor,11682 +73,Ta,Tantalum,minor,11136 +73,Ta,Tantalum,minor,2708 +73,Ta,Tantalum,minor,2469 +73,Ta,Tantalum,minor,2194 +73,Ta,Tantalum,minor,566 +73,Ta,Tantalum,minor,465 +73,Ta,Tantalum,minor,405 +73,Ta,Tantalum,minor,241 +73,Ta,Tantalum,minor,229 +73,Ta,Tantalum,minor,71 +74,W,Tungsten,major,10207 +74,W,Tungsten,major,1872 +74,W,Tungsten,major,1809 +74,W,Tungsten,major,47 +74,W,Tungsten,major,36 +74,W,Tungsten,minor,12100 +74,W,Tungsten,minor,11544 +74,W,Tungsten,minor,2820 +74,W,Tungsten,minor,2575 +74,W,Tungsten,minor,2281 +74,W,Tungsten,minor,595 +74,W,Tungsten,minor,492 +74,W,Tungsten,minor,425 +74,W,Tungsten,minor,259 +74,W,Tungsten,minor,245 +74,W,Tungsten,minor,37 +74,W,Tungsten,minor,34 +74,W,Tungsten,minor,77 +75,Re,Rhenium,major,10535 +75,Re,Rhenium,major,1949 +75,Re,Rhenium,major,1883 +75,Re,Rhenium,major,46 +75,Re,Rhenium,major,35 +75,Re,Rhenium,minor,12527 +75,Re,Rhenium,minor,11959 +75,Re,Rhenium,minor,2932 +75,Re,Rhenium,minor,2682 +75,Re,Rhenium,minor,2367 +75,Re,Rhenium,minor,625 +75,Re,Rhenium,minor,518 +75,Re,Rhenium,minor,444 +75,Re,Rhenium,minor,274 +75,Re,Rhenium,minor,260 +75,Re,Rhenium,minor,41 +75,Re,Rhenium,minor,83 +76,Os,Osmium,major,10871 +76,Os,Osmium,major,2031 +76,Os,Osmium,major,1960 +76,Os,Osmium,major,58 +76,Os,Osmium,major,45 +76,Os,Osmium,minor,12968 +76,Os,Osmium,minor,12385 +76,Os,Osmium,minor,3049 +76,Os,Osmium,minor,2792 +76,Os,Osmium,minor,2457 +76,Os,Osmium,minor,654 +76,Os,Osmium,minor,547 +76,Os,Osmium,minor,468 +76,Os,Osmium,minor,289 +76,Os,Osmium,minor,273 +76,Os,Osmium,minor,46 +76,Os,Osmium,minor,84 +77,Ir,Iridium,major,11215 +77,Ir,Iridium,major,2116 +77,Ir,Iridium,major,2040 +77,Ir,Iridium,major,63 +77,Ir,Iridium,major,51 +77,Ir,Iridium,minor,13419 +77,Ir,Iridium,minor,12824 +77,Ir,Iridium,minor,3174 +77,Ir,Iridium,minor,2909 +77,Ir,Iridium,minor,2551 +77,Ir,Iridium,minor,690 +77,Ir,Iridium,minor,577 +77,Ir,Iridium,minor,494 +77,Ir,Iridium,minor,311 +77,Ir,Iridium,minor,295 +77,Ir,Iridium,minor,63 +77,Ir,Iridium,minor,61 +77,Ir,Iridium,minor,95 +78,Pt,Platinum,major,11564 +78,Pt,Platinum,major,2202 +78,Pt,Platinum,major,2122 +78,Pt,Platinum,minor,13880 +78,Pt,Platinum,minor,13273 +78,Pt,Platinum,minor,3296 +78,Pt,Platinum,minor,3027 +78,Pt,Platinum,minor,2645 +78,Pt,Platinum,minor,722 +78,Pt,Platinum,minor,609 +78,Pt,Platinum,minor,519 +78,Pt,Platinum,minor,331 +78,Pt,Platinum,minor,313 +78,Pt,Platinum,minor,74 +78,Pt,Platinum,minor,71 +78,Pt,Platinum,minor,102 +78,Pt,Platinum,minor,65 +78,Pt,Platinum,minor,52 +79,Au,Gold,major,11919 +79,Au,Gold,major,2291 +79,Au,Gold,major,2206 +79,Au,Gold,minor,14353 +79,Au,Gold,minor,13734 +79,Au,Gold,minor,3425 +79,Au,Gold,minor,3148 +79,Au,Gold,minor,2743 +79,Au,Gold,minor,759 +79,Au,Gold,minor,644 +79,Au,Gold,minor,545 +79,Au,Gold,minor,352 +79,Au,Gold,minor,334 +79,Au,Gold,minor,86 +79,Au,Gold,minor,83 +79,Au,Gold,minor,108 +79,Au,Gold,minor,72 +79,Au,Gold,minor,54 +80,Hg,Mercury,major,12284 +80,Hg,Mercury,major,2385 +80,Hg,Mercury,major,2295 +80,Hg,Mercury,minor,14839 +80,Hg,Mercury,minor,14209 +80,Hg,Mercury,minor,3562 +80,Hg,Mercury,minor,3279 +80,Hg,Mercury,minor,2847 +80,Hg,Mercury,minor,800 +80,Hg,Mercury,minor,677 +80,Hg,Mercury,minor,571 +80,Hg,Mercury,minor,378 +80,Hg,Mercury,minor,360 +80,Hg,Mercury,minor,102 +80,Hg,Mercury,minor,99 +80,Hg,Mercury,minor,120 +80,Hg,Mercury,minor,81 +80,Hg,Mercury,minor,58 +81,Tl,Thallium,major,2485 +81,Tl,Thallium,major,2389 +81,Tl,Thallium,major,15 +81,Tl,Thallium,major,13 +81,Tl,Thallium,minor,15347 +81,Tl,Thallium,minor,14698 +81,Tl,Thallium,minor,12658 +81,Tl,Thallium,minor,3704 +81,Tl,Thallium,minor,3416 +81,Tl,Thallium,minor,2957 +81,Tl,Thallium,minor,846 +81,Tl,Thallium,minor,721 +81,Tl,Thallium,minor,609 +81,Tl,Thallium,minor,407 +81,Tl,Thallium,minor,386 +81,Tl,Thallium,minor,123 +81,Tl,Thallium,minor,119 +81,Tl,Thallium,minor,136 +81,Tl,Thallium,minor,100 +81,Tl,Thallium,minor,75 +82,Pb,Lead,major,13035 +82,Pb,Lead,major,2586 +82,Pb,Lead,major,2484 +82,Pb,Lead,major,22 +82,Pb,Lead,major,19 +82,Pb,Lead,minor,15861 +82,Pb,Lead,minor,15200 +82,Pb,Lead,minor,3851 +82,Pb,Lead,minor,3554 +82,Pb,Lead,minor,3066 +82,Pb,Lead,minor,894 +82,Pb,Lead,minor,764 +82,Pb,Lead,minor,645 +82,Pb,Lead,minor,435 +82,Pb,Lead,minor,413 +82,Pb,Lead,minor,143 +82,Pb,Lead,minor,138 +82,Pb,Lead,minor,147 +82,Pb,Lead,minor,105 +82,Pb,Lead,minor,86 +83,Bi,Bismuth,major,13419 +83,Bi,Bismuth,major,2688 +83,Bi,Bismuth,major,2580 +83,Bi,Bismuth,major,27 +83,Bi,Bismuth,major,24 +83,Bi,Bismuth,minor,16388 +83,Bi,Bismuth,minor,15711 +83,Bi,Bismuth,minor,3999 +83,Bi,Bismuth,minor,3696 +83,Bi,Bismuth,minor,3177 +83,Bi,Bismuth,minor,938 +83,Bi,Bismuth,minor,805 +83,Bi,Bismuth,minor,679 +83,Bi,Bismuth,minor,464 +83,Bi,Bismuth,minor,440 +83,Bi,Bismuth,minor,162 +83,Bi,Bismuth,minor,157 +83,Bi,Bismuth,minor,159 +83,Bi,Bismuth,minor,117 +83,Bi,Bismuth,minor,93 +84,Po,Polonium,major,13814 +84,Po,Polonium,major,2798 +84,Po,Polonium,major,2683 +84,Po,Polonium,major,31 +84,Po,Polonium,minor,16939 +84,Po,Polonium,minor,16244 +84,Po,Polonium,minor,4149 +84,Po,Polonium,minor,3854 +84,Po,Polonium,minor,3302 +84,Po,Polonium,minor,995 +84,Po,Polonium,minor,851 +84,Po,Polonium,minor,705 +84,Po,Polonium,minor,500 +84,Po,Polonium,minor,473 +85,At,Astatine,major,14214 +85,At,Astatine,major,2908 +85,At,Astatine,major,2787 +85,At,Astatine,minor,17493 +85,At,Astatine,minor,16785 +85,At,Astatine,minor,4317 +85,At,Astatine,minor,4008 +85,At,Astatine,minor,3426 +85,At,Astatine,minor,1042 +85,At,Astatine,minor,886 +85,At,Astatine,minor,740 +85,At,Astatine,minor,533 +86,Rn,Radon,major,14619 +86,Rn,Radon,major,3022 +86,Rn,Radon,major,2892 +86,Rn,Radon,minor,18049 +86,Rn,Radon,minor,17337 +86,Rn,Radon,minor,4482 +86,Rn,Radon,minor,4159 +86,Rn,Radon,minor,3538 +86,Rn,Radon,minor,1097 +86,Rn,Radon,minor,929 +86,Rn,Radon,minor,768 +86,Rn,Radon,minor,567 +87,Fr,Francium,major,15031 +87,Fr,Francium,major,3136 +87,Fr,Francium,major,3000 +87,Fr,Francium,minor,18639 +87,Fr,Francium,minor,17907 +87,Fr,Francium,minor,4652 +87,Fr,Francium,minor,4327 +87,Fr,Francium,minor,3663 +87,Fr,Francium,minor,1153 +87,Fr,Francium,minor,980 +87,Fr,Francium,minor,810 +87,Fr,Francium,minor,603 +87,Fr,Francium,minor,577 +88,Ra,Radium,major,15444 +88,Ra,Radium,major,3248 +88,Ra,Radium,major,3105 +88,Ra,Radium,major,299 +88,Ra,Radium,major,67 +88,Ra,Radium,minor,19237 +88,Ra,Radium,minor,18484 +88,Ra,Radium,minor,4822 +88,Ra,Radium,minor,4490 +88,Ra,Radium,minor,3792 +88,Ra,Radium,minor,1208 +88,Ra,Radium,minor,1058 +88,Ra,Radium,minor,879 +88,Ra,Radium,minor,636 +88,Ra,Radium,minor,603 +88,Ra,Radium,minor,254 +88,Ra,Radium,minor,153 +89,Ac,Actinium,major,15871 +89,Ac,Actinium,major,3370 +89,Ac,Actinium,major,3219 +89,Ac,Actinium,minor,19840 +89,Ac,Actinium,minor,19083 +89,Ac,Actinium,minor,5002 +89,Ac,Actinium,minor,4656 +89,Ac,Actinium,minor,3909 +89,Ac,Actinium,minor,1269 +89,Ac,Actinium,minor,1080 +89,Ac,Actinium,minor,890 +89,Ac,Actinium,minor,675 +90,Th,Thorium,major,16300 +90,Th,Thorium,major,3491 +90,Th,Thorium,major,3332 +90,Th,Thorium,major,344 +90,Th,Thorium,major,335 +90,Th,Thorium,major,94 +90,Th,Thorium,major,88 +90,Th,Thorium,minor,20472 +90,Th,Thorium,minor,19693 +90,Th,Thorium,minor,5182 +90,Th,Thorium,minor,4830 +90,Th,Thorium,minor,4046 +90,Th,Thorium,minor,1330 +90,Th,Thorium,minor,1168 +90,Th,Thorium,minor,967 +90,Th,Thorium,minor,714 +90,Th,Thorium,minor,676 +90,Th,Thorium,minor,290 +90,Th,Thorium,minor,229 +90,Th,Thorium,minor,182 +91,Pa,Protactinium,major,16733 +91,Pa,Protactinium,major,3611 +91,Pa,Protactinium,major,3442 +91,Pa,Protactinium,major,371 +91,Pa,Protactinium,major,360 +91,Pa,Protactinium,major,94 +91,Pa,Protactinium,minor,21105 +91,Pa,Protactinium,minor,20314 +91,Pa,Protactinium,minor,5367 +91,Pa,Protactinium,minor,5001 +91,Pa,Protactinium,minor,4174 +91,Pa,Protactinium,minor,1387 +91,Pa,Protactinium,minor,1224 +91,Pa,Protactinium,minor,1007 +91,Pa,Protactinium,minor,743 +91,Pa,Protactinium,minor,708 +91,Pa,Protactinium,minor,310 +91,Pa,Protactinium,minor,223 +92,U,Uranium,major,17166 +92,U,Uranium,major,3728 +92,U,Uranium,major,3552 +92,U,Uranium,major,391 +92,U,Uranium,major,381 +92,U,Uranium,major,105 +92,U,Uranium,major,96 +92,U,Uranium,minor,21757 +92,U,Uranium,minor,20948 +92,U,Uranium,minor,5548 +92,U,Uranium,minor,5182 +92,U,Uranium,minor,4303 +92,U,Uranium,minor,1441 +92,U,Uranium,minor,1273 +92,U,Uranium,minor,1045 +92,U,Uranium,minor,780 +92,U,Uranium,minor,738 +92,U,Uranium,minor,324 +92,U,Uranium,minor,259 +92,U,Uranium,minor,195 +93,Np,Neptunium,major,17610 +93,Np,Neptunium,major,3850 +93,Np,Neptunium,major,3666 +93,Np,Neptunium,major,415 +93,Np,Neptunium,major,404 +93,Np,Neptunium,major,109 +93,Np,Neptunium,major,101 +93,Np,Neptunium,minor,22427 +93,Np,Neptunium,minor,21601 +93,Np,Neptunium,minor,5723 +93,Np,Neptunium,minor,5366 +93,Np,Neptunium,minor,4435 +93,Np,Neptunium,minor,1501 +93,Np,Neptunium,minor,1328 +93,Np,Neptunium,minor,1087 +93,Np,Neptunium,minor,816 +93,Np,Neptunium,minor,770 +93,Np,Neptunium,minor,283 +93,Np,Neptunium,minor,206 +94,Pu,Plutonium,major,18057 +94,Pu,Plutonium,major,3973 +94,Pu,Plutonium,major,3778 +94,Pu,Plutonium,major,446 +94,Pu,Plutonium,major,432 +94,Pu,Plutonium,major,116 +94,Pu,Plutonium,major,105 +94,Pu,Plutonium,minor,23097 +94,Pu,Plutonium,minor,22266 +94,Pu,Plutonium,minor,5933 +94,Pu,Plutonium,minor,5541 +94,Pu,Plutonium,minor,4557 +94,Pu,Plutonium,minor,1559 +94,Pu,Plutonium,minor,1372 +94,Pu,Plutonium,minor,1115 +94,Pu,Plutonium,minor,849 +94,Pu,Plutonium,minor,801 +94,Pu,Plutonium,minor,352 +94,Pu,Plutonium,minor,274 +94,Pu,Plutonium,minor,207 +95,Am,Americium,major,18504 +95,Am,Americium,major,4092 +95,Am,Americium,major,3887 +95,Am,Americium,major,116 +95,Am,Americium,major,103 +95,Am,Americium,minor,23773 +95,Am,Americium,minor,22944 +95,Am,Americium,minor,6121 +95,Am,Americium,minor,5710 +95,Am,Americium,minor,4667 +95,Am,Americium,minor,1617 +95,Am,Americium,minor,1412 +95,Am,Americium,minor,1136 +95,Am,Americium,minor,879 +95,Am,Americium,minor,828 +96,Cm,Curium,, +97,Bk,Berkelium,, +98,Cf,Californium,, +99,Es,Einsteinium,, +100,Fm,Fermium,, +101,Md,Mendelevium,, +102,No,Nobelium,, +103,Lr,Lawrencium,, +104,Rf,Rutherfordium,, +105,Db,Dubnium,, +106,Sg,Seaborgium,, +107,Bh,Bohrium,, +108,Hs,Hassium,, +109,Mt,Meitnerium,, +110,Ds,Darmstadtium,, +111,Rg,Roentgenium,, +112,Cn,Copernicium,, +113,Uut,Ununtrium,, +114,Fl,Flerovium,, +115,Uup,Ununpentium,, +116,Lv,Livermorium,, +117,Uus,Ununseptium,, +118,Uuo,Ununoctium,, \ No newline at end of file diff --git a/src/quantem/spectroscopy/spectroscopy_models.py b/src/quantem/spectroscopy/spectroscopy_models.py new file mode 100644 index 00000000..076b5372 --- /dev/null +++ b/src/quantem/spectroscopy/spectroscopy_models.py @@ -0,0 +1,377 @@ +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn + +from quantem.spectroscopy.utils import load_xray_lines_database + + +def inverse_softplus(x: torch.Tensor, min_value: float = 1e-8) -> torch.Tensor: + """Numerically stable inverse of softplus for positive initialization values.""" + x = torch.clamp(x, min=min_value) + # For large x, log(expm1(x)) can overflow in float32. Use a stable branch. + return torch.where( + x > 20.0, + x + torch.log1p(-torch.exp(-x)), + torch.log(torch.expm1(x)), + ) + + +def eds_data_loss( + predicted: torch.Tensor, target: torch.Tensor, loss: str = "poisson", min_value: float = 1e-8 +) -> torch.Tensor: + """Compute EDS fit loss with clamped positive predictions.""" + pred_safe = torch.nan_to_num(predicted, nan=min_value, posinf=1e8, neginf=min_value) + pred_safe = torch.clamp(pred_safe, min=min_value, max=1e8) + if loss == "poisson": + target_safe = torch.nan_to_num(target, nan=0.0, posinf=1e8, neginf=0.0) + target_safe = torch.clamp(target_safe, min=0.0, max=1e8) + if hasattr(torch, "xlogy"): + log_term = torch.xlogy(target_safe, pred_safe) + elif hasattr(torch.special, "xlogy"): + log_term = torch.special.xlogy(target_safe, pred_safe) + else: + log_term = target_safe * torch.log(pred_safe) + log_term = torch.nan_to_num(log_term, nan=0.0, posinf=1e8, neginf=-1e8) + loss_terms = pred_safe - log_term + return torch.mean(torch.nan_to_num(loss_terms, nan=1e8, posinf=1e8, neginf=-1e8)) + if loss == "mse": + target_safe = torch.nan_to_num(target, nan=0.0, posinf=1e8, neginf=-1e8) + return nn.functional.mse_loss(pred_safe, target_safe) + raise ValueError("loss must be 'poisson' or 'mse'") + + +def polynomial_energy_basis(energy_axis: torch.Tensor, degree: int) -> torch.Tensor: + """Return polynomial basis in normalized energy coordinates.""" + energy_norm = (energy_axis - energy_axis.min()) / ( + energy_axis.max() - energy_axis.min() + 1e-12 + ) + return torch.stack([energy_norm**i for i in range(degree + 1)], dim=0) + + +def build_element_basis( + energy_axis: torch.Tensor, + peak_energies: torch.Tensor, + peak_weights: torch.Tensor, + peak_element_indices: torch.Tensor, + peak_width_by_peak: torch.Tensor, + n_elements: int, + energy_step: float, +) -> torch.Tensor: + """Build matrix mapping per-element concentrations to spectral intensity.""" + fwhm = nn.functional.softplus(peak_width_by_peak) + sigma = (fwhm / 2.355).unsqueeze(1) + centers = peak_energies.unsqueeze(1) + energies = energy_axis.unsqueeze(0) + all_peaks = torch.exp(-0.5 * ((energies - centers) / sigma) ** 2) + sqrt_2pi = torch.sqrt(torch.tensor(2 * np.pi, dtype=all_peaks.dtype, device=all_peaks.device)) + all_peaks = all_peaks * energy_step / (sqrt_2pi * sigma) + weighted_peaks = all_peaks * peak_weights.unsqueeze(1) + + basis = torch.zeros( + (n_elements, energy_axis.shape[0]), + dtype=weighted_peaks.dtype, + device=weighted_peaks.device, + ) + basis.index_add_(0, peak_element_indices.to(weighted_peaks.device), weighted_peaks) + return basis.t() + + +def abundance_smoothness_l2(abundance_maps: torch.Tensor) -> torch.Tensor: + """Spatial L2 smoothness for abundance maps shaped (n_elements, y, x).""" + if abundance_maps.ndim != 3: + raise ValueError("abundance_maps must have shape (n_elements, y, x)") + + loss = abundance_maps.new_tensor(0.0) + if abundance_maps.shape[2] > 1: + dx = abundance_maps[:, :, 1:] - abundance_maps[:, :, :-1] + loss = loss + dx.pow(2).mean() + if abundance_maps.shape[1] > 1: + dy = abundance_maps[:, 1:, :] - abundance_maps[:, :-1, :] + loss = loss + dy.pow(2).mean() + return loss + + +class EDSModel(nn.Module): + """EDS spectrum model = peaks + optional background.""" + + def __init__(self, peak_model, background_model=None): + super().__init__() + self.peak_model = peak_model + self.background_model = background_model + + def forward(self): + spectrum = self.peak_model() + if self.background_model is not None: + spectrum = spectrum + self.background_model() + return spectrum + + +class GaussianPeaks(nn.Module): + """Generate Gaussian peak spectra from X-ray line data.""" + + def __init__( + self, + energy_axis, + peak_width, + elements_to_fit=None, + ): + super().__init__() + + current_dir = Path(__file__).parent + data = load_xray_lines_database(current_dir / "x_ray_lines.csv") + + energy_axis_tensor = ( + energy_axis.float() + if torch.is_tensor(energy_axis) + else torch.tensor(energy_axis, dtype=torch.float32) + ) + self.register_buffer("energy_axis", energy_axis_tensor) + self.energy_min = self.energy_axis.min().item() + self.energy_max = self.energy_axis.max().item() + self.energy_step = (self.energy_axis[1] - self.energy_axis[0]).item() + + all_element_data = {} + for elem, lines in data.items(): + if len(lines) > 0: + energies = [] + weights = [] + line_names = [] + + for line_name, line_data in lines.items(): + energy = line_data["energy (keV)"] + if self.energy_min - 0.5 <= energy <= self.energy_max + 0.5: + energies.append(energy) + weights.append(line_data["weight"]) + line_names.append(line_name) + + if len(energies) > 0: + all_element_data[elem] = { + "energies": energies, + "weights": weights, + "line_names": line_names, + } + + if elements_to_fit is not None: + self.element_data = {} + for elem in elements_to_fit: + if elem in all_element_data: + self.element_data[elem] = all_element_data[elem] + else: + self.element_data = all_element_data + + self.element_names = list(self.element_data.keys()) + n_elements = len(self.element_names) + + all_peak_energies = [] + all_peak_weights = [] + all_peak_element_indices = [] + all_peak_shell_group_indices = [] + shell_group_lookup = {} + shell_group_names = [] + shell_group_element_indices = [] + shell_group_shell_labels = [] + + for elem_idx, elem in enumerate(self.element_names): + energies = self.element_data[elem]["energies"] + weights = np.asarray(self.element_data[elem]["weights"], dtype=np.float32) + line_names = self.element_data[elem]["line_names"] + + shell_labels = [self._line_shell_label(line_name) for line_name in line_names] + normalized_weights = np.zeros_like(weights) + for shell_label in set(shell_labels): + shell_indices = [i for i, label in enumerate(shell_labels) if label == shell_label] + shell_weights = np.clip(weights[shell_indices], a_min=0.0, a_max=None) + shell_sum = float(np.sum(shell_weights)) + if shell_sum <= 0.0: + shell_weights = np.ones(len(shell_indices), dtype=np.float32) / float( + len(shell_indices) + ) + else: + shell_weights = shell_weights / shell_sum + normalized_weights[shell_indices] = shell_weights + weights_to_use = normalized_weights + + all_peak_energies.extend(energies) + all_peak_weights.extend(float(weight) for weight in weights_to_use) + all_peak_element_indices.extend([elem_idx] * len(energies)) + for shell_label in shell_labels: + key = (elem_idx, shell_label) + if key not in shell_group_lookup: + shell_group_lookup[key] = len(shell_group_names) + shell_group_names.append(f"{elem} {shell_label}") + shell_group_element_indices.append(elem_idx) + shell_group_shell_labels.append(shell_label) + all_peak_shell_group_indices.append(shell_group_lookup[key]) + + self.register_buffer( + "peak_energies", + torch.tensor( + all_peak_energies, + dtype=self.energy_axis.dtype, + device=self.energy_axis.device, + ), + ) + self.register_buffer( + "peak_weights", + torch.tensor( + all_peak_weights, + dtype=self.energy_axis.dtype, + device=self.energy_axis.device, + ), + ) + self.register_buffer( + "peak_element_indices", + torch.tensor( + all_peak_element_indices, + dtype=torch.long, + device=self.energy_axis.device, + ), + ) + self.register_buffer( + "peak_shell_group_indices", + torch.tensor( + all_peak_shell_group_indices, + dtype=torch.long, + device=self.energy_axis.device, + ), + ) + self.register_buffer( + "shell_group_element_indices", + torch.tensor( + shell_group_element_indices, + dtype=torch.long, + device=self.energy_axis.device, + ), + ) + self.shell_group_names = shell_group_names + self.shell_group_shell_labels = shell_group_shell_labels + + self.n_peaks = len(all_peak_energies) + init_fwhm = torch.tensor( + peak_width, + dtype=self.energy_axis.dtype, + device=self.energy_axis.device, + ) + self.peak_width_by_peak = nn.Parameter( + inverse_softplus(init_fwhm) + * torch.ones( + self.n_peaks, + dtype=self.energy_axis.dtype, + device=self.energy_axis.device, + ) + ) + + n_shell_groups = len(shell_group_names) + print( + f"Fitting {n_elements} elements with {self.n_peaks} total peaks " + f"across {n_shell_groups} edge groups" + ) + + concentration_size = len(shell_group_names) + if concentration_size > 0: + init_concentration = torch.ones( + concentration_size, + dtype=self.energy_axis.dtype, + device=self.energy_axis.device, + ) + concentration_init_logits = inverse_softplus(init_concentration) + else: + concentration_init_logits = torch.ones( + concentration_size, + dtype=self.energy_axis.dtype, + device=self.energy_axis.device, + ) + + self.concentrations = nn.Parameter(concentration_init_logits) + + @staticmethod + def _line_shell_label(line_name: str) -> str: + text = str(line_name).strip() + for char in text: + if char.isalpha(): + return char.upper() + return "Other" + + def forward(self): + centers = self.peak_energies.unsqueeze(1) + energies = self.energy_axis.unsqueeze(0) + + fwhm = nn.functional.softplus(self.peak_width_by_peak) + sigma = (fwhm / 2.355).unsqueeze(1) + + all_peaks = torch.exp(-0.5 * ((energies - centers) / sigma) ** 2) + + sqrt_2pi = torch.sqrt( + torch.tensor( + 2 * np.pi, + dtype=all_peaks.dtype, + device=all_peaks.device, + ) + ) + all_peaks = all_peaks * self.energy_step / (sqrt_2pi * sigma) + + concentration_lookup = self.peak_shell_group_indices + peak_concentrations = nn.functional.softplus(self.concentrations[concentration_lookup]) + weighted_peaks = all_peaks * (peak_concentrations * self.peak_weights).unsqueeze(1) + + spectrum = weighted_peaks.sum(dim=0) + + return spectrum + + +class PolynomialBackground(nn.Module): + """Polynomial background model""" + + def __init__(self, energy_axis, degree=3): + super().__init__() + energy_axis_tensor = ( + energy_axis.float() + if torch.is_tensor(energy_axis) + else torch.tensor(energy_axis, dtype=torch.float32) + ) + self.register_buffer("energy_axis", energy_axis_tensor) + self.degree = degree + + energy_norm = (self.energy_axis - self.energy_axis.min()) / ( + self.energy_axis.max() - self.energy_axis.min() + ) + self.register_buffer("energy_norm", energy_norm) + + self.coeffs = nn.Parameter( + torch.randn( + degree + 1, + dtype=self.energy_axis.dtype, + device=self.energy_axis.device, + ) + * 0.1 + ) + + def forward(self): + background = torch.zeros_like(self.energy_axis) + for i, coeff in enumerate(self.coeffs): + background += coeff * (self.energy_norm**i) + return background + + +class ExponentialBackground(nn.Module): + """Exponential background for bremsstrahlung""" + + def __init__(self, energy_axis): + super().__init__() + energy_axis_tensor = ( + energy_axis.float() + if torch.is_tensor(energy_axis) + else torch.tensor(energy_axis, dtype=torch.float32) + ) + self.register_buffer("energy_axis", energy_axis_tensor) + dtype = self.energy_axis.dtype + device = self.energy_axis.device + + self.amplitude = nn.Parameter(torch.tensor(1.0, dtype=dtype, device=device)) + self.decay = nn.Parameter(torch.tensor(0.5, dtype=dtype, device=device)) + self.offset = nn.Parameter(torch.tensor(0.1, dtype=dtype, device=device)) + + def forward(self): + return self.amplitude * torch.exp(-self.decay * self.energy_axis) + self.offset diff --git a/src/quantem/spectroscopy/utils.py b/src/quantem/spectroscopy/utils.py new file mode 100644 index 00000000..3ca1712f --- /dev/null +++ b/src/quantem/spectroscopy/utils.py @@ -0,0 +1,110 @@ +import csv +from pathlib import Path +from typing import Optional, Union + + +def _parse_float(row: dict[str, str], keys: tuple[str, ...]) -> Optional[float]: + for key in keys: + value = row.get(key) + if value is None: + continue + text = str(value).strip() + if not text: + continue + try: + return float(text) + except ValueError: + continue + return None + + +def load_xray_lines_database(path: Union[Path, str]) -> dict[str, dict[str, dict[str, float]]]: + """Load X-ray lines CSV into the legacy element->line metadata mapping.""" + elements: dict[str, dict[str, dict[str, float]]] = {} + duplicate_counts: dict[tuple[str, str], int] = {} + + with open(path, "r", encoding="utf-8", newline="") as f: + reader = csv.DictReader(f) + for row in reader: + element = str(row.get("element", "")).strip() + line_name = str(row.get("line", "")).strip() + if not element or not line_name: + continue + + energy_kev = _parse_float(row, ("energy_keV", "energy (keV)", "energy")) + if energy_kev is None: + energy_ev = _parse_float(row, ("energy_eV", "energy (eV)")) + if energy_ev is None: + continue + energy_kev = energy_ev / 1000.0 + + # Use the normalized CSV column as the X-ray line weight. + weight = _parse_float(row, ("col4_norm", "weight", "relative_intensity")) + if weight is None: + weight = 0.0 + + element_lines = elements.setdefault(element, {}) + key = (element, line_name) + if line_name in element_lines: + duplicate_counts[key] = duplicate_counts.get(key, 1) + 1 + line_name = f"{line_name}__{duplicate_counts[key]}" + + element_lines[line_name] = { + "energy (keV)": float(energy_kev), + "weight": float(weight), + } + + return elements + + +def load_eels_edges_database(path: Union[Path, str]) -> dict[str, dict[str, dict[str, object]]]: + """Load EELS edge CSV into the legacy element->edge metadata mapping.""" + elements: dict[str, dict[str, dict[str, object]]] = {} + duplicate_counts: dict[tuple[str, str], int] = {} + + with open(path, "r", encoding="utf-8", newline="") as f: + reader = csv.DictReader(f) + fieldnames = set(reader.fieldnames or []) + required_columns = ("symbol", "edge_label", "edge_energy_eV") + missing_columns = [column for column in required_columns if column not in fieldnames] + if missing_columns: + raise ValueError( + f"{path} is missing required EELS edge columns: {', '.join(missing_columns)}" + ) + + for row in reader: + element_symbol = str(row.get("symbol", "")).strip() + if not element_symbol: + continue + + energy_ev = _parse_float(row, ("edge_energy_eV", "onset_energy (eV)", "energy_eV")) + if energy_ev is None: + continue + + edge_label = str(row.get("edge_label", "")).strip() + element_edges = elements.setdefault(element_symbol, {}) + edge_name = f"{energy_ev:g} eV" + key = (element_symbol, edge_name) + if edge_name in element_edges: + duplicate_counts[key] = duplicate_counts.get(key, 1) + 1 + edge_name = f"{edge_name}__{duplicate_counts[key]}" + + edge_info: dict[str, object] = { + "onset_energy (eV)": float(energy_ev), + } + if edge_label: + edge_info["edge_label"] = edge_label + + atomic_number = _parse_float(row, ("atomic_number",)) + if atomic_number is not None: + edge_info["atomic_number"] = ( + int(atomic_number) if atomic_number.is_integer() else float(atomic_number) + ) + + element_name = str(row.get("element", "")).strip() + if element_name: + edge_info["element"] = element_name + + element_edges[edge_name] = edge_info + + return elements diff --git a/src/quantem/spectroscopy/x_ray_lines.csv b/src/quantem/spectroscopy/x_ray_lines.csv new file mode 100644 index 00000000..e0fd039a --- /dev/null +++ b/src/quantem/spectroscopy/x_ray_lines.csv @@ -0,0 +1,702 @@ +energy_eV,atomic_number,element,line,relative_intensity,weight +2633.7,47,Ag,Ll,4,0.021053 +2978.2,47,Ag,La2,11,0.057895 +2984.3,47,Ag,La1,100,0.526316 +3150.9,47,Ag,Lb1,56,0.294737 +3347.8,47,Ag,"Lb2,15",13,0.068421 +3519.6,47,Ag,Lg1,6,0.031579 +21990.3,47,Ag,Ka2,53,0.291209 +22162.9,47,Ag,Ka1,100,0.549451 +24911.5,47,Ag,Kb3,9,0.049451 +24942.4,47,Ag,Kb1,16,0.087912 +25456.4,47,Ag,Kb2,4,0.021978 +1486.3,13,Al,Ka2,50,0.331126 +1486.7,13,Al,Ka1,100,0.662252 +1557.4,13,Al,Kb1,1,0.006623 +2955.6,18,Ar,Ka2,50,0.3125 +2957.7,18,Ar,Ka1,100,0.625 +3190.5,18,Ar,"Kb1,3",10,0.0625 +1120,33,As,Ll,6,0.033898 +1282,33,As,"La1,2",111,0.627119 +1317,33,As,Lb1,60,0.338983 +10508,33,As,Ka2,51,0.298246 +10543.7,33,As,Ka1,100,0.584795 +11720.3,33,As,Kb3,6,0.035088 +11726.2,33,As,Kb1,13,0.076023 +11864,33,As,Kb2,1,0.005848 +2122.9,79,Au,Ma1,100,1 +8493.9,79,Au,Ll,5,0.022831 +9628,79,Au,La2,11,0.050228 +9713.3,79,Au,La1,100,0.456621 +11442.3,79,Au,Lb1,67,0.305936 +11584.7,79,Au,Lb2,23,0.105023 +13381.7,79,Au,Lg1,13,0.059361 +66989.5,79,Au,Ka2,59,0.292079 +68803.7,79,Au,Ka1,100,0.49505 +77580,79,Au,Kb3,12,0.059406 +77984,79,Au,Kb1,23,0.113861 +80150,79,Au,Kb2,8,0.039604 +183.3,5,B,"Ka1,2",151,1 +3954.1,56,Ba,Ll,4,0.019608 +4450.9,56,Ba,La2,11,0.053922 +4466.3,56,Ba,La1,100,0.490196 +4827.5,56,Ba,Lb1,60,0.294118 +5156.5,56,Ba,"Lb2,15",20,0.098039 +5531.1,56,Ba,Lg1,9,0.044118 +31817.1,56,Ba,Ka2,54,0.287234 +32193.6,56,Ba,Ka1,100,0.531915 +36304,56,Ba,Kb3,10,0.053191 +36378.2,56,Ba,Kb1,18,0.095745 +37257,56,Ba,Kb2,6,0.031915 +108.5,4,Be,"Ka1,2",150,1 +2422.6,83,Bi,Ma1,100,1 +9420.4,83,Bi,Ll,6,0.026906 +10730.9,83,Bi,La2,11,0.049327 +10838.8,83,Bi,La1,100,0.44843 +12979.9,83,Bi,Lb2,25,0.112108 +13023.5,83,Bi,Lb1,67,0.300448 +15247.7,83,Bi,Lg1,14,0.06278 +74814.8,83,Bi,Ka2,60,0.294118 +77107.9,83,Bi,Ka1,100,0.490196 +86834,83,Bi,Kb3,12,0.058824 +87343,83,Bi,Kb1,23,0.112745 +89830,83,Bi,Kb2,9,0.044118 +1293.5,35,Br,Ll,5,0.028571 +1480.4,35,Br,"La1,2",111,0.634286 +1525.9,35,Br,Lb1,59,0.337143 +11877.6,35,Br,Ka2,52,0.298851 +11924.2,35,Br,Ka1,100,0.574713 +13284.5,35,Br,Kb3,7,0.04023 +13291.4,35,Br,Kb1,14,0.08046 +13469.5,35,Br,Kb2,1,0.005747 +277,6,C,"Ka1,2",147,1 +3688.1,20,Ca,Ka2,50,0.306748 +3691.7,20,Ca,Ka1,100,0.613497 +4012.7,20,Ca,"Kb1,3",13,0.079755 +2767.4,48,Cd,Ll,4,0.020619 +3126.9,48,Cd,La2,11,0.056701 +3133.7,48,Cd,La1,100,0.515464 +3316.6,48,Cd,Lb1,58,0.298969 +3528.1,48,Cd,"Lb2,15",15,0.07732 +3716.9,48,Cd,Lg1,6,0.030928 +22984.1,48,Cd,Ka2,53,0.289617 +23173.6,48,Cd,Ka1,100,0.546448 +26061.2,48,Cd,Kb3,9,0.04918 +26095.5,48,Cd,Kb1,17,0.092896 +26643.8,48,Cd,Kb2,4,0.021858 +883,58,Ce,Ma1,100,1 +4287.5,58,Ce,Ll,4,0.019417 +4823,58,Ce,La2,11,0.053398 +4840.2,58,Ce,La1,100,0.485437 +5262.2,58,Ce,Lb1,61,0.296117 +5613.4,58,Ce,"Lb2,15",21,0.101942 +6052,58,Ce,Lg1,9,0.043689 +34278.9,58,Ce,Ka2,55,0.289474 +34719.7,58,Ce,Ka1,100,0.526316 +39170.1,58,Ce,Kb3,10,0.052632 +39257.3,58,Ce,Kb1,19,0.1 +40233,58,Ce,Kb2,6,0.031579 +2620.8,17,Cl,Ka2,50,0.320513 +2622.4,17,Cl,Ka1,100,0.641026 +2815.6,17,Cl,Kb1,6,0.038462 +677.8,27,Co,Ll,10,0.050761 +776.2,27,Co,"La1,2",111,0.563452 +791.4,27,Co,Lb1,76,0.385787 +6915.3,27,Co,Ka2,51,0.303571 +6930.3,27,Co,Ka1,100,0.595238 +7649.4,27,Co,"Kb1,3",17,0.10119 +500.3,24,Cr,Ll,17,0.082126 +572.8,24,Cr,"La1,2",111,0.536232 +582.8,24,Cr,Lb1,79,0.381643 +5405.5,24,Cr,Ka2,50,0.30303 +5414.7,24,Cr,Ka1,100,0.606061 +5946.7,24,Cr,"Kb1,3",15,0.090909 +3795,55,Cs,Ll,4,0.019608 +4272.2,55,Cs,La2,11,0.053922 +4286.5,55,Cs,La1,100,0.490196 +4619.8,55,Cs,Lb1,61,0.29902 +4935.9,55,Cs,"Lb2,15",20,0.098039 +5280.4,55,Cs,Lg1,8,0.039216 +30625.1,55,Cs,Ka2,54,0.28877 +30972.8,55,Cs,Ka1,100,0.534759 +34919.4,55,Cs,Kb3,9,0.048128 +34986.9,55,Cs,Kb1,18,0.096257 +35822,55,Cs,Kb2,6,0.032086 +811.1,29,Cu,Ll,8,0.043478 +929.7,29,Cu,"La1,2",111,0.603261 +949.8,29,Cu,Lb1,65,0.353261 +8027.8,29,Cu,Ka2,51,0.303571 +8047.8,29,Cu,Ka1,100,0.595238 +8905.3,29,Cu,"Kb1,3",17,0.10119 +1293,66,Dy,Ma1,100,1 +5743.1,66,Dy,Ll,4,0.019231 +6457.7,66,Dy,La2,11,0.052885 +6495.2,66,Dy,La1,100,0.480769 +7247.7,66,Dy,Lb1,62,0.298077 +7635.7,66,Dy,Lb2,20,0.096154 +8418.8,66,Dy,Lg1,11,0.052885 +45207.8,66,Dy,Ka2,56,0.290155 +45998.4,66,Dy,Ka1,100,0.518135 +51957,66,Dy,Kb3,10,0.051813 +52119,66,Dy,Kb1,20,0.103627 +53476,66,Dy,Kb2,7,0.036269 +1406,68,Er,Ma1,100,1 +6152,68,Er,Ll,4,0.019048 +6905,68,Er,La2,11,0.052381 +6948.7,68,Er,La1,100,0.47619 +7810.9,68,Er,Lb1,64,0.304762 +8189,68,Er,"Lb2,15",20,0.095238 +9089,68,Er,Lg1,11,0.052381 +48221.1,68,Er,Ka2,56,0.287179 +49127.7,68,Er,Ka1,100,0.512821 +55494,68,Er,Kb3,11,0.05641 +55681,68,Er,Kb1,21,0.107692 +57210,68,Er,Kb2,7,0.035897 +1131,63,Eu,Ma1,100,1 +5177.2,63,Eu,Ll,4,0.019231 +5816.6,63,Eu,La2,11,0.052885 +5845.7,63,Eu,La1,100,0.480769 +6456.4,63,Eu,Lb1,62,0.298077 +6843.2,63,Eu,"Lb2,15",21,0.100962 +7480.3,63,Eu,Lg1,10,0.048077 +40901.9,63,Eu,Ka2,56,0.293194 +41542.2,63,Eu,Ka1,100,0.52356 +46903.6,63,Eu,Kb3,10,0.052356 +47037.9,63,Eu,Kb1,19,0.099476 +48256,63,Eu,Kb2,6,0.031414 +676.8,9,F,"Ka1,2",148,1 +615.2,26,Fe,Ll,10,0.053476 +705,26,Fe,"La1,2",111,0.593583 +718.5,26,Fe,Lb1,66,0.352941 +6390.8,26,Fe,Ka2,50,0.299401 +6403.8,26,Fe,Ka1,100,0.598802 +7058,26,Fe,"Kb1,3",17,0.101796 +957.2,31,Ga,Ll,7,0.038043 +1097.9,31,Ga,"La1,2",111,0.603261 +1124.8,31,Ga,Lb1,66,0.358696 +9224.8,31,Ga,Ka2,51,0.22973 +9251.7,31,Ga,Ka1,100,0.45045 +10260.3,31,Ga,Kb3,5,0.022523 +10264.2,31,Ga,Kb1,66,0.297297 +1185,64,Gd,Ma1,100,1 +5362.1,64,Gd,Ll,4,0.019139 +6025,64,Gd,La2,11,0.052632 +6057.2,64,Gd,La1,100,0.478469 +6713.2,64,Gd,Lb1,62,0.296651 +7102.8,64,Gd,"Lb2,15",21,0.100478 +7785.8,64,Gd,Lg1,11,0.052632 +42308.9,64,Gd,Ka2,56,0.290155 +42996.2,64,Gd,Ka1,100,0.518135 +48555,64,Gd,Kb3,10,0.051813 +48697,64,Gd,Kb1,20,0.103627 +49959,64,Gd,Kb2,7,0.036269 +1036.2,32,Ge,Ll,6,0.033898 +1188,32,Ge,"La1,2",111,0.627119 +1218.5,32,Ge,Lb1,60,0.338983 +9855.3,32,Ge,Ka2,51,0.235023 +9886.4,32,Ge,Ka1,100,0.460829 +10978,32,Ge,Kb3,6,0.02765 +10982.1,32,Ge,Kb1,60,0.276498 +1644.6,72,Hf,Ma1,100,1 +6959.6,72,Hf,Ll,5,0.023256 +7844.6,72,Hf,La2,11,0.051163 +7899,72,Hf,La1,100,0.465116 +9022.7,72,Hf,Lb1,67,0.311628 +9347.3,72,Hf,Lb2,20,0.093023 +10515.8,72,Hf,Lg1,12,0.055814 +54611.4,72,Hf,Ka2,57,0.28934 +55790.2,72,Hf,Ka1,100,0.507614 +62980,72,Hf,Kb3,11,0.055838 +63234,72,Hf,Kb1,22,0.111675 +64980,72,Hf,Kb2,7,0.035533 +2195.3,80,Hg,Ma1,100,1 +8721,80,Hg,Ll,5,0.022624 +9897.6,80,Hg,La2,11,0.049774 +9988.8,80,Hg,La1,100,0.452489 +11822.6,80,Hg,Lb1,67,0.303167 +11924.1,80,Hg,Lb2,24,0.108597 +13830.1,80,Hg,Lg1,14,0.063348 +68895,80,Hg,Ka2,59,0.292079 +70819,80,Hg,Ka1,100,0.49505 +79822,80,Hg,Kb3,12,0.059406 +80253,80,Hg,Kb1,23,0.113861 +82515,80,Hg,Kb2,8,0.039604 +1348,67,Ho,Ma1,100,1 +5943.4,67,Ho,Ll,4,0.019048 +6679.5,67,Ho,La2,11,0.052381 +6719.8,67,Ho,La1,100,0.47619 +7525.3,67,Ho,Lb1,64,0.304762 +7911,67,Ho,"Lb2,15",20,0.095238 +8747,67,Ho,Lg1,11,0.052381 +46699.7,67,Ho,Ka2,56,0.28866 +47546.7,67,Ho,Ka1,100,0.515464 +53711,67,Ho,Kb3,11,0.056701 +53877,67,Ho,Kb1,20,0.103093 +55293,67,Ho,Kb2,7,0.036082 +3485,53,I,Ll,4,0.019704 +3926,53,I,La2,11,0.054187 +3937.6,53,I,La1,100,0.492611 +4220.7,53,I,Lb1,61,0.300493 +4507.5,53,I,"Lb2,15",19,0.093596 +4800.9,53,I,Lg1,8,0.039409 +28317.2,53,I,Ka2,54,0.290323 +28612,53,I,Ka1,100,0.537634 +32239.4,53,I,Kb3,9,0.048387 +32294.7,53,I,Kb1,18,0.096774 +33042,53,I,Kb2,5,0.026882 +2904.4,49,In,Ll,4,0.020619 +3279.3,49,In,La2,11,0.056701 +3286.9,49,In,La1,100,0.515464 +3487.2,49,In,Lb1,58,0.298969 +3713.8,49,In,"Lb2,15",15,0.07732 +3920.8,49,In,Lg1,6,0.030928 +24002,49,In,Ka2,53,0.288043 +24209.7,49,In,Ka1,100,0.543478 +27237.7,49,In,Kb3,9,0.048913 +27275.9,49,In,Kb1,17,0.092391 +27860.8,49,In,Kb2,5,0.027174 +1979.9,77,Ir,Ma1,100,1 +8045.8,77,Ir,Ll,5,0.023041 +9099.5,77,Ir,La2,11,0.050691 +9175.1,77,Ir,La1,100,0.460829 +10708.3,77,Ir,Lb1,66,0.304147 +10920.3,77,Ir,Lb2,22,0.101382 +12512.6,77,Ir,Lg1,13,0.059908 +63286.7,77,Ir,Ka2,58,0.288557 +64895.6,77,Ir,Ka1,100,0.497512 +73202.7,77,Ir,Kb3,12,0.059701 +73560.8,77,Ir,Kb1,23,0.114428 +75575,77,Ir,Kb2,8,0.039801 +3311.1,19,K,Ka2,50,0.310559 +3313.8,19,K,Ka1,100,0.621118 +3589.6,19,K,"Kb1,3",11,0.068323 +1386,36,Kr,Ll,5,0.028902 +1586,36,Kr,"La1,2",111,0.641618 +1636.6,36,Kr,Lb1,57,0.32948 +12598,36,Kr,Ka2,52,0.297143 +12649,36,Kr,Ka1,100,0.571429 +14104,36,Kr,Kb3,7,0.04 +14112,36,Kr,Kb1,14,0.08 +14315,36,Kr,Kb2,2,0.011429 +833,57,La,Ma1,100,1 +4124,57,La,Ll,4,0.019512 +4634.2,57,La,La2,11,0.053659 +4651,57,La,La1,100,0.487805 +5042.1,57,La,Lb1,60,0.292683 +5383.5,57,La,"Lb2,15",21,0.102439 +5788.5,57,La,Lg1,9,0.043902 +33034.1,57,La,Ka2,54,0.285714 +33441.8,57,La,Ka1,100,0.529101 +37720.2,57,La,Kb3,10,0.05291 +37801,57,La,Kb1,19,0.100529 +38729.9,57,La,Kb2,6,0.031746 +54.3,3,Li,"Ka1,2",150,1 +1581.3,71,Lu,Ma1,100,1 +6752.8,71,Lu,Ll,4,0.018868 +7604.9,71,Lu,La2,11,0.051887 +7655.5,71,Lu,La1,100,0.471698 +8709,71,Lu,Lb1,66,0.311321 +9048.9,71,Lu,Lb2,19,0.089623 +10143.4,71,Lu,Lg1,12,0.056604 +52965,71,Lu,Ka2,57,0.290816 +54069.8,71,Lu,Ka1,100,0.510204 +61050,71,Lu,Kb3,11,0.056122 +61283,71,Lu,Kb1,21,0.107143 +62970,71,Lu,Kb2,7,0.035714 +1253.6,12,Mg,"Ka1,2",150,1 +556.3,25,Mn,Ll,15,0.073892 +637.4,25,Mn,"La1,2",111,0.546798 +648.8,25,Mn,Lb1,77,0.37931 +5887.6,25,Mn,Ka2,50,0.299401 +5898.8,25,Mn,Ka1,100,0.598802 +6490.4,25,Mn,"Kb1,3",17,0.101796 +2015.7,42,Mo,Ll,5,0.028249 +2289.8,42,Mo,La2,11,0.062147 +2293.2,42,Mo,La1,100,0.564972 +2394.8,42,Mo,Lb1,53,0.299435 +2518.3,42,Mo,"Lb2,15",5,0.028249 +2623.5,42,Mo,Lg1,3,0.016949 +17374.3,42,Mo,Ka2,52,0.292135 +17479.3,42,Mo,Ka1,100,0.561798 +19590.3,42,Mo,Kb3,8,0.044944 +19608.3,42,Mo,Kb1,15,0.08427 +19965.2,42,Mo,Kb2,3,0.016854 +392.4,7,N,"Ka1,2",150,1 +1041,11,Na,"Ka1,2",150,1 +1902.2,41,Nb,Ll,5,0.028902 +2163,41,Nb,La2,11,0.063584 +2165.9,41,Nb,La1,100,0.578035 +2257.4,41,Nb,Lb1,52,0.300578 +2367,41,Nb,"Lb2,15",3,0.017341 +2461.8,41,Nb,Lg1,2,0.011561 +16521,41,Nb,Ka2,52,0.292135 +16615.1,41,Nb,Ka1,100,0.561798 +18606.3,41,Nb,Kb3,8,0.044944 +18622.5,41,Nb,Kb1,15,0.08427 +18953,41,Nb,Kb2,3,0.016854 +978,60,Nd,Ma1,100,1 +4633,60,Nd,Ll,4,0.019417 +5207.7,60,Nd,La2,11,0.053398 +5230.4,60,Nd,La1,100,0.485437 +5721.6,60,Nd,Lb1,60,0.291262 +6089.4,60,Nd,"Lb2,15",21,0.101942 +6602.1,60,Nd,Lg1,10,0.048544 +36847.4,60,Nd,Ka2,55,0.289474 +37361,60,Nd,Ka1,100,0.526316 +42166.5,60,Nd,Kb3,10,0.052632 +42271.3,60,Nd,Kb1,19,0.1 +43335,60,Nd,Kb2,6,0.031579 +848.6,10,Ne,"Ka1,2",150,1 +742.7,28,Ni,Ll,9,0.047872 +851.5,28,Ni,"La1,2",111,0.590426 +868.8,28,Ni,Lb1,68,0.361702 +7460.9,28,Ni,Ka2,51,0.303571 +7478.2,28,Ni,Ka1,100,0.595238 +8264.7,28,Ni,"Kb1,3",17,0.10119 +524.9,8,O,"Ka1,2",151,1 +1910.2,76,Os,Ma1,100,1 +7822.2,76,Os,Ll,5,0.022936 +8841,76,Os,La2,11,0.050459 +8911.7,76,Os,La1,100,0.458716 +10355.3,76,Os,Lb1,67,0.307339 +10598.5,76,Os,Lb2,22,0.100917 +12095.3,76,Os,Lg1,13,0.059633 +61486.7,76,Os,Ka2,58,0.288557 +63000.5,76,Os,Ka1,100,0.497512 +71077,76,Os,Kb3,12,0.059701 +71413,76,Os,Kb1,23,0.114428 +73363,76,Os,Kb2,8,0.039801 +2012.7,15,P,Ka2,50,0.326797 +2013.7,15,P,Ka1,100,0.653595 +2139.1,15,P,Kb1,3,0.019608 +2345.5,82,Pb,Ma1,100,1 +9184.5,82,Pb,Ll,6,0.027027 +10449.5,82,Pb,La2,11,0.04955 +10551.5,82,Pb,La1,100,0.45045 +12613.7,82,Pb,Lb1,66,0.297297 +12622.6,82,Pb,Lb2,25,0.112613 +14764.4,82,Pb,Lg1,14,0.063063 +72804.2,82,Pb,Ka2,60,0.295567 +74969.4,82,Pb,Ka1,100,0.492611 +84450,82,Pb,Kb3,12,0.059113 +84936,82,Pb,Kb1,23,0.1133 +87320,82,Pb,Kb2,8,0.039409 +2503.4,46,Pd,Ll,4,0.021505 +2833.3,46,Pd,La2,11,0.05914 +2838.6,46,Pd,La1,100,0.537634 +2990.2,46,Pd,Lb1,53,0.284946 +3171.8,46,Pd,"Lb2,15",12,0.064516 +3328.7,46,Pd,Lg1,6,0.032258 +21020.1,46,Pd,Ka2,53,0.292818 +21177.1,46,Pd,Ka1,100,0.552486 +23791.1,46,Pd,Kb3,8,0.044199 +23818.7,46,Pd,Kb1,16,0.088398 +24299.1,46,Pd,Kb2,4,0.022099 +4809,61,Pm,Ll,4,0.019324 +5408,61,Pm,La2,11,0.05314 +5432,61,Pm,La1,100,0.483092 +5961,61,Pm,Lb1,61,0.294686 +6339,61,Pm,Lb2,21,0.101449 +6892,61,Pm,Lg1,10,0.048309 +38171.2,61,Pm,Ka2,55,0.289474 +38724.7,61,Pm,Ka1,100,0.526316 +43713,61,Pm,Kb3,10,0.052632 +43826,61,Pm,Kb1,19,0.1 +44942,61,Pm,Kb2,6,0.031579 +929.2,59,Pr,Ma1,100,1 +4453.2,59,Pr,Ll,4,0.019417 +5013.5,59,Pr,La2,11,0.053398 +5033.7,59,Pr,La1,100,0.485437 +5488.9,59,Pr,Lb1,61,0.296117 +5850,59,Pr,"Lb2,15",21,0.101942 +6322.1,59,Pr,Lg1,9,0.043689 +35550.2,59,Pr,Ka2,55,0.289474 +36026.3,59,Pr,Ka1,100,0.526316 +40652.9,59,Pr,Kb3,10,0.052632 +40748.2,59,Pr,Kb1,19,0.1 +41773,59,Pr,Kb2,6,0.031579 +2050.5,78,Pt,Ma1,100,1 +8268,78,Pt,Ll,5,0.022831 +9361.8,78,Pt,La2,11,0.050228 +9442.3,78,Pt,La1,100,0.456621 +11070.7,78,Pt,Lb1,67,0.305936 +11250.5,78,Pt,Lb2,23,0.105023 +12942,78,Pt,Lg1,13,0.059361 +65112,78,Pt,Ka2,58,0.288557 +66832,78,Pt,Ka1,100,0.497512 +75368,78,Pt,Kb3,12,0.059701 +75748,78,Pt,Kb1,23,0.114428 +77850,78,Pt,Kb2,8,0.039801 +1482.4,37,Rb,Ll,5,0.028736 +1692.6,37,Rb,La2,11,0.063218 +1694.1,37,Rb,La1,100,0.574713 +1752.2,37,Rb,Lb1,58,0.333333 +13335.8,37,Rb,Ka2,52,0.297143 +13395.3,37,Rb,Ka1,100,0.571429 +14951.7,37,Rb,Kb3,7,0.04 +14961.3,37,Rb,Kb1,14,0.08 +15185,37,Rb,Kb2,2,0.011429 +1842.5,75,Re,Ma1,100,1 +7603.6,75,Re,Ll,5,0.023041 +8586.2,75,Re,La2,11,0.050691 +8652.5,75,Re,La1,100,0.460829 +10010,75,Re,Lb1,66,0.304147 +10275.2,75,Re,Lb2,22,0.101382 +11685.4,75,Re,Lg1,13,0.059908 +59717.9,75,Re,Ka2,58,0.29 +61140.3,75,Re,Ka1,100,0.5 +68994,75,Re,Kb3,12,0.06 +69310,75,Re,Kb1,22,0.11 +71232,75,Re,Kb2,8,0.04 +2376.5,45,Rh,Ll,4,0.021978 +2692,45,Rh,La2,11,0.06044 +2696.7,45,Rh,La1,100,0.549451 +2834.4,45,Rh,Lb1,52,0.285714 +3001.3,45,Rh,"Lb2,15",10,0.054945 +3143.8,45,Rh,Lg1,5,0.027473 +20073.7,45,Rh,Ka2,53,0.292818 +20216.1,45,Rh,Ka1,100,0.552486 +22698.9,45,Rh,Kb3,8,0.044199 +22723.6,45,Rh,Kb1,16,0.088398 +23172.8,45,Rh,Kb2,4,0.022099 +2252.8,44,Ru,Ll,4,0.021858 +2554.3,44,Ru,La2,11,0.060109 +2558.6,44,Ru,La1,100,0.546448 +2683.2,44,Ru,Lb1,54,0.295082 +2836,44,Ru,"Lb2,15",10,0.054645 +2964.5,44,Ru,Lg1,4,0.021858 +19150.4,44,Ru,Ka2,53,0.292818 +19279.2,44,Ru,Ka1,100,0.552486 +21634.6,44,Ru,Kb3,8,0.044199 +21656.8,44,Ru,Kb1,16,0.088398 +22074,44,Ru,Kb2,4,0.022099 +2306.6,16,S,Ka2,50,0.322581 +2307.8,16,S,Ka1,100,0.645161 +2464,16,S,Kb1,5,0.032258 +3188.6,51,Sb,Ll,4,0.0199 +3595.3,51,Sb,La2,11,0.054726 +3604.7,51,Sb,La1,100,0.497512 +3843.6,51,Sb,Lb1,61,0.303483 +4100.8,51,Sb,"Lb2,15",17,0.084577 +4347.8,51,Sb,Lg1,8,0.039801 +26110.8,51,Sb,Ka2,54,0.290323 +26359.1,51,Sb,Ka1,100,0.537634 +29679.2,51,Sb,Kb3,9,0.048387 +29725.6,51,Sb,Kb1,18,0.096774 +30389.5,51,Sb,Kb2,5,0.026882 +348.3,21,Sc,Ll,21,0.100478 +395.4,21,Sc,"La1,2",111,0.5311 +399.6,21,Sc,Lb1,77,0.368421 +4086.1,21,Sc,Ka2,50,0.30303 +4090.6,21,Sc,Ka1,100,0.606061 +4460.5,21,Sc,"Kb1,3",15,0.090909 +1204.4,34,Se,Ll,6,0.034091 +1379.1,34,Se,"La1,2",111,0.630682 +1419.2,34,Se,Lb1,59,0.335227 +11181.4,34,Se,Ka2,52,0.302326 +11222.4,34,Se,Ka1,100,0.581395 +12489.6,34,Se,Kb3,6,0.034884 +12495.9,34,Se,Kb1,13,0.075581 +12652,34,Se,Kb2,1,0.005814 +1739.4,14,Si,Ka2,50,0.328947 +1740,14,Si,Ka1,100,0.657895 +1835.9,14,Si,Kb1,2,0.013158 +1081,62,Sm,Ma1,100,1 +4994.5,62,Sm,Ll,4,0.019324 +5609,62,Sm,La2,11,0.05314 +5636.1,62,Sm,La1,100,0.483092 +6205.1,62,Sm,Lb1,61,0.294686 +6587,62,Sm,"Lb2,15",21,0.101449 +7178,62,Sm,Lg1,10,0.048309 +39522.4,62,Sm,Ka2,55,0.289474 +40118.1,62,Sm,Ka1,100,0.526316 +45289,62,Sm,Kb3,10,0.052632 +45413,62,Sm,Kb1,19,0.1 +46578,62,Sm,Kb2,6,0.031579 +3045,50,Sn,Ll,4,0.020202 +3435.4,50,Sn,La2,11,0.055556 +3444,50,Sn,La1,100,0.505051 +3662.8,50,Sn,Lb1,60,0.30303 +3904.9,50,Sn,"Lb2,15",16,0.080808 +4131.1,50,Sn,Lg1,7,0.035354 +25044,50,Sn,Ka2,53,0.288043 +25271.3,50,Sn,Ka1,100,0.543478 +28444,50,Sn,Kb3,9,0.048913 +28486,50,Sn,Kb1,17,0.092391 +29109.3,50,Sn,Kb2,5,0.027174 +1582.2,38,Sr,Ll,5,0.028736 +1804.7,38,Sr,La2,11,0.063218 +1806.6,38,Sr,La1,100,0.574713 +1871.7,38,Sr,Lb1,58,0.333333 +14097.9,38,Sr,Ka2,52,0.295455 +14165,38,Sr,Ka1,100,0.568182 +15824.9,38,Sr,Kb3,7,0.039773 +15835.7,38,Sr,Kb1,14,0.079545 +16084.6,38,Sr,Kb2,3,0.017045 +1709.6,73,Ta,Ma1,100,1 +7173.1,73,Ta,Ll,5,0.023256 +8087.9,73,Ta,La2,11,0.051163 +8146.1,73,Ta,La1,100,0.465116 +9343.1,73,Ta,Lb1,67,0.311628 +9651.8,73,Ta,Lb2,20,0.093023 +10895.2,73,Ta,Lg1,12,0.055814 +56277,73,Ta,Ka2,57,0.28934 +57532,73,Ta,Ka1,100,0.507614 +64948.8,73,Ta,Kb3,11,0.055838 +65223,73,Ta,Kb1,22,0.111675 +66990,73,Ta,Kb2,7,0.035533 +1240,65,Tb,Ma1,100,1 +5546.7,65,Tb,Ll,4,0.019231 +6238,65,Tb,La2,11,0.052885 +6272.8,65,Tb,La1,100,0.480769 +6978,65,Tb,Lb1,61,0.293269 +7366.7,65,Tb,"Lb2,15",21,0.100962 +8102,65,Tb,Lg1,11,0.052885 +43744.1,65,Tb,Ka2,56,0.290155 +44481.6,65,Tb,Ka1,100,0.518135 +50229,65,Tb,Kb3,10,0.051813 +50382,65,Tb,Kb1,20,0.103627 +51698,65,Tb,Kb2,7,0.036269 +2122,43,Tc,Ll,5,0.027778 +2420,43,Tc,La2,11,0.061111 +2424,43,Tc,La1,100,0.555556 +2538,43,Tc,Lb1,54,0.3 +2674,43,Tc,"Lb2,15",7,0.038889 +2792,43,Tc,Lg1,3,0.016667 +18250.8,43,Tc,Ka2,53,0.292818 +18367.1,43,Tc,Ka1,100,0.552486 +20599,43,Tc,Kb3,8,0.044199 +20619,43,Tc,Kb1,16,0.088398 +21005,43,Tc,Kb2,4,0.022099 +3335.6,52,Te,Ll,4,0.019802 +3758.8,52,Te,La2,11,0.054455 +3769.3,52,Te,La1,100,0.49505 +4029.6,52,Te,Lb1,61,0.30198 +4301.7,52,Te,"Lb2,15",18,0.089109 +4570.9,52,Te,Lg1,8,0.039604 +27201.7,52,Te,Ka2,54,0.290323 +27472.3,52,Te,Ka1,100,0.537634 +30944.3,52,Te,Kb3,9,0.048387 +30995.7,52,Te,Kb1,18,0.096774 +31700.4,52,Te,Kb2,5,0.026882 +2996.1,90,Th,Ma1,100,1 +11118.6,90,Th,Ll,6,0.026316 +12809.6,90,Th,La2,11,0.048246 +12968.7,90,Th,La1,100,0.438596 +15623.7,90,Th,Lb2,26,0.114035 +16202.2,90,Th,Lb1,69,0.302632 +18982.5,90,Th,Lg1,16,0.070175 +89953,90,Th,Ka2,62,0.299517 +93350,90,Th,Ka1,100,0.483092 +104831,90,Th,Kb3,12,0.057971 +105609,90,Th,Kb1,24,0.115942 +108640,90,Th,Kb2,9,0.043478 +395.3,22,Ti,Ll,46,0.194915 +452.2,22,Ti,"La1,2",111,0.470339 +458.4,22,Ti,Lb1,79,0.334746 +4504.9,22,Ti,Ka2,50,0.30303 +4510.8,22,Ti,Ka1,100,0.606061 +4931.8,22,Ti,"Kb1,3",15,0.090909 +2270.6,81,Tl,Ma1,100,1 +8953.2,81,Tl,Ll,6,0.026906 +10172.8,81,Tl,La2,11,0.049327 +10268.5,81,Tl,La1,100,0.44843 +12213.3,81,Tl,Lb1,67,0.300448 +12271.5,81,Tl,Lb2,25,0.112108 +14291.5,81,Tl,Lg1,14,0.06278 +70831.9,81,Tl,Ka2,60,0.295567 +72871.5,81,Tl,Ka1,100,0.492611 +82118,81,Tl,Kb3,12,0.059113 +82576,81,Tl,Kb1,23,0.1133 +84910,81,Tl,Kb2,8,0.039409 +1462,69,Tm,Ma1,100,1 +6341.9,69,Tm,Ll,4,0.018957 +7133.1,69,Tm,La2,11,0.052133 +7179.9,69,Tm,La1,100,0.473934 +8101,69,Tm,Lb1,64,0.303318 +8468,69,Tm,"Lb2,15",20,0.094787 +9426,69,Tm,Lg1,12,0.056872 +49772.6,69,Tm,Ka2,57,0.290816 +50741.6,69,Tm,Ka1,100,0.510204 +57304,69,Tm,Kb3,11,0.056122 +57517,69,Tm,Kb1,21,0.107143 +59090,69,Tm,Kb2,7,0.035714 +3170.8,92,U,Ma1,100,1 +11618.3,92,U,Ll,7,0.031818 +13438.8,92,U,La2,11,0.05 +13614.7,92,U,La1,100,0.454545 +16428.3,92,U,Lb2,26,0.118182 +17220,92,U,Lb1,61,0.277273 +20167.1,92,U,Lg1,15,0.068182 +94665,92,U,Ka2,62,0.298077 +98439,92,U,Ka1,100,0.480769 +110406,92,U,Kb3,13,0.0625 +111300,92,U,Kb1,24,0.115385 +114530,92,U,Kb2,9,0.043269 +446.5,23,V,Ll,28,0.127854 +511.3,23,V,"La1,2",111,0.506849 +519.2,23,V,Lb1,80,0.365297 +4944.6,23,V,Ka2,50,0.30303 +4952.2,23,V,Ka1,100,0.606061 +5427.3,23,V,"Kb1,3",15,0.090909 +1775.4,74,W,Ma1,100,1 +7387.8,74,W,Ll,5,0.023041 +8335.2,74,W,La2,11,0.050691 +8397.6,74,W,La1,100,0.460829 +9672.4,74,W,Lb1,67,0.308756 +9961.5,74,W,Lb2,21,0.096774 +11285.9,74,W,Lg1,13,0.059908 +57981.7,74,W,Ka2,58,0.291457 +59318.2,74,W,Ka1,100,0.502513 +66951.4,74,W,Kb3,11,0.055276 +67244.3,74,W,Kb1,22,0.110553 +69067,74,W,Kb2,8,0.040201 +3636,54,Xe,Ll,4,0.019704 +4093,54,Xe,La2,11,0.054187 +4109.9,54,Xe,La1,100,0.492611 +4414,54,Xe,Lb1,60,0.295567 +4714,54,Xe,"Lb2,15",20,0.098522 +5034,54,Xe,Lg1,8,0.039409 +29458,54,Xe,Ka2,54,0.290323 +29779,54,Xe,Ka1,100,0.537634 +33562,54,Xe,Kb3,9,0.048387 +33624,54,Xe,Kb1,18,0.096774 +34415,54,Xe,Kb2,5,0.026882 +1685.4,39,Y,Ll,5,0.028902 +1920.5,39,Y,La2,11,0.063584 +1922.6,39,Y,La1,100,0.578035 +1995.8,39,Y,Lb1,57,0.32948 +14882.9,39,Y,Ka2,52,0.292135 +14958.4,39,Y,Ka1,100,0.561798 +16725.8,39,Y,Kb3,8,0.044944 +16737.8,39,Y,Kb1,15,0.08427 +17015.4,39,Y,Kb2,3,0.016854 +1521.4,70,Yb,Ma1,100,1 +6545.5,70,Yb,Ll,4,0.018868 +7367.3,70,Yb,La2,11,0.051887 +7415.6,70,Yb,La1,100,0.471698 +8401.8,70,Yb,Lb1,65,0.306604 +8758.8,70,Yb,"Lb2,15",20,0.09434 +9780.1,70,Yb,Lg1,12,0.056604 +51354,70,Yb,Ka2,57,0.290816 +52388.9,70,Yb,Ka1,100,0.510204 +59140,70,Yb,Kb3,11,0.056122 +59370,70,Yb,Kb1,21,0.107143 +60980,70,Yb,Kb2,7,0.035714 +884,30,Zn,Ll,7,0.038251 +1011.7,30,Zn,"La1,2",111,0.606557 +1034.7,30,Zn,Lb1,65,0.355191 +8615.8,30,Zn,Ka2,51,0.303571 +8638.9,30,Zn,Ka1,100,0.595238 +9572,30,Zn,"Kb1,3",17,0.10119 +1792,40,Zr,Ll,5,0.028902 +2039.9,40,Zr,La2,11,0.063584 +2042.4,40,Zr,La1,100,0.578035 +2124.4,40,Zr,Lb1,54,0.312139 +2219.4,40,Zr,"Lb2,15",1,0.00578 +2302.7,40,Zr,Lg1,2,0.011561 +15690.9,40,Zr,Ka2,52,0.292135 +15775.1,40,Zr,Ka1,100,0.561798 +17654,40,Zr,Kb3,8,0.044944 +17667.8,40,Zr,Kb1,15,0.08427 +17970,40,Zr,Kb2,3,0.016854 \ No newline at end of file