Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions configs/tests/vortex/standard.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
DESCRIPTION:
BRIEF: Testing VORTEX config
EXP_NAME: test/vortex/standard
PROJECT_NAME: vortex
TAGS: ("mridata_knee_2019", "unet", "16x", "vortex")
OUTPUT_DIR: "results://tests/vortex/standard"
DATASETS:
TRAIN: ("mridata_knee_2019_train",)
VAL: ("mridata_knee_2019_val",)
TEST: ("mridata_knee_2019_test",)
AUG_TRAIN:
UNDERSAMPLE:
NAME: "PoissonDiskMaskFunc"
ACCELERATIONS: (16,)
CALIBRATION_SIZE: 20
AUG_TEST:
UNDERSAMPLE:
ACCELERATIONS: (16,)
DATALOADER:
NUM_WORKERS: 0
SUBSAMPLE_TRAIN:
NUM_TOTAL: 6 # Number of total scans to use
NUM_UNDERSAMPLED: 5 # Number of undersampled scans to use
SAMPLER_TRAIN: "AlternatingSampler" # Choices: ["", "AlternatingSampler"]
ALT_SAMPLER:
PERIOD_SUPERVISED: 1
PERIOD_UNSUPERVISED: 1
MODEL:
META_ARCHITECTURE: "VortexModel"
UNET:
CHANNELS: 32
DROPOUT: 0.0
IN_CHANNELS: 2
NUM_POOL_LAYERS: 4
OUT_CHANNELS: 2
RECON_LOSS:
NAME: "l1"
RENORMALIZE_DATA: False
A2R:
META_ARCHITECTURE: "UnetModel"
CONSISTENCY:
AUG:
MRI_RECON:
AUG_SENSITIVITY_MAPS: true
SCHEDULER_P:
IGNORE: false
TRANSFORMS: () # TODO: This must be filled in
SOLVER:
TRAIN_BATCH_SIZE: 2
TEST_BATCH_SIZE: 2
CHECKPOINT_PERIOD: 20
MAX_ITER: 60
TEST:
EVAL_PERIOD: 40
VAL_METRICS:
RECON: ("psnr", "psnr_scan", "psnr_mag", "psnr_mag_scan", "nrmse", "nrmse_scan", "nrmse_mag", "nrmse_mag_scan", "ssim (Wang)", "ssim (Wang)_scan")
TIME_SCALE: "iter"
SEED: 1000
VIS_PERIOD: 10
VERSION: 1
26 changes: 26 additions & 0 deletions configs/tests/vortex/vortex_multishot_motion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Test
_BASE_: "standard.yaml"
DESCRIPTION:
BRIEF: Testing VORTEX multi-shot motion config
EXP_NAME: test/vortex/multi-shot-motion
PROJECT_NAME: vortex
TAGS: ("mridata_knee_2019", "unet", "16x", "vortex")
OUTPUT_DIR: "results://tests/vortex/multi-shot-motion"
MODEL:
CONSISTENCY:
AUG:
MRI_RECON:
AUG_SENSITIVITY_MAPS: false
SCHEDULER_P:
IGNORE: false
TRANSFORMS:
- name: RandomMRIMultiShotMotion
p: 1.0
nshots: 8
trajectory: blocked
tfms_or_gens:
- name: RandomAffine
p:
angle: 1.0
pad_like: MRAugment
angle: 5
48 changes: 48 additions & 0 deletions configs/tests/vortex/vortex_ssdu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Self-Supervised VORTEX (i.e. VORTEX+SSDU)
_BASE_: "standard.yaml"
MODEL:
META_ARCHITECTURE: "VortexModel"
A2R:
META_ARCHITECTURE: "SSDUModel"
SSDU:
META_ARCHITECTURE: "UnetModel"
MASKER:
PARAMS:
p: 1.0
rhos: 0.4 # Based on hyperparameter search
kind: "uniform" # Choices: ["uniform", "gaussian"]
per_example: True
std_scale: 4 # Default value in SSDU code
UNET:
IN_CHANNELS: 2
OUT_CHANNELS: 2
CHANNELS: 32
NUM_POOL_LAYERS: 4
DROPOUT: 0.
CONSISTENCY:
AUG:
MRI_RECON:
AUG_SENSITIVITY_MAPS: true
SCHEDULER_P:
IGNORE: false
TRANSFORMS:
# Motion Transform
- name: RandomMRIMotion
p: 1.0
std_devs:
- 0.2
- 0.5
RECON_LOSS:
NAME: "k_l1"
RENORMALIZE_DATA: False
DATALOADER:
NUM_WORKERS: 0 # for debugging purposes
SAMPLER_TRAIN: "" # Random sampler in self-supervised settings
SUBSAMPLE_TRAIN:
NUM_TOTAL: 14
NUM_UNDERSAMPLED: 14 # undersample all scans - N2R + SSDU should still train.
SOLVER:
TRAIN_BATCH_SIZE: 1
TIME_SCALE: "iter"
OUTPUT_DIR: "results://tests/vortex/vortex_ssdu"
VERSION: 1
2 changes: 1 addition & 1 deletion meddlr/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _format_str(val_str: str, *, cfg: CfgNode, unroll: bool):
assert len(start) == len(end), f"Could not determine formatting string: {val_str}"

if len(start) == 0:
return val_str
return eval(val_str)

cfg_keys_to_search = [val_str[s + 1 : e] for s, e in zip(start, end)]
values = [cfg.get_recursive(v) for v in cfg_keys_to_search]
Expand Down
1 change: 1 addition & 0 deletions meddlr/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@
_C.MODEL.A2R = CN()
_C.MODEL.A2R.META_ARCHITECTURE = "GeneralizedUnrolledCNN"
_C.MODEL.A2R.USE_SUPERVISED_CONSISTENCY = False
_C.MODEL.A2R.EDGE_DC = False

# -----------------------------------------------------------------------------
# SSDU model
Expand Down
7 changes: 6 additions & 1 deletion meddlr/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,17 @@ def build_recon_val_loader(
"drop_last": False,
}

num_workers = cfg.DATALOADER.NUM_WORKERS
prefetch_factor = cfg.DATALOADER.PREFETCH_FACTOR
if env.pt_version() >= "2.0" and num_workers == 0:
prefetch_factor = None

val_loader = DataLoader(
dataset=val_data,
num_workers=cfg.DATALOADER.NUM_WORKERS,
pin_memory=True,
collate_fn=default_collate,
prefetch_factor=cfg.DATALOADER.PREFETCH_FACTOR,
prefetch_factor=prefetch_factor,
**dl_kwargs,
)
return val_loader
Expand Down
13 changes: 10 additions & 3 deletions meddlr/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,24 +286,31 @@ def collect_mask(
If sub-indices are collated, they will be summed.
out_channel_first (bool, optional): Reorders dimensions of output mask to Cx(...)
"""
# TODO: Make this more efficient.
if isinstance(index, int):
index = (index,)

is_ndarray = isinstance(mask, np.ndarray)
if is_ndarray:
mask = torch.as_tensor(mask)

if not any(isinstance(idx, Sequence) for idx in index):
mask = mask[..., index]
else:
o_seg = []
for idx in index:
c_seg = mask[..., idx]
if isinstance(idx, Sequence):
c_seg = np.sum(c_seg, axis=-1)
c_seg = c_seg.sum(dim=-1)
o_seg.append(c_seg)
mask = np.stack(o_seg, axis=-1)
mask = torch.stack(o_seg, axis=-1)

if out_channel_first:
last_idx = len(mask.shape) - 1
mask = np.transpose(mask, (last_idx,) + tuple(range(0, last_idx)))
mask = torch.permute(mask, (last_idx,) + tuple(range(0, last_idx)))

if is_ndarray:
mask = mask.numpy()
return mask


Expand Down
3 changes: 2 additions & 1 deletion meddlr/data/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __call__(
masked_kspace = masked_kspace.squeeze(0)
maps = maps.squeeze(0)
target = target.squeeze(0)
mask = mask.squeeze(0)

out = {
"kspace": masked_kspace,
Expand All @@ -399,7 +400,7 @@ def __call__(
"std": std,
"norm": norm,
"edge_mask": edge_mask.squeeze(0),
"mask": self._get_mask(masked_kspace),
"mask": mask,
}
if postprocessing_mask is not None:
out["postprocessing_mask"] = postprocessing_mask.squeeze(0)
Expand Down
1 change: 1 addition & 0 deletions meddlr/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def build_loss_computer(cls, cfg):
or cfg.MODEL.META_ARCHITECTURE == "M2RModel"
or cfg.MODEL.META_ARCHITECTURE == "NM2RModel"
or cfg.MODEL.META_ARCHITECTURE == "A2RModel"
or cfg.MODEL.META_ARCHITECTURE == "VortexModel"
else "BasicLossComputer"
)
return build_loss_computer(cfg, loss_computer)
Expand Down
3 changes: 3 additions & 0 deletions meddlr/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def values(self) -> Sequence[DatasetEvaluator]:
else:
return self._evaluators

def pop(self, key):
return self._evaluators.pop(key)

def reset(self):
for evaluator in self.values():
evaluator.reset()
Expand Down
19 changes: 13 additions & 6 deletions meddlr/evaluation/recon_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,12 +342,14 @@ def evaluate(self):
os.path.join(self._save_scan_dir, f"{scan_id}.h5"),
)

if self._group_by_scan:
pred_vals = self._group_results_by_scan()
if self._compute_metrics():
if self._group_by_scan:
pred_vals = self._group_results_by_scan()
else:
pred_vals = self.slice_metrics.to_dict()
pred_vals.update(self.scan_metrics.to_dict())
else:
pred_vals = self.slice_metrics.to_dict()
pred_vals.update(self.scan_metrics.to_dict())

pred_vals = {}
self._results = pred_vals

if not self._is_flushing:
Expand All @@ -374,8 +376,11 @@ def _group_results_by_scan(self):

return pred_vals

def _compute_metrics(self) -> bool:
return len(self._metric_names) > 0

def log_summary(self):
if not comm.is_main_process():
if not comm.is_main_process() or not self._compute_metrics():
return

output_dir = self._output_dir
Expand Down Expand Up @@ -431,6 +436,8 @@ def evaluate_prediction(
ex_id: Union[str, Sequence[str]],
is_batch=False,
):
if len(metrics) == 0:
return {}
output, target = prediction["pred"], prediction["target"]
if not is_batch:
output, target = output.unsqueeze(0), target.unsqueeze(0)
Expand Down
5 changes: 4 additions & 1 deletion meddlr/evaluation/scan_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def flush(self, enter_prediction_scope: bool = True, skip_last_scan: bool = True

Args:
enter_prediction_scope (bool, optional): If ``True``, enter the
prediction scope.
prediction scope. You may want to set this to ``False`` if an
external object is managing the prediction scope.
This allows the external object to access predictions before
they are released from memory.
skip_last_scan (bool, optional): If ``True``, does not flush
most recent scan. This avoids prematurely computing metrics
before all slices of the scan are available.
Expand Down
13 changes: 11 additions & 2 deletions meddlr/evaluation/seg_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time
import uuid
from typing import Any, Dict

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -136,8 +137,16 @@ def reset(self):
self.slice_metrics.eval()
self.scan_metrics.eval()

def structure_scans(self, verbose=True):
"""Structure scans into volumes to be used to evaluation."""
def structure_scans(self, verbose=True) -> Dict[str, Dict[str, Any]]:
"""Structure scans into volumes to be used to evaluation.

Returns:
dict: a dictionary mapping scan_id to a dictionary with keys:
- "pred": the predicted segmentation mask - shape: (C, D, H, W)
- "target": the ground truth segmentation mask - shape: (C, D, H, W)
- "voxel_spacing": the voxel spacing of the scan (D, H, W, C)
- "affine": Affine matrix for MedicalVolume (D, H, W, C)
"""
out = structure_scans(self._predictions, verbose=verbose, dims={1: "slice_id"})
return out

Expand Down
6 changes: 5 additions & 1 deletion meddlr/evaluation/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,13 @@ def check_consistency(state_dict: Dict[str, Any], model: torch.nn.Module):
model (nn.Module): A Pytorch model.
"""
_state_dict = model.state_dict()
mismatched_keys = []
for k in state_dict:
assert k in _state_dict, f"{k} not in model state_dict: {_state_dict.keys()}"
assert torch.equal(state_dict[k], _state_dict[k]), f"Mismatch values: {k}"
if not torch.equal(state_dict[k], _state_dict[k]):
mismatched_keys.append(k)
if len(mismatched_keys):
raise ValueError("Mismatch keys:\n\t{}".format("\t\n".join(mismatched_keys)))


def _metrics_from_x(metrics_file, criterion):
Expand Down
16 changes: 16 additions & 0 deletions meddlr/modeling/loss_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ def __init__(self, cfg):
self.use_consistency = cfg.MODEL.CONSISTENCY.USE_CONSISTENCY
self.num_latent_layers = cfg.MODEL.CONSISTENCY.NUM_LATENT_LAYERS
self.latent_keys = ["E4", "E3", "D3", "E2", "D2", "E1", "D1"]

# Apply data consistency loss between input and augmented recon.
self.use_dc = False
# self.use_robust = cfg.MODEL.LOSS.USE_ROBUST
# self.beta = cfg.MODEL.LOSS.BETA
# self.robust_step_size = cfg.MODEL.LOSS.ROBUST_STEP_SIZE
Expand Down Expand Up @@ -285,6 +288,17 @@ def __call__(self, input, output):
}
if output_consistency is not None and self.use_consistency:
loss += self.consistency_weight * metrics_consistency["cons_loss"]
if output_consistency is not None and self.use_dc:
pred = output["dc"]["pred"].contiguous()
target = output["dc"]["target"].contiguous()
abs_error = cplx.abs(pred - target)
kl1_norm = torch.sum(abs_error) / torch.sum(cplx.abs(target))
kl2_norm = torch.sqrt(torch.sum(abs_error**2)) / torch.sqrt(
torch.sum(cplx.abs(target) ** 2)
) # noqa: E501
dc_loss = 0.5 * kl1_norm + 0.5 * kl2_norm
# dc_loss = torch.mean(torch.abs(pred - target))
loss += (self.consistency_weight / 10) * dc_loss

if self.use_latent:
num_losses = self.num_latent_layers * 2 - 1
Expand Down Expand Up @@ -314,6 +328,8 @@ def __call__(self, input, output):
if self.use_latent:
for i in range(num_losses):
metrics.update(all_metrics_latent[i])
if self.use_dc:
metrics["dc_loss"] = dc_loss

metrics["loss"] = loss
return metrics
Expand Down
Loading