Utility modules for deep learning workflows.
The test_time_augmentation module provides test-time augmentation for 2D+t and 3D+t image sequences. It applies geometric augmentations during inference, runs the model on all variants, reverses the transforms on outputs, and returns aggregated predictions with uncertainty estimates.
- 2D+t support: 8-fold augmentations (4 rotations × 2 flip states)
- 3D+t isotropic: 48-fold augmentations (full octahedral symmetry group)
- 3D+t anisotropic: 16-fold augmentations (XY-plane rotations only, for volumes with different Z resolution)
- Memory efficient: Mini-batch processing with online statistics computation
- Random sampling: Option to use a subset of augmentations
- Flexible output handling: Works with tensor, tuple, or dict model outputs
from test_time_augmentation import TTAWrapper, tta_inferenceimport torch
from test_time_augmentation import TTAWrapper
# Wrap your model
model = MySegmentationModel()
tta_model = TTAWrapper(model, mode='2d')
# Run inference - returns (mean, std)
images = torch.randn(8, 1, 256, 256) # T, C, H, W
mean, std = tta_model(images)images = torch.randn(8, 1, 256, 256)
masks = torch.randint(0, 2, (8, 1, 256, 256))
mean, std = tta_model(images, masks=masks)tta_model = TTAWrapper(model, mode='3d', isotropic=True)
volumes = torch.randn(4, 1, 64, 128, 128) # T, C, D, H, W
mean, std = tta_model(volumes)For volumes with different Z resolution where only XY-plane rotations are valid:
tta_model = TTAWrapper(model, mode='3d', isotropic=False)
mean, std = tta_model(volumes)Process augmentations in mini-batches to avoid OOM:
tta_model = TTAWrapper(
model,
mode='3d',
batch_size=8 # Process 8 augmentations at a time
)Use a random subset of transforms (useful for large augmentation spaces):
tta_model = TTAWrapper(
model,
mode='3d',
n_augmentations=12 # Randomly sample 12 of 48 transforms
)Note: The identity transform is always included when sampling.
tta_model = TTAWrapper(
model,
mode='3d',
isotropic=False, # 16 transforms (anisotropic)
n_augmentations=8, # Sample 8 randomly
batch_size=4 # Process 4 at a time
)For one-off inference without creating a wrapper:
from test_time_augmentation import tta_inference
mean, std = tta_inference(model, images, mode='2d')
mean, std = tta_inference(model, volumes, mode='3d', isotropic=False, batch_size=8)mean, std, all_predictions = tta_model(images, return_all=True)
# all_predictions is a list of individual outputs from each augmentation| Mode | isotropic | Transforms | Description |
|---|---|---|---|
| 2d | - | 8 | 4 rotations (0°, 90°, 180°, 270°) × 2 flip states |
| 3d | True | 48 | 6 axis permutations × 8 flip combinations |
| 3d | False | 16 | 4 XY rotations × 2 H-flip × 2 D-flip |
TTAWrapper(
model: nn.Module,
mode: str = '2d', # '2d' or '3d'
spatial_dims: Optional[Tuple[int, ...]] = None, # Custom spatial dims
output_spatial_dims: Optional[Tuple[int, ...]] = None,
batch_size: Optional[int] = None, # Mini-batch size for augmentations
n_augmentations: Optional[int] = None, # Random sample size
isotropic: bool = True # For 3D: True=48, False=16 transforms
)forward(images, masks=None, return_all=False)- Run TTA inferenceget_num_augmentations()- Number of augmentations used per forward passget_total_augmentations()- Total available augmentations
get_2d_transforms(spatial_dims)- Get list of 8 2D transformsget_3d_transforms(spatial_dims)- Get list of 48 3D isotropic transformsget_3d_anisotropic_transforms(spatial_dims)- Get list of 16 3D anisotropic transformstta_inference(model, images, ...)- Functional interface for TTA
For non-PyTorch workflows:
from test_time_augmentation import tta_inference_numpy, get_2d_transforms_numpy
def my_model(images, masks=None):
# Your inference logic
return predictions
mean, std = tta_inference_numpy(my_model, images, masks=masks)The infer module provides a general-purpose inference pipeline for running models on datasets and saving instance segmentation results to zarr format.
- Method-agnostic: Works with any segmentation method via the
InferenceProcessorprotocol - Zarr output: Saves predictions to zarr containers with shape (C, T, Y, X)
- Unique instance IDs: Instance labels are globally unique across all frames
- CSV export: Optional centroid and confidence output for tracking
- Raw output saving: Optionally save raw model outputs alongside predictions
from x_utils.infer import infer, InferenceProcessor
# Create a method-specific processor (e.g., from embedseg)
from embedseg.inference import EmbedSegProcessor
processor = EmbedSegProcessor(
seed_thresh=0.9,
fg_thresh=0.5,
min_object_size=36,
)
# Run inference
infer(
test_dataset=dataset, # Returns (image, time) tuples
model=model,
device=torch.device("cuda"),
zarr_container=Path("output.zarr"),
processor=processor,
checkpoint=Path("model.pth"), # Optional
csv_output=Path("tracks.csv"), # Optional
)Implement the InferenceProcessor protocol for custom methods:
class MyProcessor:
def setup(self, height: int, width: int, device: torch.device) -> None:
# One-time setup (e.g., create spatial grids)
self.device = device
def __call__(self, output: torch.Tensor, return_metadata: bool = False):
# Convert model output to instance segmentation
prediction = my_clustering(output)
if return_metadata:
return prediction, {1: 0.95, 2: 0.87} # id -> confidence
return prediction