From 09773908265e0a74a3e8b7ce2507f7188ddcb3d5 Mon Sep 17 00:00:00 2001 From: Nicholas Marchese Date: Fri, 5 Dec 2025 11:19:39 -0800 Subject: [PATCH 1/5] Added gaussian noise, option to add background to labels, and aperature settings. Implemented compatibility with multichannel labels. --- src/quantem/core/utils/augment_dp.py | 228 ++++++++++++++++++++++++--- 1 file changed, 208 insertions(+), 20 deletions(-) diff --git a/src/quantem/core/utils/augment_dp.py b/src/quantem/core/utils/augment_dp.py index b7b55f91..5bacc78b 100644 --- a/src/quantem/core/utils/augment_dp.py +++ b/src/quantem/core/utils/augment_dp.py @@ -20,17 +20,18 @@ ArrayLike = Union[np.ndarray, "torch.Tensor"] + # TODO # add dark background # add gaussian noise - - class DPAugmentor(RNGMixin): def __init__( self, add_bkg: bool = False, bkg_weight: list[float] | float = [0.001, 0.05], bkg_q: list[float] | float = [0.01, 0.1], + apply_background_to_label: bool = False, + background_label_application: list[bool] | None = None, add_shot: bool = False, e_dose: list[float] | float = [1e4, 1e7], add_shift: bool = False, @@ -41,6 +42,9 @@ def __init__( add_ellipticity_to_label: bool = True, add_salt_and_pepper: bool = False, salt_and_pepper: list[float] | float = [0, 5e-4], + add_gaussian_noise: bool = False, + gaussian_noise_mu: float = 1.0, + gaussian_noise_std: float = 0.5, add_scale: bool = False, scale_factor: list[float] | float = [0.9, 1.1], add_blur: bool = False, @@ -51,6 +55,9 @@ def __init__( log_file: os.PathLike | None = None, rng: np.random.Generator | int | None = None, device: str = "cpu", + add_aperture: bool = False, + radius_factor: list[float] | float = [0.8, 1], + aperture_shift: list[float] | float = [0, 10], ): """ Initialize diffraction pattern augmentor with configurable transformations. @@ -63,6 +70,10 @@ def __init__( Range for background weight (fraction of total intensity). bkg_q : list[float] | float, default=[0.01, 0.1] Range for plasmon scattering parameter q₀ in 1/(q² + q₀²) form factor. + apply_background_to_label: bool, defautl=False + Flag for whether background should be applied to labels, and which ones according to background_label_application + background_label_application: list[bool] + List of 1/0 for if background should be applied to label add_shot : bool, default=False Enable Poisson shot noise based on electron dose. @@ -89,6 +100,13 @@ def __init__( salt_and_pepper : list[float] | float, default=[0, 5e-4] Range for fraction of pixels affected by salt and pepper noise. + add_gaussian_noise : bool, default=False + Enable gaussian noise. + gaussian_noise_mu : float, default=1.0 + Mean for gaussian noise distribution + gaussian_noise_std : float, defualt=0.5 + Standard deviation for gaussian noise distribution + add_scale : bool, default=False Enable uniform scaling of the diffraction pattern. scale_factor : list[float] | float, default=[0.9, 1.1] @@ -130,6 +148,8 @@ def __init__( add_bkg, bkg_weight, bkg_q, + apply_background_to_label, + background_label_application, add_shot, e_dose, add_shift, @@ -140,6 +160,9 @@ def __init__( add_ellipticity_to_label, add_salt_and_pepper, salt_and_pepper, + add_gaussian_noise, + gaussian_noise_mu, + gaussian_noise_std, add_scale, scale_factor, add_blur, @@ -147,6 +170,9 @@ def __init__( add_flipshift, free_rotation, rotation_range, + add_aperture, + radius_factor, + aperture_shift, ) self.generate_params() self._init_log_file() @@ -168,16 +194,19 @@ def _init_log_file(self) -> None: if self.log_file is not None: with open(self.log_file, "a") as f: f.write( - "bkg_weight,bkg_q,e_dose,xshift,yshift,exx,eyy,exy," - "scale_factor,flip_horizontal,flip_vertical,rotation_angle," - "blur_sigma,salt_and_pepper,rng_seed\n" + "bkg_weight,bkg_q,background_label_application,e_dose,xshift,yshift,exx,eyy,exy," + "gaussian_noise_mu,gaussian_noise_std,scale_factor,flip_horizontal,flip_vertical," + "rotation_angle,blur_sigma,salt_and_pepper,rng_seed\n" ) + def set_params( self, add_bkg: bool = False, bkg_weight: list[float] | float = [0.01, 0.1], bkg_q: list[float] | float = [0.01, 0.1], + apply_background_to_label: list[bool] | None = None, + background_label_application: bool = False, add_shot: bool = False, e_dose: list[float] | float = [1e5, 1e10], add_shift: bool = False, @@ -188,6 +217,9 @@ def set_params( add_ellipticity_to_label: bool = True, add_salt_and_pepper: bool = False, salt_and_pepper: list[float] | float = [0, 1e-3], + add_gaussian_noise: bool = False, + gaussian_noise_mu: list[float] | float = 1, + gaussian_noise_std: list[float] | float = 0.5, add_scale: bool = False, scale_factor: list[float] | float = [0.9, 1.1], add_blur: bool = False, @@ -195,19 +227,28 @@ def set_params( add_flipshift: bool = False, free_rotation: bool = False, rotation_range: list[float] | float = [-180, 180], + add_aperture: bool = False, + radius_factor: list[float] | float = [0.8, 1], + aperture_shift: list[float] | float = [0, 10], ) -> None: self.add_bkg = add_bkg self.add_shot = add_shot self.add_shift = add_shift self.add_ellipticity = add_ellipticity - self.add_ellipticity_to_label = add_ellipticity_to_label + self.add_ellipticity_to_label = add_ellipticity_to_label or [] self.add_salt_and_pepper = add_salt_and_pepper + self.add_gaussian_noise = add_gaussian_noise + self.gaussian_noise_mu = gaussian_noise_mu + self.gaussian_noise_std = gaussian_noise_std self.add_scale = add_scale self.add_blur = add_blur self.add_flipshift = add_flipshift + self.add_aperture = add_aperture self._bkg_weight_range = self._check_input(bkg_weight) if add_bkg else [0, 0] self._bkg_q_range = self._check_input(bkg_q) if add_bkg else [0, 0] + self.apply_background_to_label = apply_background_to_label + self.background_label_application = background_label_application self._e_dose_range = self._check_input(e_dose) if add_shot else [np.inf, np.inf] self._xshift_range = self._check_input(xshift) if add_shift else [0, 0] self._yshift_range = self._check_input(yshift) if add_shift else [0, 0] @@ -223,6 +264,10 @@ def set_params( self.free_rotation = free_rotation self._rotation_range = self._check_input(rotation_range) if add_flipshift else [0, 0] + # modification: aperture, by HAADF detector + self._radius_range = self._check_input(radius_factor) if add_aperture else [0, 0] + self._aptshift_range = self._check_input(aperture_shift) if add_aperture else [0, 0] + def generate_params(self) -> None: self.bkg_weight = self._uniform_or_zero(self._bkg_weight_range, self.add_bkg) self.bkg_q = self._uniform_or_zero(self._bkg_q_range, self.add_bkg) @@ -233,6 +278,8 @@ def generate_params(self) -> None: self.blur_sigma = self._uniform_or_zero(self._blur_range, self.add_blur) self.xshift = self._uniform_with_sign(self._xshift_range, self.add_shift) self.yshift = self._uniform_with_sign(self._yshift_range, self.add_shift) + self.xshiftapt = self._uniform_with_sign(self._aptshift_range, self.add_aperture) + self.yshiftapt = self._uniform_with_sign(self._aptshift_range, self.add_aperture) self._generate_ellipticity_params() self._generate_flipshift_params() @@ -241,6 +288,11 @@ def generate_params(self) -> None: else: self.scale_factor = 0 + if self.add_aperture: + self.radius_factor = self.rng.uniform(self._radius_range[0], self._radius_range[1]) + else: + self.radius_factor = 0 + def _uniform_or_zero(self, range_vals: list, enabled: bool) -> float: return self.rng.uniform(range_vals[0], range_vals[1]) if enabled else 0 @@ -310,6 +362,7 @@ def print_params(self, print_all: bool = False) -> None: f"Flip: H={self.flip_horizontal}, V={self.flip_vertical}, Rot: {self.rotation_angle:.1f}°", ), ("Salt & pepper", self.add_salt_and_pepper, f"Amount: {self.salt_and_pepper:.2e}"), + ("Gaussian noise", self.add_gaussian_noise, f"Mean: {self.gaussian_noise_mu:.2e}", f"Std: {self.gaussian_noise_std:.2e}"), ("Gaussian blur", self.add_blur, f"Sigma: {self.blur_sigma:.2f}"), ] @@ -360,17 +413,32 @@ def _augment_stack( if probe_stack is not None and probe_stack.shape[0] != batch_size: raise ValueError(f"Probe stack size {probe_stack.shape[0]} != DP size {batch_size}") - if label_stack is not None and label_stack.shape[0] != batch_size: + # Make exception for batch_size of 1 + if batch_size == 1 and len(label_stack.shape) == 3: + pass + elif label_stack is not None and label_stack.shape[0] != batch_size: raise ValueError(f"Label stack size {label_stack.shape[0]} != DP size {batch_size}") augmented_dps = [] augmented_labels = [] if label_stack is not None else None - for i in tqdm(range(batch_size), desc="augmenting"): + # Create iterator with condition for batch_size of 1 + iterator = tqdm(range(batch_size), desc="augmenting") if batch_size > 1 else range(batch_size) + for i in iterator: dp_single = dp_stack[i] probe_single = probe_stack[i] if probe_stack is not None else None - label_single = label_stack[i] if label_stack is not None else None - + + # Check for multichannel labels + if label_stack is not None: + if batch_size == 1 and len(label_stack.shape) == 3: + # Single image with multichannel labels + label_single = label_stack # Use entire multichannel label + else: + # If multiple images take labels for current iterant + label_single = label_stack[i] + else: + label_single = None + if label_single is not None: aug_dp, aug_label = self._augment_single(dp_single, probe_single, label_single) augmented_dps.append(aug_dp) @@ -382,13 +450,21 @@ def _augment_stack( if self.use_torch: stacked_dps = torch.stack(augmented_dps) # type: ignore if augmented_labels is not None: - stacked_labels = torch.stack(augmented_labels) # type: ignore + # Check for batch size of 1 + if batch_size == 1 and len(label_stack.shape) == 3: + stacked_labels = augmented_labels[0] # If multichannel just return, don't stack + else: + stacked_labels = torch.stack(augmented_labels) # type: ignore return stacked_dps, stacked_labels return stacked_dps else: stacked_dps = np.stack(augmented_dps) if augmented_labels is not None: - stacked_labels = np.stack(augmented_labels) + # Check for batch size of 1 + if batch_size == 1 and len(label_stack.shape) == 3: + stacked_labels = augmented_labels[0] # If multichannel just return, don't stack + else: + stacked_labels = np.stack(augmented_labels) return stacked_dps, stacked_labels return stacked_dps @@ -401,15 +477,35 @@ def _augment_single( if self.add_flipshift: result = self._apply_flipshift(result) if transformed_label is not None: - transformed_label = self._apply_flipshift(transformed_label) - if self.add_bkg: - result = self._apply_bkg(result, probe) + # Check if label is multichannel + if len(transformed_label.shape) == 3: + transformed_label = self._apply_flipshift_to_multichannel_label(label) + else: + transformed_label = self._apply_flipshift(label) + if self.add_ellipticity or self.add_shift or self.add_scale: result = self._apply_elastic(result) if transformed_label is not None: - transformed_label = self._apply_elastic_to_label(transformed_label) + # Check if label is multichannel + if len(transformed_label.shape) == 3: + transformed_label = self._apply_elastic_to_multichannel_label(transformed_label) + else: + transformed_label = self._apply_elastic_to_label(transformed_label) + + if self.add_bkg: + result = self._apply_bkg(result, probe) + # Apply background to specified label channels BEFORE elastic transforms + if transformed_label is not None and self.apply_background_to_label and self.background_label_application is not None: + if len(self.background_label_application) > 0: + if len(transformed_label.shape) == 3: + transformed_label = self._apply_bkg_to_multichannel_label(transformed_label, probe) + # modification: aperture + if self.add_aperture: # currently input can only be Tensor + result = self._apply_aperture(result) if self.add_shot: result = self._apply_shot(result) + if self.add_gaussian_noise: + result = self._apply_gaussian_noise(result) if self.add_blur: result = self._apply_blur(result) if self.add_salt_and_pepper: @@ -461,6 +557,34 @@ def _maybe_switch_to_torch( self.use_torch = True self._rng_to_device(self.device) + def _apply_flipshift_to_multichannel_label(self, label: ArrayLike) -> ArrayLike: + """Apply flipshift to multichannel label""" + if len(label.shape) == 3: # Multichannel (C, H, W) + transformed_channels = [] + for c in range(label.shape[0]): + transformed_channels.append(self._apply_flipshift(label[c])) + if self.use_torch: + return torch.stack(transformed_channels) + else: + return np.stack(transformed_channels) + else: + # Single channel label + return self._apply_flipshift(label) + + def _apply_elastic_to_multichannel_label(self, label: ArrayLike) -> ArrayLike: + """Apply elastic transforms to multichannel label""" + if len(label.shape) == 3: # Multichannel (C, H, W) + transformed_channels = [] + for c in range(label.shape[0]): + transformed_channels.append(self._apply_elastic_to_label(label[c])) + if self.use_torch: + return torch.stack(transformed_channels) + else: + return np.stack(transformed_channels) + else: + # Single channel label + return self._apply_elastic_to_label(label) + def _apply_shot(self, inputs: ArrayLike) -> ArrayLike: """Apply Poisson shot noise""" if self.use_torch: @@ -468,14 +592,35 @@ def _apply_shot(self, inputs: ArrayLike) -> ArrayLike: offset = image.min() image = (image - offset) / (image - offset).sum() return torch.poisson(image * self.e_dose, generator=self._rng_torch) + offset + # Below version preserves total intensity + # sum_int = (image - offset).sum() + # image = (image - offset) / sum_int + # return torch.poisson(image * self.e_dose, generator=self._rng_torch) * sum_int / self.e_dose + offset else: image = np.array(inputs) offset = image.min() image = (image - offset) / (image - offset).sum() return self.rng.poisson(image * self.e_dose) + offset + def _apply_aperture(self, inputs: "torch.Tensor") -> "torch.Tensor": + height, width = inputs.shape + device = inputs.device + y, x = torch.meshgrid( + torch.arange(height, dtype=torch.float32, device=device), + torch.arange(width, dtype=torch.float32, device=device), + indexing="ij", + ) + y_center, x_center = height // 2, width // 2 + y = y.clone() - y_center + self.yshiftapt + x = x.clone() - x_center + self.xshiftapt + r = torch.sqrt(x**2+y**2) + + aperture_mask = (r <= self.radius_factor*np.sqrt(y_center**2+x_center**2)).float() + output = inputs * aperture_mask + return output + def _apply_elastic(self, inputs: ArrayLike) -> ArrayLike: - """Apply elastic transformations (scaling, rotation, translation)""" + """Apply elastic transformations (scaling, translation)""" if self.use_torch: return self._apply_elastic_torch(inputs) # type: ignore else: @@ -502,7 +647,7 @@ def _apply_elastic_torch(self, inputs: "torch.Tensor") -> "torch.Tensor": if self.add_shift: x_new += self.xshift y_new += self.yshift - + x_norm = 2.0 * x_new / (width - 1) - 1.0 y_norm = 2.0 * y_new / (height - 1) - 1.0 grid = torch.stack([x_norm, y_norm], dim=-1).unsqueeze(0) @@ -551,6 +696,30 @@ def _apply_bkg(self, inputs: ArrayLike, probe: ArrayLike | None = None) -> Array inputs_float = af.as_type(inputs, torch.float32 if self.use_torch else np.float32) return inputs_float * (1 - self.bkg_weight) + CBEDbgConv.real * self.bkg_weight + def _apply_bkg_to_multichannel_label(self, label: ArrayLike, probe: ArrayLike | None = None) -> ArrayLike: + """Apply background to specified channels of multichannel label""" + if len(label.shape) != 3: + # Single channel label - shouldn't normally happen + return label + if len(self.background_label_application) == 0: + # No background specification + return label + + # Process each channel + result_channels = [] + for c in range(label.shape[0]): + if c < len(self.background_label_application) and self.background_label_application[c]: + # Apply background to this channel per background_label_application + result_channels.append(self._apply_bkg(label[c], probe)) + else: + # Keep channel as-is + result_channels.append(label[c]) + + if self.use_torch: + return torch.stack(result_channels) + else: + return np.stack(result_channels) + def _apply_blur(self, inputs: ArrayLike) -> ArrayLike: """Apply Gaussian blur""" if self.use_torch: @@ -611,14 +780,33 @@ def _get_salt_and_pepper( out[flipped & ~salted] = pepper_val return out + def _apply_gaussian_noise(self, inputs: ArrayLike) -> ArrayLike: + # Constant background applied to everything, scaled by electron dose + # Gaussian uniform to whole image, clipped to 0 + # Just camera noise, electronic noise + # Just some random scale value (std 5 e- for example, mean is std, then clip. Makes it so gaussian shifted so half isn't negative) + mean = self.gaussian_noise_mu * self.e_dose if self.add_shot else self.gaussian_noise_mu + std = self.gaussian_noise_std * self.e_dose if self.add_shot else self.gaussian_noise_std + + if self.use_torch: + image = inputs.clone() + noise = torch.clip(torch.normal(mean=mean, std=std, size=inputs.shape), min=0) + image += noise + return image + else: + image = np.array(inputs).copy() + noise = np.clip(np.random.normal(loc=mean, scale=std, size=inputs.shape), a_min=0, a_max=None) + image += noise + return image + def write_logs(self) -> None: if self.log_file is None: return with open(self.log_file, "a") as f: f.write( - f"{self.bkg_weight},{self.bkg_q},{self.e_dose},{self.xshift}," + f"{self.bkg_weight},{self.bkg_q},{self.background_label_application},{self.e_dose},{self.xshift}," f"{self.yshift},{self.exx},{self.eyy},{self.exy}," - f"{self.scale_factor},{self.flip_horizontal},{self.flip_vertical}," + f"{self.gaussian_noise_mu},{self.gaussian_noise_std},{self.scale_factor},{self.flip_horizontal},{self.flip_vertical}," f"{self.rotation_angle},{self.blur_sigma},{self.salt_and_pepper}," f"{self._rng_seed}\n" ) From c988ee31e13f14f032c98ea50d645135712f6672 Mon Sep 17 00:00:00 2001 From: Nicholas Marchese Date: Fri, 5 Dec 2025 11:23:51 -0800 Subject: [PATCH 2/5] Undid oversight of putting elastic before background, now agrees with original augmentor order. --- src/quantem/core/utils/augment_dp.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/quantem/core/utils/augment_dp.py b/src/quantem/core/utils/augment_dp.py index 5bacc78b..4b79ae8f 100644 --- a/src/quantem/core/utils/augment_dp.py +++ b/src/quantem/core/utils/augment_dp.py @@ -264,7 +264,6 @@ def set_params( self.free_rotation = free_rotation self._rotation_range = self._check_input(rotation_range) if add_flipshift else [0, 0] - # modification: aperture, by HAADF detector self._radius_range = self._check_input(radius_factor) if add_aperture else [0, 0] self._aptshift_range = self._check_input(aperture_shift) if add_aperture else [0, 0] @@ -483,6 +482,13 @@ def _augment_single( else: transformed_label = self._apply_flipshift(label) + if self.add_bkg: + result = self._apply_bkg(result, probe) + if transformed_label is not None and self.apply_background_to_label and self.background_label_application is not None: + if len(self.background_label_application) > 0: + if len(transformed_label.shape) == 3: + transformed_label = self._apply_bkg_to_multichannel_label(transformed_label, probe) + if self.add_ellipticity or self.add_shift or self.add_scale: result = self._apply_elastic(result) if transformed_label is not None: @@ -492,14 +498,6 @@ def _augment_single( else: transformed_label = self._apply_elastic_to_label(transformed_label) - if self.add_bkg: - result = self._apply_bkg(result, probe) - # Apply background to specified label channels BEFORE elastic transforms - if transformed_label is not None and self.apply_background_to_label and self.background_label_application is not None: - if len(self.background_label_application) > 0: - if len(transformed_label.shape) == 3: - transformed_label = self._apply_bkg_to_multichannel_label(transformed_label, probe) - # modification: aperture if self.add_aperture: # currently input can only be Tensor result = self._apply_aperture(result) if self.add_shot: From 969123751d281e8d3279f775615893185d0d1906 Mon Sep 17 00:00:00 2001 From: Nicholas Marchese Date: Tue, 16 Dec 2025 08:25:03 -0800 Subject: [PATCH 3/5] Added changes based on PR comments. Removed background_label_application and implemented that functionality in apply_background_to_label, which is now either list or None to reflect this. If None, will not apply, otherwise will apply background to labels according to boolean list. Updated _apply_bkg to incorporate ZiXi's correction for shifts. Changed defaults of gaussian noise to be 0 for mu, 1e-5 for std, and updated docstring to reflect that this scaling is based on total electron dose. Changed application order to be elastic->background. Added aperature related attributes to docstring, and fixed typo in docstring. Added warning if single channel passed to multichannel functions. --- src/quantem/core/utils/augment_dp.py | 75 ++++++++++++++++------------ 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/src/quantem/core/utils/augment_dp.py b/src/quantem/core/utils/augment_dp.py index 4b79ae8f..535f45a3 100644 --- a/src/quantem/core/utils/augment_dp.py +++ b/src/quantem/core/utils/augment_dp.py @@ -1,4 +1,5 @@ import os +import warnings from typing import TYPE_CHECKING, Union import numpy as np @@ -30,8 +31,7 @@ def __init__( add_bkg: bool = False, bkg_weight: list[float] | float = [0.001, 0.05], bkg_q: list[float] | float = [0.01, 0.1], - apply_background_to_label: bool = False, - background_label_application: list[bool] | None = None, + apply_background_to_label: list[bool] | None = None, add_shot: bool = False, e_dose: list[float] | float = [1e4, 1e7], add_shift: bool = False, @@ -43,8 +43,8 @@ def __init__( add_salt_and_pepper: bool = False, salt_and_pepper: list[float] | float = [0, 5e-4], add_gaussian_noise: bool = False, - gaussian_noise_mu: float = 1.0, - gaussian_noise_std: float = 0.5, + gaussian_noise_mu: float = 0.0, + gaussian_noise_std: float = 1e-5, add_scale: bool = False, scale_factor: list[float] | float = [0.9, 1.1], add_blur: bool = False, @@ -70,11 +70,9 @@ def __init__( Range for background weight (fraction of total intensity). bkg_q : list[float] | float, default=[0.01, 0.1] Range for plasmon scattering parameter q₀ in 1/(q² + q₀²) form factor. - apply_background_to_label: bool, defautl=False - Flag for whether background should be applied to labels, and which ones according to background_label_application - background_label_application: list[bool] - List of 1/0 for if background should be applied to label - + apply_background_to_label: list[bool] | None, default=None + Flag for whether background should be applied to labels, and which ones based on 1/0 list. + List of 1/0 for if background should be applied to label. None if no application. add_shot : bool, default=False Enable Poisson shot noise based on electron dose. e_dose : list[float] | float, default=[1e4, 1e7] @@ -102,10 +100,12 @@ def __init__( add_gaussian_noise : bool, default=False Enable gaussian noise. - gaussian_noise_mu : float, default=1.0 - Mean for gaussian noise distribution - gaussian_noise_std : float, defualt=0.5 - Standard deviation for gaussian noise distribution + gaussian_noise_mu : float, default=0.0 + Mean for gaussian noise distribution. Should be 0 for scientifically accurate representation. + Scaled by electron dose. So value of 0.1 represents mean = 10% of electron dose. + gaussian_noise_std : float, defualt=1e-5 + Standard deviation for gaussian noise distribution. + Scaled by electron dose. So value of 0.1 represents std. dev. = 10% of electron dose. add_scale : bool, default=False Enable uniform scaling of the diffraction pattern. @@ -132,6 +132,16 @@ def __init__( device : str, default="cpu" Device for computations ("cpu", "cuda", "cuda:0", etc.). + add_aperture : bool, default=False + Enable circular aperture mask to simulate objective aperture effects. + radius_factor : list[float] | float, default=[0.8, 1] + Range for aperture radius as fraction of maximum image radius (distance from + center to corner). Values < 1 create vignetted diffraction patterns. The mask + is centered at (height//2, width//2) + aperture_shift. + aperture_shift : list[float] | float, default=[0, 10] + Range for random shift of aperture center in pixels (applied with random sign + to both x and y). Simulates misalignment of the objective aperture. + Notes ----- - Augmentations are applied in order: flipshift → background → elastic → @@ -149,7 +159,6 @@ def __init__( bkg_weight, bkg_q, apply_background_to_label, - background_label_application, add_shot, e_dose, add_shift, @@ -194,7 +203,7 @@ def _init_log_file(self) -> None: if self.log_file is not None: with open(self.log_file, "a") as f: f.write( - "bkg_weight,bkg_q,background_label_application,e_dose,xshift,yshift,exx,eyy,exy," + "bkg_weight,bkg_q,apply_background_to_label,e_dose,xshift,yshift,exx,eyy,exy," "gaussian_noise_mu,gaussian_noise_std,scale_factor,flip_horizontal,flip_vertical," "rotation_angle,blur_sigma,salt_and_pepper,rng_seed\n" ) @@ -206,7 +215,6 @@ def set_params( bkg_weight: list[float] | float = [0.01, 0.1], bkg_q: list[float] | float = [0.01, 0.1], apply_background_to_label: list[bool] | None = None, - background_label_application: bool = False, add_shot: bool = False, e_dose: list[float] | float = [1e5, 1e10], add_shift: bool = False, @@ -218,8 +226,8 @@ def set_params( add_salt_and_pepper: bool = False, salt_and_pepper: list[float] | float = [0, 1e-3], add_gaussian_noise: bool = False, - gaussian_noise_mu: list[float] | float = 1, - gaussian_noise_std: list[float] | float = 0.5, + gaussian_noise_mu: list[float] | float = 0.0, + gaussian_noise_std: list[float] | float = 1e-5, add_scale: bool = False, scale_factor: list[float] | float = [0.9, 1.1], add_blur: bool = False, @@ -248,7 +256,6 @@ def set_params( self._bkg_weight_range = self._check_input(bkg_weight) if add_bkg else [0, 0] self._bkg_q_range = self._check_input(bkg_q) if add_bkg else [0, 0] self.apply_background_to_label = apply_background_to_label - self.background_label_application = background_label_application self._e_dose_range = self._check_input(e_dose) if add_shot else [np.inf, np.inf] self._xshift_range = self._check_input(xshift) if add_shift else [0, 0] self._yshift_range = self._check_input(yshift) if add_shift else [0, 0] @@ -482,13 +489,6 @@ def _augment_single( else: transformed_label = self._apply_flipshift(label) - if self.add_bkg: - result = self._apply_bkg(result, probe) - if transformed_label is not None and self.apply_background_to_label and self.background_label_application is not None: - if len(self.background_label_application) > 0: - if len(transformed_label.shape) == 3: - transformed_label = self._apply_bkg_to_multichannel_label(transformed_label, probe) - if self.add_ellipticity or self.add_shift or self.add_scale: result = self._apply_elastic(result) if transformed_label is not None: @@ -498,6 +498,13 @@ def _augment_single( else: transformed_label = self._apply_elastic_to_label(transformed_label) + if self.add_bkg: + result = self._apply_bkg(result, probe) + if transformed_label is not None and self.apply_background_to_label is not None: + if len(self.apply_background_to_label) > 0: + if len(transformed_label.shape) == 3: + transformed_label = self._apply_bkg_to_multichannel_label(transformed_label, probe) + if self.add_aperture: # currently input can only be Tensor result = self._apply_aperture(result) if self.add_shot: @@ -683,7 +690,10 @@ def _apply_bkg(self, inputs: ArrayLike, probe: ArrayLike | None = None) -> Array qx = af.view(af.sort(af.fftfreq(height, 0.1, like=inputs), axis=0), (-1, 1)) qy = af.view(af.sort(af.fftfreq(width, 0.1, like=inputs), axis=0), (1, -1)) - CBEDbg = 1.0 / (qx**2 + qy**2 + self.bkg_q**2) # Plasmon form factor: 1/(q² + q₀²) + qxc = self.yshift / (height*0.1) + qyc = self.xshift / (width*0.1) + + CBEDbg = 1.0 / ((qx+qxc)**2 + (qy+qyc)**2 + self.bkg_q**2) # Plasmon form factor: 1/(q² + q₀²) CBEDbg = CBEDbg.squeeze() / af.sum(CBEDbg.squeeze()) if probe is not None: @@ -697,17 +707,18 @@ def _apply_bkg(self, inputs: ArrayLike, probe: ArrayLike | None = None) -> Array def _apply_bkg_to_multichannel_label(self, label: ArrayLike, probe: ArrayLike | None = None) -> ArrayLike: """Apply background to specified channels of multichannel label""" if len(label.shape) != 3: - # Single channel label - shouldn't normally happen + warnings.warn(f"Expected shape (C,H,W), got {label.shape}. Returning unchanged.", stacklevel=2) return label + if len(self.background_label_application) == 0: - # No background specification + warnings.warn("background_label_application is empty. Returning unchanged.", stacklevel=2) return label # Process each channel result_channels = [] for c in range(label.shape[0]): - if c < len(self.background_label_application) and self.background_label_application[c]: - # Apply background to this channel per background_label_application + if c < len(self.apply_background_to_label) and self.apply_background_to_label[c]: + # Apply background to this channel per apply_background_to_label result_channels.append(self._apply_bkg(label[c], probe)) else: # Keep channel as-is @@ -802,7 +813,7 @@ def write_logs(self) -> None: return with open(self.log_file, "a") as f: f.write( - f"{self.bkg_weight},{self.bkg_q},{self.background_label_application},{self.e_dose},{self.xshift}," + f"{self.bkg_weight},{self.bkg_q},{self.apply_background_to_label},{self.e_dose},{self.xshift}," f"{self.yshift},{self.exx},{self.eyy},{self.exy}," f"{self.gaussian_noise_mu},{self.gaussian_noise_std},{self.scale_factor},{self.flip_horizontal},{self.flip_vertical}," f"{self.rotation_angle},{self.blur_sigma},{self.salt_and_pepper}," From 1e7de240cf32ccf94e98627b550a966f43582b03 Mon Sep 17 00:00:00 2001 From: Nicholas Marchese Date: Tue, 16 Dec 2025 08:26:30 -0800 Subject: [PATCH 4/5] Changed docstring to reflect new order of application (elastic -> background). --- src/quantem/core/utils/augment_dp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/quantem/core/utils/augment_dp.py b/src/quantem/core/utils/augment_dp.py index 535f45a3..756ff123 100644 --- a/src/quantem/core/utils/augment_dp.py +++ b/src/quantem/core/utils/augment_dp.py @@ -144,7 +144,7 @@ def __init__( Notes ----- - - Augmentations are applied in order: flipshift → background → elastic → + - Augmentations are applied in order: flipshift → elastic → background → shot noise → blur → salt & pepper - For labels, only geometric transforms (flipshift, elastic) are applied - Ellipticity creates anisotropic scaling via exx, eyy, exy parameters From 17f6d0ccaf44dcd3a52ece8e0496c547aae2e790 Mon Sep 17 00:00:00 2001 From: Nicholas Marchese Date: Mon, 9 Feb 2026 12:48:58 -0800 Subject: [PATCH 5/5] Removed set_params method and incorporated that functionality into the init --- src/quantem/core/utils/augment_dp.py | 112 +++++++-------------------- 1 file changed, 27 insertions(+), 85 deletions(-) diff --git a/src/quantem/core/utils/augment_dp.py b/src/quantem/core/utils/augment_dp.py index 756ff123..44a92022 100644 --- a/src/quantem/core/utils/augment_dp.py +++ b/src/quantem/core/utils/augment_dp.py @@ -154,91 +154,7 @@ def __init__( self._setup_device(device) self.log_file = log_file - self.set_params( - add_bkg, - bkg_weight, - bkg_q, - apply_background_to_label, - add_shot, - e_dose, - add_shift, - xshift, - yshift, - add_ellipticity, - ellipticity_scale, - add_ellipticity_to_label, - add_salt_and_pepper, - salt_and_pepper, - add_gaussian_noise, - gaussian_noise_mu, - gaussian_noise_std, - add_scale, - scale_factor, - add_blur, - blur_sigma, - add_flipshift, - free_rotation, - rotation_range, - add_aperture, - radius_factor, - aperture_shift, - ) - self.generate_params() - self._init_log_file() - - def _setup_device(self, device: str) -> None: - if device == "gpu" or device.startswith("cuda"): - if not config.get("has_torch"): - raise RuntimeError("torch required for GPU operations but not available") - self.device = device if device.startswith("cuda") else "cuda" - self.use_torch = True - else: - self.device = "cpu" - self.use_torch = False - - if hasattr(self, "_rng_seed") and self._rng_seed is not None: - self._rng_to_device(self.device) - - def _init_log_file(self) -> None: - if self.log_file is not None: - with open(self.log_file, "a") as f: - f.write( - "bkg_weight,bkg_q,apply_background_to_label,e_dose,xshift,yshift,exx,eyy,exy," - "gaussian_noise_mu,gaussian_noise_std,scale_factor,flip_horizontal,flip_vertical," - "rotation_angle,blur_sigma,salt_and_pepper,rng_seed\n" - ) - - - def set_params( - self, - add_bkg: bool = False, - bkg_weight: list[float] | float = [0.01, 0.1], - bkg_q: list[float] | float = [0.01, 0.1], - apply_background_to_label: list[bool] | None = None, - add_shot: bool = False, - e_dose: list[float] | float = [1e5, 1e10], - add_shift: bool = False, - xshift: list[float] | float = [0, 10], - yshift: list[float] | float = [0, 10], - add_ellipticity: bool = False, - ellipticity_scale: list[float] | float = [0, 0.15], - add_ellipticity_to_label: bool = True, - add_salt_and_pepper: bool = False, - salt_and_pepper: list[float] | float = [0, 1e-3], - add_gaussian_noise: bool = False, - gaussian_noise_mu: list[float] | float = 0.0, - gaussian_noise_std: list[float] | float = 1e-5, - add_scale: bool = False, - scale_factor: list[float] | float = [0.9, 1.1], - add_blur: bool = False, - blur_sigma: list[float] | float = [0.0, 1.5], - add_flipshift: bool = False, - free_rotation: bool = False, - rotation_range: list[float] | float = [-180, 180], - add_aperture: bool = False, - radius_factor: list[float] | float = [0.8, 1], - aperture_shift: list[float] | float = [0, 10], - ) -> None: + # Setting parameters self.add_bkg = add_bkg self.add_shot = add_shot self.add_shift = add_shift @@ -274,6 +190,32 @@ def set_params( self._radius_range = self._check_input(radius_factor) if add_aperture else [0, 0] self._aptshift_range = self._check_input(aperture_shift) if add_aperture else [0, 0] + # Generate parameters from set parameters + self.generate_params() + self._init_log_file() + + def _setup_device(self, device: str) -> None: + if device == "gpu" or device.startswith("cuda"): + if not config.get("has_torch"): + raise RuntimeError("torch required for GPU operations but not available") + self.device = device if device.startswith("cuda") else "cuda" + self.use_torch = True + else: + self.device = "cpu" + self.use_torch = False + + if hasattr(self, "_rng_seed") and self._rng_seed is not None: + self._rng_to_device(self.device) + + def _init_log_file(self) -> None: + if self.log_file is not None: + with open(self.log_file, "a") as f: + f.write( + "bkg_weight,bkg_q,apply_background_to_label,e_dose,xshift,yshift,exx,eyy,exy," + "gaussian_noise_mu,gaussian_noise_std,scale_factor,flip_horizontal,flip_vertical," + "rotation_angle,blur_sigma,salt_and_pepper,rng_seed\n" + ) + def generate_params(self) -> None: self.bkg_weight = self._uniform_or_zero(self._bkg_weight_range, self.add_bkg) self.bkg_q = self._uniform_or_zero(self._bkg_q_range, self.add_bkg)