Skip to content

lmanan/x_utils

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

65 Commits
 
 
 
 
 
 
 
 

Repository files navigation

x_utils

Utility modules for deep learning workflows.

Test-Time Augmentation (TTA)

The test_time_augmentation module provides test-time augmentation for 2D+t and 3D+t image sequences. It applies geometric augmentations during inference, runs the model on all variants, reverses the transforms on outputs, and returns aggregated predictions with uncertainty estimates.

Features

  • 2D+t support: 8-fold augmentations (4 rotations × 2 flip states)
  • 3D+t isotropic: 48-fold augmentations (full octahedral symmetry group)
  • 3D+t anisotropic: 16-fold augmentations (XY-plane rotations only, for volumes with different Z resolution)
  • Memory efficient: Mini-batch processing with online statistics computation
  • Random sampling: Option to use a subset of augmentations
  • Flexible output handling: Works with tensor, tuple, or dict model outputs

Installation

from test_time_augmentation import TTAWrapper, tta_inference

Usage

Basic Usage

import torch
from test_time_augmentation import TTAWrapper

# Wrap your model
model = MySegmentationModel()
tta_model = TTAWrapper(model, mode='2d')

# Run inference - returns (mean, std)
images = torch.randn(8, 1, 256, 256)  # T, C, H, W
mean, std = tta_model(images)

With Masks

images = torch.randn(8, 1, 256, 256)
masks = torch.randint(0, 2, (8, 1, 256, 256))

mean, std = tta_model(images, masks=masks)

3D Isotropic (48 transforms)

tta_model = TTAWrapper(model, mode='3d', isotropic=True)
volumes = torch.randn(4, 1, 64, 128, 128)  # T, C, D, H, W
mean, std = tta_model(volumes)

3D Anisotropic (16 transforms)

For volumes with different Z resolution where only XY-plane rotations are valid:

tta_model = TTAWrapper(model, mode='3d', isotropic=False)
mean, std = tta_model(volumes)

Memory-Efficient Processing

Process augmentations in mini-batches to avoid OOM:

tta_model = TTAWrapper(
    model,
    mode='3d',
    batch_size=8  # Process 8 augmentations at a time
)

Random Augmentation Sampling

Use a random subset of transforms (useful for large augmentation spaces):

tta_model = TTAWrapper(
    model,
    mode='3d',
    n_augmentations=12  # Randomly sample 12 of 48 transforms
)

Note: The identity transform is always included when sampling.

Combining Options

tta_model = TTAWrapper(
    model,
    mode='3d',
    isotropic=False,      # 16 transforms (anisotropic)
    n_augmentations=8,    # Sample 8 randomly
    batch_size=4          # Process 4 at a time
)

Functional Interface

For one-off inference without creating a wrapper:

from test_time_augmentation import tta_inference

mean, std = tta_inference(model, images, mode='2d')
mean, std = tta_inference(model, volumes, mode='3d', isotropic=False, batch_size=8)

Get All Predictions

mean, std, all_predictions = tta_model(images, return_all=True)
# all_predictions is a list of individual outputs from each augmentation

Transform Summary

Mode isotropic Transforms Description
2d - 8 4 rotations (0°, 90°, 180°, 270°) × 2 flip states
3d True 48 6 axis permutations × 8 flip combinations
3d False 16 4 XY rotations × 2 H-flip × 2 D-flip

API Reference

TTAWrapper

TTAWrapper(
    model: nn.Module,
    mode: str = '2d',                              # '2d' or '3d'
    spatial_dims: Optional[Tuple[int, ...]] = None, # Custom spatial dims
    output_spatial_dims: Optional[Tuple[int, ...]] = None,
    batch_size: Optional[int] = None,              # Mini-batch size for augmentations
    n_augmentations: Optional[int] = None,         # Random sample size
    isotropic: bool = True                         # For 3D: True=48, False=16 transforms
)

Methods

  • forward(images, masks=None, return_all=False) - Run TTA inference
  • get_num_augmentations() - Number of augmentations used per forward pass
  • get_total_augmentations() - Total available augmentations

Standalone Functions

  • get_2d_transforms(spatial_dims) - Get list of 8 2D transforms
  • get_3d_transforms(spatial_dims) - Get list of 48 3D isotropic transforms
  • get_3d_anisotropic_transforms(spatial_dims) - Get list of 16 3D anisotropic transforms
  • tta_inference(model, images, ...) - Functional interface for TTA

NumPy Support

For non-PyTorch workflows:

from test_time_augmentation import tta_inference_numpy, get_2d_transforms_numpy

def my_model(images, masks=None):
    # Your inference logic
    return predictions

mean, std = tta_inference_numpy(my_model, images, masks=masks)

Inference

The infer module provides a general-purpose inference pipeline for running models on datasets and saving instance segmentation results to zarr format.

Features

  • Method-agnostic: Works with any segmentation method via the InferenceProcessor protocol
  • Zarr output: Saves predictions to zarr containers with shape (C, T, Y, X)
  • Unique instance IDs: Instance labels are globally unique across all frames
  • CSV export: Optional centroid and confidence output for tracking
  • Raw output saving: Optionally save raw model outputs alongside predictions

Usage

from x_utils.infer import infer, InferenceProcessor

# Create a method-specific processor (e.g., from embedseg)
from embedseg.inference import EmbedSegProcessor

processor = EmbedSegProcessor(
    seed_thresh=0.9,
    fg_thresh=0.5,
    min_object_size=36,
)

# Run inference
infer(
    test_dataset=dataset,        # Returns (image, time) tuples
    model=model,
    device=torch.device("cuda"),
    zarr_container=Path("output.zarr"),
    processor=processor,
    checkpoint=Path("model.pth"),  # Optional
    csv_output=Path("tracks.csv"), # Optional
)

Custom Processors

Implement the InferenceProcessor protocol for custom methods:

class MyProcessor:
    def setup(self, height: int, width: int, device: torch.device) -> None:
        # One-time setup (e.g., create spatial grids)
        self.device = device

    def __call__(self, output: torch.Tensor, return_metadata: bool = False):
        # Convert model output to instance segmentation
        prediction = my_clustering(output)
        if return_metadata:
            return prediction, {1: 0.95, 2: 0.87}  # id -> confidence
        return prediction

About

Repository for handling training and inference.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages