Skip to content

Add self-supervised losses#1695

Open
ahmedtaha100 wants to merge 1 commit into
google-deepmind:mainfrom
ahmedtaha100:feat/self-supervised-losses
Open

Add self-supervised losses#1695
ahmedtaha100 wants to merge 1 commit into
google-deepmind:mainfrom
ahmedtaha100:feat/self-supervised-losses

Conversation

@ahmedtaha100

Copy link
Copy Markdown

Closes #1528.

cc @rdyro

What changed

This PR adds four loss functions to optax.losses: byol_loss, simsiam_loss, dino_loss, and barlow_twins_loss.

The code lives in optax.losses._self_supervised. The functions are exported from optax.losses and listed in docs/api/losses.rst.

Behavior

byol_loss, simsiam_loss, and dino_loss return one loss value per example. This matches the usual Optax pattern for losses that do not couple examples. Users can call .mean() when they need one scalar for a batch.

barlow_twins_loss returns one scalar because it uses batch statistics and a cross correlation matrix.

The two view APIs use optional second view arguments instead of a boolean flag. That keeps normal jax.jit use simple because callers do not need a static flag.

The implementation stops gradients on the target or teacher branches inside the loss functions. DINO uses cross view student and teacher pairs. Barlow Twins standardizes each feature with biased batch variance and eps=0.00001, matching the reference batch normalization behavior.

Validation

The functions check shapes, floating dtypes, DINO temperatures, DINO teacher center broadcasting, Barlow Twins rank, and the minimum Barlow Twins batch size.

Tests

The new tests cover numerical properties, seeded random inputs, shape and dtype errors, two view composition, jax.jit, jax.vmap, stop gradient behavior, degenerate inputs, DINO translation invariance, x64 temperature dtype behavior, and Barlow Twins off diagonal scaling.

Verification

python -m pytest optax/losses/_self_supervised_test.py

The private validation branch passed GitHub Actions, including pytest across the configured Python and JAX matrix, ruff, flake8, pylint, pyrefly, pre commit checks, docs, and doctests.

@ahmedtaha100

Copy link
Copy Markdown
Author

@vz415 @emilyfertig @rdyro Hello, gentle ping on this. CI is green and it closes #1528, so it should be ready to review whenever you have bandwidth. Happy to make any changes, and if it's easier to review in smaller pieces I can split it into one PR per loss. Just let me know what works. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature request: add BYOL, SimSiam, DINO and Barlow Twins losses to optax.losses

1 participant