Official implementation of PSMA-FM, a foundation model for whole-body PSMA PET/CT imaging trained via three-stage report-supervised learning.
PSMA-FM is trained in three sequential stages:
- Stage 1 — MAE Pre-training: A 3D Vision Transformer (ViT3D-B/8) is pre-trained via Masked Autoencoding on dual-channel PET/CT volumes.
- Stage 2 — Text Encoder Adaptation: BiomedVLP-CXR-BERT-general is adapted to PSMA radiology reports via LoRA-augmented Masked Language Modeling.
- Stage 3 — Vision-Language Alignment: The vision and text encoders are jointly aligned using multi-crop contrastive learning and auxiliary MAE regularization.
PSMA-FM/
├── code/
│ ├── configs/vlp/
│ │ ├── stage1_mae.yaml
│ │ ├── stage2_text_lora.yaml
│ │ └── stage3_lit_multicrop_mae.yaml
│ ├── data/
│ │ ├── mae_dataset.py
│ │ ├── text_mlm_dataset.py
│ │ └── vlp_dataset.py
│ ├── experiments/vlp/
│ │ ├── stage1_mae_pretraining.py
│ │ ├── stage2_text_lora_mlm.py
│ │ └── stage3_lit_multicrop_mae.py
│ ├── models/
│ │ └── vit3d_mae.py
│ └── optimizers/
│ └── lr_scheduler.py
├── notebooks/
│ └── test_stage3_lit_multicrop_mae.ipynb
├── train.sh
└── requirements.txt
conda create -n psma-fm python=3.10
conda activate psma-fm
pip install -r requirements.txt{
"training": [
{
"case_id": "patient_001",
"ct_image": "/path/to/ct.nii",
"pt_image": "/path/to/pet.nii",
"report": {
"findings": "...",
"impression": "...",
"full_text": "...",
"mi-T-stage": "T2",
"mi-N-stage": "N1",
"mi-M-stage": "M1",
"miTNM_stage": "T2N1M1"
}
}
],
"validation": [...]
}Update json_path in each YAML config before training.
sh train.sh code/experiments/vlp/stage1_mae_pretraining.py \
code/configs/vlp/stage1_mae.yamlsh train.sh code/experiments/vlp/stage2_text_lora_mlm.py \
code/configs/vlp/stage2_text_lora.yamlUpdate vision_encoder_ckpt and text_encoder_ckpt in the config, then:
sh train.sh code/experiments/vlp/stage3_lit_multicrop_mae.py \
code/configs/vlp/stage3_lit_multicrop_mae.yamlOpen notebooks/test_stage3_lit_multicrop_mae.ipynb to evaluate:
- Cross-modal retrieval (R@1, R@5, R@10)
- Zero-shot miTNM binary classification
- Learned crop aggregation weights
- t-SNE embedding visualization
- Image-to-image retrieval (MAP@5)
| Component | Details |
|---|---|
| Vision encoder | ViT3D-B/8 (768-dim, 12 layers, 12 heads) |
| Text encoder | BiomedVLP-CXR-BERT-general |
| Input modalities | Dual-channel PET + CT (96×96×96 crops) |
| Axial crops | 4 (pelvis, abdomen, chest, head/neck) |
| Embedding space | 768-dim (direct alignment, no projection) |
| Parameters | ~85M (vision) + ~109M (text) |
After Stage 3 training, checkpoints are saved to outputs/PSMA-FM/stage3_vlp/v1/checkpoints/.
Loading for inference:
from experiments.vlp.stage3_lit_multicrop_mae import PSMAFMTrainer
model = PSMAFMTrainer.load_from_checkpoint('path/to/last.ckpt')
model.eval()