diff --git a/configs/tests/vortex/standard.yaml b/configs/tests/vortex/standard.yaml new file mode 100644 index 00000000..68fb32b5 --- /dev/null +++ b/configs/tests/vortex/standard.yaml @@ -0,0 +1,60 @@ +DESCRIPTION: + BRIEF: Testing VORTEX config + EXP_NAME: test/vortex/standard + PROJECT_NAME: vortex + TAGS: ("mridata_knee_2019", "unet", "16x", "vortex") +OUTPUT_DIR: "results://tests/vortex/standard" +DATASETS: + TRAIN: ("mridata_knee_2019_train",) + VAL: ("mridata_knee_2019_val",) + TEST: ("mridata_knee_2019_test",) +AUG_TRAIN: + UNDERSAMPLE: + NAME: "PoissonDiskMaskFunc" + ACCELERATIONS: (16,) + CALIBRATION_SIZE: 20 +AUG_TEST: + UNDERSAMPLE: + ACCELERATIONS: (16,) +DATALOADER: + NUM_WORKERS: 0 + SUBSAMPLE_TRAIN: + NUM_TOTAL: 6 # Number of total scans to use + NUM_UNDERSAMPLED: 5 # Number of undersampled scans to use + SAMPLER_TRAIN: "AlternatingSampler" # Choices: ["", "AlternatingSampler"] + ALT_SAMPLER: + PERIOD_SUPERVISED: 1 + PERIOD_UNSUPERVISED: 1 +MODEL: + META_ARCHITECTURE: "VortexModel" + UNET: + CHANNELS: 32 + DROPOUT: 0.0 + IN_CHANNELS: 2 + NUM_POOL_LAYERS: 4 + OUT_CHANNELS: 2 + RECON_LOSS: + NAME: "l1" + RENORMALIZE_DATA: False + A2R: + META_ARCHITECTURE: "UnetModel" + CONSISTENCY: + AUG: + MRI_RECON: + AUG_SENSITIVITY_MAPS: true + SCHEDULER_P: + IGNORE: false + TRANSFORMS: () # TODO: This must be filled in +SOLVER: + TRAIN_BATCH_SIZE: 2 + TEST_BATCH_SIZE: 2 + CHECKPOINT_PERIOD: 20 + MAX_ITER: 60 +TEST: + EVAL_PERIOD: 40 + VAL_METRICS: + RECON: ("psnr", "psnr_scan", "psnr_mag", "psnr_mag_scan", "nrmse", "nrmse_scan", "nrmse_mag", "nrmse_mag_scan", "ssim (Wang)", "ssim (Wang)_scan") +TIME_SCALE: "iter" +SEED: 1000 +VIS_PERIOD: 10 +VERSION: 1 \ No newline at end of file diff --git a/configs/tests/vortex/vortex_multishot_motion.yaml b/configs/tests/vortex/vortex_multishot_motion.yaml new file mode 100644 index 00000000..58b82a76 --- /dev/null +++ b/configs/tests/vortex/vortex_multishot_motion.yaml @@ -0,0 +1,26 @@ +# Test +_BASE_: "standard.yaml" +DESCRIPTION: + BRIEF: Testing VORTEX multi-shot motion config + EXP_NAME: test/vortex/multi-shot-motion + PROJECT_NAME: vortex + TAGS: ("mridata_knee_2019", "unet", "16x", "vortex") +OUTPUT_DIR: "results://tests/vortex/multi-shot-motion" +MODEL: + CONSISTENCY: + AUG: + MRI_RECON: + AUG_SENSITIVITY_MAPS: false + SCHEDULER_P: + IGNORE: false + TRANSFORMS: + - name: RandomMRIMultiShotMotion + p: 1.0 + nshots: 8 + trajectory: blocked + tfms_or_gens: + - name: RandomAffine + p: + angle: 1.0 + pad_like: MRAugment + angle: 5 \ No newline at end of file diff --git a/configs/tests/vortex/vortex_ssdu.yaml b/configs/tests/vortex/vortex_ssdu.yaml new file mode 100644 index 00000000..c6ee39c1 --- /dev/null +++ b/configs/tests/vortex/vortex_ssdu.yaml @@ -0,0 +1,48 @@ +# Self-Supervised VORTEX (i.e. VORTEX+SSDU) +_BASE_: "standard.yaml" +MODEL: + META_ARCHITECTURE: "VortexModel" + A2R: + META_ARCHITECTURE: "SSDUModel" + SSDU: + META_ARCHITECTURE: "UnetModel" + MASKER: + PARAMS: + p: 1.0 + rhos: 0.4 # Based on hyperparameter search + kind: "uniform" # Choices: ["uniform", "gaussian"] + per_example: True + std_scale: 4 # Default value in SSDU code + UNET: + IN_CHANNELS: 2 + OUT_CHANNELS: 2 + CHANNELS: 32 + NUM_POOL_LAYERS: 4 + DROPOUT: 0. + CONSISTENCY: + AUG: + MRI_RECON: + AUG_SENSITIVITY_MAPS: true + SCHEDULER_P: + IGNORE: false + TRANSFORMS: + # Motion Transform + - name: RandomMRIMotion + p: 1.0 + std_devs: + - 0.2 + - 0.5 + RECON_LOSS: + NAME: "k_l1" + RENORMALIZE_DATA: False +DATALOADER: + NUM_WORKERS: 0 # for debugging purposes + SAMPLER_TRAIN: "" # Random sampler in self-supervised settings + SUBSAMPLE_TRAIN: + NUM_TOTAL: 14 + NUM_UNDERSAMPLED: 14 # undersample all scans - N2R + SSDU should still train. +SOLVER: + TRAIN_BATCH_SIZE: 1 +TIME_SCALE: "iter" +OUTPUT_DIR: "results://tests/vortex/vortex_ssdu" +VERSION: 1 \ No newline at end of file diff --git a/meddlr/config/config.py b/meddlr/config/config.py index a2409d0d..c4114016 100644 --- a/meddlr/config/config.py +++ b/meddlr/config/config.py @@ -413,7 +413,7 @@ def _format_str(val_str: str, *, cfg: CfgNode, unroll: bool): assert len(start) == len(end), f"Could not determine formatting string: {val_str}" if len(start) == 0: - return val_str + return eval(val_str) cfg_keys_to_search = [val_str[s + 1 : e] for s, e in zip(start, end)] values = [cfg.get_recursive(v) for v in cfg_keys_to_search] diff --git a/meddlr/config/defaults.py b/meddlr/config/defaults.py index 4e975a75..d4bf9cdb 100644 --- a/meddlr/config/defaults.py +++ b/meddlr/config/defaults.py @@ -226,6 +226,7 @@ _C.MODEL.A2R = CN() _C.MODEL.A2R.META_ARCHITECTURE = "GeneralizedUnrolledCNN" _C.MODEL.A2R.USE_SUPERVISED_CONSISTENCY = False +_C.MODEL.A2R.EDGE_DC = False # ----------------------------------------------------------------------------- # SSDU model diff --git a/meddlr/data/build.py b/meddlr/data/build.py index 5fab414d..0b79d93b 100644 --- a/meddlr/data/build.py +++ b/meddlr/data/build.py @@ -304,12 +304,17 @@ def build_recon_val_loader( "drop_last": False, } + num_workers = cfg.DATALOADER.NUM_WORKERS + prefetch_factor = cfg.DATALOADER.PREFETCH_FACTOR + if env.pt_version() >= "2.0" and num_workers == 0: + prefetch_factor = None + val_loader = DataLoader( dataset=val_data, num_workers=cfg.DATALOADER.NUM_WORKERS, pin_memory=True, collate_fn=default_collate, - prefetch_factor=cfg.DATALOADER.PREFETCH_FACTOR, + prefetch_factor=prefetch_factor, **dl_kwargs, ) return val_loader diff --git a/meddlr/data/data_utils.py b/meddlr/data/data_utils.py index c1df9266..79a6259a 100644 --- a/meddlr/data/data_utils.py +++ b/meddlr/data/data_utils.py @@ -286,9 +286,14 @@ def collect_mask( If sub-indices are collated, they will be summed. out_channel_first (bool, optional): Reorders dimensions of output mask to Cx(...) """ + # TODO: Make this more efficient. if isinstance(index, int): index = (index,) + is_ndarray = isinstance(mask, np.ndarray) + if is_ndarray: + mask = torch.as_tensor(mask) + if not any(isinstance(idx, Sequence) for idx in index): mask = mask[..., index] else: @@ -296,14 +301,16 @@ def collect_mask( for idx in index: c_seg = mask[..., idx] if isinstance(idx, Sequence): - c_seg = np.sum(c_seg, axis=-1) + c_seg = c_seg.sum(dim=-1) o_seg.append(c_seg) - mask = np.stack(o_seg, axis=-1) + mask = torch.stack(o_seg, axis=-1) if out_channel_first: last_idx = len(mask.shape) - 1 - mask = np.transpose(mask, (last_idx,) + tuple(range(0, last_idx))) + mask = torch.permute(mask, (last_idx,) + tuple(range(0, last_idx))) + if is_ndarray: + mask = mask.numpy() return mask diff --git a/meddlr/data/transforms/transform.py b/meddlr/data/transforms/transform.py index 24db1bcc..3d8a9846 100644 --- a/meddlr/data/transforms/transform.py +++ b/meddlr/data/transforms/transform.py @@ -390,6 +390,7 @@ def __call__( masked_kspace = masked_kspace.squeeze(0) maps = maps.squeeze(0) target = target.squeeze(0) + mask = mask.squeeze(0) out = { "kspace": masked_kspace, @@ -399,7 +400,7 @@ def __call__( "std": std, "norm": norm, "edge_mask": edge_mask.squeeze(0), - "mask": self._get_mask(masked_kspace), + "mask": mask, } if postprocessing_mask is not None: out["postprocessing_mask"] = postprocessing_mask.squeeze(0) diff --git a/meddlr/engine/trainer.py b/meddlr/engine/trainer.py index acd641d4..d0d43544 100644 --- a/meddlr/engine/trainer.py +++ b/meddlr/engine/trainer.py @@ -382,6 +382,7 @@ def build_loss_computer(cls, cfg): or cfg.MODEL.META_ARCHITECTURE == "M2RModel" or cfg.MODEL.META_ARCHITECTURE == "NM2RModel" or cfg.MODEL.META_ARCHITECTURE == "A2RModel" + or cfg.MODEL.META_ARCHITECTURE == "VortexModel" else "BasicLossComputer" ) return build_loss_computer(cfg, loss_computer) diff --git a/meddlr/evaluation/evaluator.py b/meddlr/evaluation/evaluator.py index 99d4976b..e4f5db38 100644 --- a/meddlr/evaluation/evaluator.py +++ b/meddlr/evaluation/evaluator.py @@ -79,6 +79,9 @@ def values(self) -> Sequence[DatasetEvaluator]: else: return self._evaluators + def pop(self, key): + return self._evaluators.pop(key) + def reset(self): for evaluator in self.values(): evaluator.reset() diff --git a/meddlr/evaluation/recon_evaluation.py b/meddlr/evaluation/recon_evaluation.py index 7140d14b..87932076 100644 --- a/meddlr/evaluation/recon_evaluation.py +++ b/meddlr/evaluation/recon_evaluation.py @@ -342,12 +342,14 @@ def evaluate(self): os.path.join(self._save_scan_dir, f"{scan_id}.h5"), ) - if self._group_by_scan: - pred_vals = self._group_results_by_scan() + if self._compute_metrics(): + if self._group_by_scan: + pred_vals = self._group_results_by_scan() + else: + pred_vals = self.slice_metrics.to_dict() + pred_vals.update(self.scan_metrics.to_dict()) else: - pred_vals = self.slice_metrics.to_dict() - pred_vals.update(self.scan_metrics.to_dict()) - + pred_vals = {} self._results = pred_vals if not self._is_flushing: @@ -374,8 +376,11 @@ def _group_results_by_scan(self): return pred_vals + def _compute_metrics(self) -> bool: + return len(self._metric_names) > 0 + def log_summary(self): - if not comm.is_main_process(): + if not comm.is_main_process() or not self._compute_metrics(): return output_dir = self._output_dir @@ -431,6 +436,8 @@ def evaluate_prediction( ex_id: Union[str, Sequence[str]], is_batch=False, ): + if len(metrics) == 0: + return {} output, target = prediction["pred"], prediction["target"] if not is_batch: output, target = output.unsqueeze(0), target.unsqueeze(0) diff --git a/meddlr/evaluation/scan_evaluator.py b/meddlr/evaluation/scan_evaluator.py index 2d2e4c54..6392927b 100644 --- a/meddlr/evaluation/scan_evaluator.py +++ b/meddlr/evaluation/scan_evaluator.py @@ -64,7 +64,10 @@ def flush(self, enter_prediction_scope: bool = True, skip_last_scan: bool = True Args: enter_prediction_scope (bool, optional): If ``True``, enter the - prediction scope. + prediction scope. You may want to set this to ``False`` if an + external object is managing the prediction scope. + This allows the external object to access predictions before + they are released from memory. skip_last_scan (bool, optional): If ``True``, does not flush most recent scan. This avoids prematurely computing metrics before all slices of the scan are available. diff --git a/meddlr/evaluation/seg_evaluation.py b/meddlr/evaluation/seg_evaluation.py index 3f12f2a9..c090d246 100644 --- a/meddlr/evaluation/seg_evaluation.py +++ b/meddlr/evaluation/seg_evaluation.py @@ -4,6 +4,7 @@ import os import time import uuid +from typing import Any, Dict import torch from tqdm import tqdm @@ -136,8 +137,16 @@ def reset(self): self.slice_metrics.eval() self.scan_metrics.eval() - def structure_scans(self, verbose=True): - """Structure scans into volumes to be used to evaluation.""" + def structure_scans(self, verbose=True) -> Dict[str, Dict[str, Any]]: + """Structure scans into volumes to be used to evaluation. + + Returns: + dict: a dictionary mapping scan_id to a dictionary with keys: + - "pred": the predicted segmentation mask - shape: (C, D, H, W) + - "target": the ground truth segmentation mask - shape: (C, D, H, W) + - "voxel_spacing": the voxel spacing of the scan (D, H, W, C) + - "affine": Affine matrix for MedicalVolume (D, H, W, C) + """ out = structure_scans(self._predictions, verbose=verbose, dims={1: "slice_id"}) return out diff --git a/meddlr/evaluation/testing.py b/meddlr/evaluation/testing.py index 8bdbcca8..63e4a7b3 100644 --- a/meddlr/evaluation/testing.py +++ b/meddlr/evaluation/testing.py @@ -236,9 +236,13 @@ def check_consistency(state_dict: Dict[str, Any], model: torch.nn.Module): model (nn.Module): A Pytorch model. """ _state_dict = model.state_dict() + mismatched_keys = [] for k in state_dict: assert k in _state_dict, f"{k} not in model state_dict: {_state_dict.keys()}" - assert torch.equal(state_dict[k], _state_dict[k]), f"Mismatch values: {k}" + if not torch.equal(state_dict[k], _state_dict[k]): + mismatched_keys.append(k) + if len(mismatched_keys): + raise ValueError("Mismatch keys:\n\t{}".format("\t\n".join(mismatched_keys))) def _metrics_from_x(metrics_file, criterion): diff --git a/meddlr/modeling/loss_computer.py b/meddlr/modeling/loss_computer.py index a2669c55..79f23cc0 100644 --- a/meddlr/modeling/loss_computer.py +++ b/meddlr/modeling/loss_computer.py @@ -219,6 +219,9 @@ def __init__(self, cfg): self.use_consistency = cfg.MODEL.CONSISTENCY.USE_CONSISTENCY self.num_latent_layers = cfg.MODEL.CONSISTENCY.NUM_LATENT_LAYERS self.latent_keys = ["E4", "E3", "D3", "E2", "D2", "E1", "D1"] + + # Apply data consistency loss between input and augmented recon. + self.use_dc = False # self.use_robust = cfg.MODEL.LOSS.USE_ROBUST # self.beta = cfg.MODEL.LOSS.BETA # self.robust_step_size = cfg.MODEL.LOSS.ROBUST_STEP_SIZE @@ -285,6 +288,17 @@ def __call__(self, input, output): } if output_consistency is not None and self.use_consistency: loss += self.consistency_weight * metrics_consistency["cons_loss"] + if output_consistency is not None and self.use_dc: + pred = output["dc"]["pred"].contiguous() + target = output["dc"]["target"].contiguous() + abs_error = cplx.abs(pred - target) + kl1_norm = torch.sum(abs_error) / torch.sum(cplx.abs(target)) + kl2_norm = torch.sqrt(torch.sum(abs_error**2)) / torch.sqrt( + torch.sum(cplx.abs(target) ** 2) + ) # noqa: E501 + dc_loss = 0.5 * kl1_norm + 0.5 * kl2_norm + # dc_loss = torch.mean(torch.abs(pred - target)) + loss += (self.consistency_weight / 10) * dc_loss if self.use_latent: num_losses = self.num_latent_layers * 2 - 1 @@ -314,6 +328,8 @@ def __call__(self, input, output): if self.use_latent: for i in range(num_losses): metrics.update(all_metrics_latent[i]) + if self.use_dc: + metrics["dc_loss"] = dc_loss metrics["loss"] = loss return metrics diff --git a/meddlr/modeling/meta_arch/unrolled.py b/meddlr/modeling/meta_arch/unrolled.py index e8dca82b..6897de49 100644 --- a/meddlr/modeling/meta_arch/unrolled.py +++ b/meddlr/modeling/meta_arch/unrolled.py @@ -104,7 +104,12 @@ def __init__( self._dc_first = order[0] == "dc" def visualize_training( - self, kspace: torch.Tensor, zfs: torch.Tensor, targets: torch.Tensor, preds: torch.Tensor + self, + kspace: torch.Tensor, + zfs: torch.Tensor, + targets: torch.Tensor, + preds: torch.Tensor, + dc_mask: torch.Tensor = None, ): """Visualize kspace data and reconstructions. @@ -139,6 +144,9 @@ def visualize_training( "errors": cplx.abs(preds - targets), "masks": cplx.get_mask(kspace), } + if dc_mask is not None: + # Take the mask for the first coil, it should be the same for all coils. + imgs_to_write["dc_mask"] = dc_mask[0:1, ..., 0].cpu() for name, data in imgs_to_write.items(): data = data.squeeze(-1).unsqueeze(1) @@ -170,7 +178,7 @@ def reg(self, *, image: torch.Tensor, model: nn.Module, dims: torch.Size): image = torch.view_as_real(image) # prox update - image = image.reshape(dims[0:3] + (self.num_emaps * 2,)).permute(0, 3, 1, 2) + image = image.reshape(dims[0:3] + (self.num_emaps * 2,)).permute(0, 3, 1, 2).contiguous() if hasattr(model, "base_forward") and callable(model.base_forward): image = model.base_forward(image) else: @@ -179,9 +187,7 @@ def reg(self, *, image: torch.Tensor, model: nn.Module, dims: torch.Size): # This doesn't work when padding is not the same. # i.e. when the output is a different shape than the input. # However, this should not ever happen. - image = image.permute(0, 2, 3, 1).reshape(dims[0:3] + (self.num_emaps, 2)) - if not image.is_contiguous(): - image = image.contiguous() + image = image.permute(0, 2, 3, 1).reshape(dims[0:3] + (self.num_emaps, 2)).contiguous() if use_cplx: image = torch.view_as_complex(image) return image @@ -302,7 +308,7 @@ def forward(self, inputs: Dict[str, Any], return_pp: bool = False, vis_training: if self.training and (vis_training or self.vis_period > 0): storage = get_event_storage() if vis_training or storage.iter % self.vis_period == 0: - self.visualize_training(kspace, zf_image, target, image) + self.visualize_training(kspace, zf_image, target, image, dc_mask=mask) output_dict["zf_image"] = zf_image diff --git a/meddlr/modeling/meta_arch/vortex.py b/meddlr/modeling/meta_arch/vortex.py index 3a703c95..dcf3e236 100644 --- a/meddlr/modeling/meta_arch/vortex.py +++ b/meddlr/modeling/meta_arch/vortex.py @@ -1,15 +1,19 @@ import logging +from typing import Dict, Optional import torch import torchvision.utils as tv_utils from torch import nn +import meddlr.ops as oF from meddlr.config.config import configurable +from meddlr.forward.mri import SenseModel +from meddlr.modeling.meta_arch import SSDUModel from meddlr.modeling.meta_arch.build import META_ARCH_REGISTRY, build_model from meddlr.ops import complex as cplx from meddlr.transforms.builtin.mri import MRIReconAugmentor from meddlr.utils.events import get_event_storage -from meddlr.utils.general import move_to_device +from meddlr.utils.general import flatten_dict, move_to_device @META_ARCH_REGISTRY.register() @@ -43,6 +47,7 @@ def __init__( model: nn.Module, augmentor: MRIReconAugmentor, use_supervised_consistency: bool = False, + edge_dc: bool = False, vis_period: int = -1, ): """ @@ -61,18 +66,33 @@ def __init__( self.augmentor = augmentor self.use_base_grad = False # Keep gradient for base images in transform. self.use_supervised_consistency = use_supervised_consistency + self.edge_dc = edge_dc + self._multicoil_image = True # Visualization done by this model - if hasattr(model, "vis_period") and vis_period > 0: + if ( + not isinstance(self.model, SSDUModel) + and hasattr(model, "vis_period") + and vis_period > 0 + ): self.model.vis_period = -1 self.vis_period = vis_period - def augment(self, inputs, pred_base): + def augment(self, inputs: Dict[str, torch.Tensor], pred_base: torch.Tensor): + """Apply augmentations to inputs and base image reconstruction. + + Args: + inputs (Dict[str, torch.Tensor]): The inputs to augment. + pred_base (torch.Tensor): The reconstruction of the base (i.e. non-augmented) image. + + Returns: + Tuple[Dict[str, torch.Tensor], torch.Tensor]: The augmented inputs and pseudo-targets. + """ inputs = move_to_device(inputs, device="cuda") pred_base = move_to_device(pred_base, device="cuda") kspace, maps = inputs["kspace"].clone(), inputs["maps"].clone() - out, _, _ = self.augmentor(kspace, maps, pred_base, mask=True) + out, _, _ = self.augmentor(kspace=kspace, maps=maps, target=pred_base, mask=True) inputs = { k: v.clone() if isinstance(v, torch.Tensor) else v @@ -84,7 +104,16 @@ def augment(self, inputs, pred_base): aug_pred_base = out["target"] return inputs, aug_pred_base - def visualize_aug_training(self, kspace, kspace_aug, preds, preds_base, target=None): + def log_augmentor_params(self): + scheduler_params = self.augmentor.get_tfm_gen_params(scalars_only=True) + if len(scheduler_params): + storage = get_event_storage() + scheduler_params = flatten_dict({"scheduler": scheduler_params}) + storage.put_scalars(**scheduler_params) + + def visualize_aug_training( + self, kspace, kspace_aug, preds, preds_base, target=None, dc_mask=None + ): """Visualize training of augmented data. Args: @@ -124,12 +153,51 @@ def visualize_aug_training(self, kspace, kspace_aug, preds, preds_base, target=N "masks": cplx.get_mask(kspace), "kspace": cplx.abs(all_kspace), } + if dc_mask is not None: + # Take the mask for the first coil, it should be the same for all coils. + imgs_to_write["dc_mask"] = dc_mask[0:1, ..., 0].cpu() for name, data in imgs_to_write.items(): data = data.squeeze(-1).unsqueeze(1) data = tv_utils.make_grid(data, nrow=1, padding=1, normalize=True, scale_each=True) storage.put_image("train_aug/{}".format(name), data.numpy(), data_format="CHW") + def _aggregate_consistency_inputs( + self, + inputs_supervised: Optional[Dict[str, torch.Tensor]] = None, + inputs_unsupervised: Optional[Dict[str, torch.Tensor]] = None, + ) -> Dict[str, torch.Tensor]: + """Aggregates consistency inputs into a single dictionary. + + Args: + inputs_supervised: A dict of inputs, their metadata, and their ground truth references. + inputs_unsupervised: A dict of inputs and their metadata. + + Returns: + Dict[str, Dict[str, Tensor]]: A dictionary of base inputs and augmented inputs: + - 'base': Inputs to be used to generate the pseudo-label (i.e. target) + for consistency optimization. + - 'aug': Augmented inputs to use for consistency training. + """ + inputs_consistency = [] + if inputs_unsupervised is not None: + inputs_consistency.append(inputs_unsupervised) + if self.use_supervised_consistency and inputs_supervised is not None: + inputs_consistency.append({k: v for k, v in inputs_supervised.items() if k != "target"}) + + if len(inputs_consistency) == 0: + return {} # No consistency training. + + if len(inputs_consistency) > 1: + inputs_consistency = { + k: torch.cat([x[k] for x in inputs_consistency], dim=0) + for k in inputs_consistency[0].keys() + } + else: + inputs_consistency = inputs_consistency[0] + + return inputs_consistency + def forward(self, inputs): if not self.training: assert ( @@ -148,42 +216,57 @@ def forward(self, inputs): inputs_unsupervised = inputs.get("unsupervised", None) if inputs_supervised is None and inputs_unsupervised is None: raise ValueError("Examples not formatted in the proper way") + # Whether to use self-supervised via data undersampling (SSDU) for reconstruction. + is_ssdu_enabled = isinstance(self.model, SSDUModel) output_dict = {} - # Recon - if inputs_supervised is not None: + # Reconstruction (supervised). + if is_ssdu_enabled: + output_dict["recon"] = self.model(inputs) + elif inputs_supervised is not None: output_dict["recon"] = self.model( inputs_supervised, return_pp=True, vis_training=vis_training ) # Consistency. - # kspace_aug = kspace + U \sigma \mathcal{N} - # Loss = L(f(Ti(Te(kspace)), \theta), Te(f(kspace, \theta))) - inputs_consistency = [] - if inputs_unsupervised is not None: - inputs_consistency.append(inputs_unsupervised) - if self.use_supervised_consistency and inputs_supervised is not None: - inputs_consistency.append({k: v for k, v in inputs_supervised.items() if k != "target"}) - + model = self.model.model if is_ssdu_enabled else self.model + inputs_consistency = self._aggregate_consistency_inputs( + inputs_supervised, inputs_unsupervised + ) + inputs_consistency = move_to_device(inputs_consistency, device="cuda") if len(inputs_consistency) > 0: - if len(inputs_consistency) > 1: - inputs_consistency = { - k: torch.cat([x[k] for x in inputs_consistency], dim=0) - for k in inputs_consistency[0].keys() - } - else: - inputs_consistency = inputs_consistency[0] + # Add the edge mask to the DC layers of the unrolled network. + # TODO: Make this configurable. + if self.edge_dc and "mask" not in inputs_consistency: + inputs_consistency["mask"] = ( + (cplx.get_mask(inputs_consistency["kspace"]) + inputs_consistency["edge_mask"]) + .bool() + .to(torch.float32) + ) + with torch.no_grad(): - pred_base = self.model(inputs_consistency) + pred_base = model(inputs_consistency) # Target only used for visualization purposes not for loss. target = inputs_unsupervised.get("target", None) + pred_base_zf = pred_base["zf_image"] # noqa: F841 pred_base = pred_base["pred"] - inputs_consistency_aug, pred_base = self.augment(inputs_consistency, pred_base) - pred_aug = self.model(inputs_consistency_aug, return_pp=True) + # Augment the inputs. + inputs_consistency_aug, pred_base = self.augment(inputs_consistency, pred_base) + + pred_aug = model(inputs_consistency_aug, return_pp=True) + if "target" in pred_aug: del pred_aug["target"] pred_aug["target"] = pred_base.detach() output_dict["consistency"] = pred_aug + + # Add DC loss between prediction and consistency inputs. + # A = SenseModel(maps=inputs_consistency["maps"], weights=cplx.get_mask(inputs_consistency["kspace"])) # noqa: E501 + # output_dict["dc"] = { + # "pred": A(pred_aug["pred"]), + # "target": inputs_consistency["kspace"], + # } + if vis_training: self.visualize_aug_training( inputs_consistency["kspace"], @@ -191,8 +274,25 @@ def forward(self, inputs): pred_aug["pred"], pred_base, target=target, + dc_mask=inputs_consistency.get("mask", None), ) + # Convert to multicoil image. + # This is needed for loss to be computed over each coil rather than + # a coil combined k-space/image. + # We do this after visualization to avoid visualizing multicoil images. + if self._multicoil_image: + with torch.no_grad(): + pred_aug["target"] = _to_multicoil_image( + x=pred_aug["target"], maps=inputs_consistency["maps"] + ) + pred_aug["pred"] = _to_multicoil_image( + x=pred_aug["pred"], maps=inputs_consistency_aug["maps"] + ) + + # Log augmentor parameters. + self.log_augmentor_params() + return output_dict @classmethod @@ -214,5 +314,13 @@ def from_config(cls, cfg): "model": model, "augmentor": augmentor, "use_supervised_consistency": cfg.MODEL.A2R.USE_SUPERVISED_CONSISTENCY, + "edge_dc": cfg.MODEL.A2R.EDGE_DC, "vis_period": cfg.VIS_PERIOD, } + + +def _to_multicoil_image(x, maps): + """Convert image x into multicoil image for loss computer compatibility.""" + # Use signal model (SENSE) to get weighted kspace. + A = SenseModel(maps=maps) # no weights - we do not want to mask the data. + return oF.ifft2c(A(x, adjoint=False), channels_last=True) diff --git a/meddlr/transforms/builtin/mri.py b/meddlr/transforms/builtin/mri.py index c9b9678c..741cc8a7 100644 --- a/meddlr/transforms/builtin/mri.py +++ b/meddlr/transforms/builtin/mri.py @@ -192,7 +192,7 @@ def schedulers(self) -> List[TFScheduler]: ] return [x for y in schedulers for x in y] - def get_tfm_gen_params(self, scalars_only: bool = True): + def get_tfm_gen_params(self, scalars_only: bool = True, filter_na=True): """Get dictionary of scheduler parameters.""" schedulers: Dict[str, Sequence[TFScheduler]] = { type(tfm).__name__: tfm._get_param_values(use_schedulers=True) @@ -205,7 +205,18 @@ def get_tfm_gen_params(self, scalars_only: bool = True): # Filter out values that are not scalars p = {f"{tfm_name}/{k}": v for k, v in p.items()} if scalars_only: - p = {k: v for k, v in p.items() if isinstance(v, Number)} + _params = {} + # TODO: Make this recursive + for k, v in p.items(): + if isinstance(v, Number): + _params[k] = v + elif isinstance(v, (list, tuple)): + for i, _v in enumerate(v): + _params[f"{k}[{i}]"] = _v + p = _params + # p = {k: v for k, v in p.items() if isinstance(v, Number)} + if filter_na: + p = {k: v for k, v in p.items() if v is not None} params.update(p) return params diff --git a/meddlr/transforms/gen/noise.py b/meddlr/transforms/gen/noise.py index 9a353abd..1ce8e909 100644 --- a/meddlr/transforms/gen/noise.py +++ b/meddlr/transforms/gen/noise.py @@ -33,7 +33,7 @@ def __init__( self.use_mask = use_mask super().__init__(params=params, p=p) - def get_transform(self, input: torch.Tensor): + def get_transform(self, input: torch.Tensor) -> NoiseTransform: params = self._get_param_values(use_schedulers=True) std_devs = params["std_devs"] rho = params["rhos"] diff --git a/meddlr/utils/events.py b/meddlr/utils/events.py index 147521d9..7e5caea9 100644 --- a/meddlr/utils/events.py +++ b/meddlr/utils/events.py @@ -25,7 +25,7 @@ _PATH_MANAGER = get_path_manager() -def get_event_storage(): +def get_event_storage() -> "EventStorage": """ Returns: The :class:`EventStorage` object that's currently being used. diff --git a/tools/clean_results.py b/tools/clean_results.py index 922bbd19..a5e1f546 100644 --- a/tools/clean_results.py +++ b/tools/clean_results.py @@ -21,6 +21,7 @@ import pandas as pd from tabulate import tabulate +from tqdm import tqdm from meddlr.config import get_cfg from meddlr.evaluation.testing import find_weights @@ -56,7 +57,7 @@ def clean_results( iter_limit = [iter_limit] all_exp_paths = find_experiment_dirs(dirpath, completed=True) - for exp_path in all_exp_paths: + for exp_path in tqdm(all_exp_paths, desc="Searching for weights", disable=True): exp_path = os.path.abspath(exp_path) print(exp_path) cfg = get_cfg() @@ -91,6 +92,7 @@ def clean_results( if x.endswith(".pth") } remove_files = all_model_paths - filepaths_to_keep + remove_files = sorted(remove_files) print( "Found {}/{} files to remove from {}:\n\t{}".format( diff --git a/tools/eval_net.py b/tools/eval_net.py index 1addccdd..d3d42fe7 100644 --- a/tools/eval_net.py +++ b/tools/eval_net.py @@ -228,7 +228,7 @@ def eval(cfg, args, model, weights_basename, criterion, best_value): noise_vals = [0] if include_motion: - motion_vals = [0] + motion_sweep_vals if motion_arg == "sweep" else [0] + motion_vals = motion_sweep_vals if motion_arg == "sweep" else [0] motion_vals = sorted(set(motion_vals)) else: motion_vals = [0] diff --git a/tools/train_net.py b/tools/train_net.py index 1f709568..a6960030 100644 --- a/tools/train_net.py +++ b/tools/train_net.py @@ -36,6 +36,10 @@ def setup(args): if not cfg.OUTPUT_DIR: raise ValueError("OUTPUT_DIR not specified") + # Always execute in reproducible mode and auto-version mode when training + args.reproducible = True + args.auto_version = True + default_setup(cfg, args) # TODO: Change resume=args.resume once functionality is specified.