Skip to content

mansour2002/PSMA_FM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PSMA-FM: A Three-Stage Multimodal Foundation Model for Whole-Body PSMA PET/CT

Official implementation of PSMA-FM, a foundation model for whole-body PSMA PET/CT imaging trained via three-stage report-supervised learning.


Overview

PSMA-FM is trained in three sequential stages:

  1. Stage 1 — MAE Pre-training: A 3D Vision Transformer (ViT3D-B/8) is pre-trained via Masked Autoencoding on dual-channel PET/CT volumes.
  2. Stage 2 — Text Encoder Adaptation: BiomedVLP-CXR-BERT-general is adapted to PSMA radiology reports via LoRA-augmented Masked Language Modeling.
  3. Stage 3 — Vision-Language Alignment: The vision and text encoders are jointly aligned using multi-crop contrastive learning and auxiliary MAE regularization.

Repository Structure

PSMA-FM/
├── code/
│   ├── configs/vlp/
│   │   ├── stage1_mae.yaml
│   │   ├── stage2_text_lora.yaml
│   │   └── stage3_lit_multicrop_mae.yaml
│   ├── data/
│   │   ├── mae_dataset.py
│   │   ├── text_mlm_dataset.py
│   │   └── vlp_dataset.py
│   ├── experiments/vlp/
│   │   ├── stage1_mae_pretraining.py
│   │   ├── stage2_text_lora_mlm.py
│   │   └── stage3_lit_multicrop_mae.py
│   ├── models/
│   │   └── vit3d_mae.py
│   └── optimizers/
│       └── lr_scheduler.py
├── notebooks/
│   └── test_stage3_lit_multicrop_mae.ipynb
├── train.sh
└── requirements.txt

Installation

conda create -n psma-fm python=3.10
conda activate psma-fm
pip install -r requirements.txt

Dataset Format

{
  "training": [
    {
      "case_id": "patient_001",
      "ct_image": "/path/to/ct.nii",
      "pt_image": "/path/to/pet.nii",
      "report": {
        "findings": "...",
        "impression": "...",
        "full_text": "...",
        "mi-T-stage": "T2",
        "mi-N-stage": "N1",
        "mi-M-stage": "M1",
        "miTNM_stage": "T2N1M1"
      }
    }
  ],
  "validation": [...]
}

Update json_path in each YAML config before training.


Training

Stage 1: MAE Pre-training

sh train.sh code/experiments/vlp/stage1_mae_pretraining.py \
            code/configs/vlp/stage1_mae.yaml

Stage 2: Text Encoder Adaptation

sh train.sh code/experiments/vlp/stage2_text_lora_mlm.py \
            code/configs/vlp/stage2_text_lora.yaml

Stage 3: Vision-Language Alignment

Update vision_encoder_ckpt and text_encoder_ckpt in the config, then:

sh train.sh code/experiments/vlp/stage3_lit_multicrop_mae.py \
            code/configs/vlp/stage3_lit_multicrop_mae.yaml

Evaluation

Open notebooks/test_stage3_lit_multicrop_mae.ipynb to evaluate:

  • Cross-modal retrieval (R@1, R@5, R@10)
  • Zero-shot miTNM binary classification
  • Learned crop aggregation weights
  • t-SNE embedding visualization
  • Image-to-image retrieval (MAP@5)

Architecture

Component Details
Vision encoder ViT3D-B/8 (768-dim, 12 layers, 12 heads)
Text encoder BiomedVLP-CXR-BERT-general
Input modalities Dual-channel PET + CT (96×96×96 crops)
Axial crops 4 (pelvis, abdomen, chest, head/neck)
Embedding space 768-dim (direct alignment, no projection)
Parameters ~85M (vision) + ~109M (text)

Checkpoints

After Stage 3 training, checkpoints are saved to outputs/PSMA-FM/stage3_vlp/v1/checkpoints/.

Loading for inference:

from experiments.vlp.stage3_lit_multicrop_mae import PSMAFMTrainer

model = PSMAFMTrainer.load_from_checkpoint('path/to/last.ckpt')
model.eval()

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors