Skip to content

Add support for training trackastra with SAM2 features#61

Open
anwai98 wants to merge 11 commits into
weigertlab:mainfrom
anwai98:add-training-support-with-sam2-feats
Open

Add support for training trackastra with SAM2 features#61
anwai98 wants to merge 11 commits into
weigertlab:mainfrom
anwai98:add-training-support-with-sam2-feats

Conversation

@anwai98

@anwai98 anwai98 commented Apr 1, 2026

Copy link
Copy Markdown
Contributor

Hi @C-Achard,

Here's are my minimal changes to make training work with SAM2 features.

Let me know how it looks!

PS. In case it helps, here's my yaml config file to train trackastra:

yaml config
# Trackastra finetuning config file for TOIAM dataset (using SAM2 features)
# Run: python /mnt/vast-nhr/home/archit/u12090/trackastra/scripts/train.py -c train_config.yaml

name: toiam_sam2_features
outdir: ./runs

# Data
ndim: 2
input_train:
  - /mnt/vast-nhr/projects/cidas/cca/data/toiam/data/00
  - /mnt/vast-nhr/projects/cidas/cca/data/toiam/data/01
input_val:
  - /mnt/vast-nhr/projects/cidas/cca/data/toiam/data/04
detection_folders:
  - TRA
  - SEG

# Feature backbone (aligned to pretrained model)
features: pretrained_feats_aug
pretrained_feats_model: facebook/sam2.1-hiera-base-plus
pretrained_feats_mode: mean_patches_exact
pretrained_feats_additional_props: regionprops_small
pretrained_n_augs: 15
reduced_pretrained_feat_dim: 128
rotate_features: true

# Finetuning from pretrained
model: /user/archit/u12090/.local/share/trackastra/models/general_2d_w_SAM2_features

# Model architecture (matching pretrained)
d_model: 256
num_encoder_layers: 4
num_decoder_layers: 4
dropout: 0.05
window: 4
attn_dist_mode: v1
causal_norm: none

# Training hyperparameters
epochs: 500
warmup_epochs: 5
train_samples: 32000
batch_size: 16
max_tokens: 2048
weight_decay: 0.01
weight_by_dataset: true

# Augmentation
crop_size:
  - 320
  - 320

# Caching
cachedir: ./runs/.cache

# Logging and other misc. stuff
logger: tensorboard
seed: 42

@anwai98

anwai98 commented Apr 1, 2026

Copy link
Copy Markdown
Contributor Author

Poof, some ruff linting structures are funny haha. All should be working now. Lemme know how it looks @C-Achard

@C-Achard C-Achard left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @anwai98, had a quick look and the approach seems reasonable, if you end up requiring changes on the pretrained_feats repo happy to have a look as well.

One thing I noticed is that in train.py, if no model path is given (training from scratch), it will load the basic model from Trackastra, rather than the one from pretrained_feats, since in the inference-only version create() is called only from TrackingTransformer.from_folder and then it would likely crash due to the extra args.
Now you did mention you wanted to fine-tune only but maybe the best is to add some error handling if anyone tries to train a pretrained_feats model from scratch, since the resulting exception will likely look unclear if no guard is added.

Otherwise, I noticed some slightly misleading help strings in the CLI, perhaps have a look at the manuscript for better context on what these options do (I added comments on these with recommended defaults).

Finally, if your next step is to train a model, those previous configs may come in handy for that.

I hope this helps, I'm afraid I cannot test this extensively right now but happy to help further if anything is unclear in the review.

Best,
Cyril

Comment thread scripts/train.py Outdated
Comment thread scripts/train.py Outdated
Comment thread scripts/train.py Outdated
Comment thread scripts/train.py Outdated
Comment thread scripts/train.py
anwai98 and others added 3 commits April 7, 2026 09:18
Co-authored-by: Cyril Achard <cyril.achard@epfl.ch>
Co-authored-by: Cyril Achard <cyril.achard@epfl.ch>
Co-authored-by: Cyril Achard <cyril.achard@epfl.ch>
@anwai98

anwai98 commented Apr 7, 2026

Copy link
Copy Markdown
Contributor Author

Hi @C-Achard,

Thank you so much for the detailed feedback. I'll check them out later in the evening and come back to you!

@anwai98 anwai98 requested a review from C-Achard May 14, 2026 18:19
@anwai98

anwai98 commented May 14, 2026

Copy link
Copy Markdown
Contributor Author

Hi @C-Achard,

Sorry for the super late follow-up. I managed to come back to the PR this week only - had a couple of busy weeks in the past.

I took care of the comments you left. What do you think about the current state now?

Comment thread trackastra/data/data.py
FeatureExtractor,
WRPretrainedFeatures,
)
if self.pretrained_n_augs != 3:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it actually generate the augmented copies ? That would be important for performance, but that might not be available on the pretrained_feats repo API.
If you're seeing too much overfitting that would be the first suspect, I can port the augmentation part to pretrained_feats if needed, lmk @anwai98

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @C-Achard,

Your hunch is right, the augmented copies are not being generated in the current state (and I raise a warning in data.py that shows pretrained_n_augs is not yet wired into FeatureExtractor).

Let's do it this way (if you agree): I'll run a quick training and see how the model performs, compared to the general SAM2 model. If I see overfitting, I'll take you up on the offer to port the augmentation API. Thanks! ;)

@C-Achard

Copy link
Copy Markdown
Contributor

Hi @C-Achard,

Sorry for the super late follow-up. I managed to come back to the PR this week only - had a couple of busy weeks in the past.

I took care of the comments you left. What do you think about the current state now?

Nice, thanks! This looks good as far as I can tell, definitely curious to see how this performs on your data. If it overfits too much I would look into the augmentation API (I can help) but this is likely not needed for a first check.

Let me know if I can help with anything else, thanks again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants