Add self-supervised losses#1695
Open
ahmedtaha100 wants to merge 1 commit into
Open
Conversation
Author
|
@vz415 @emilyfertig @rdyro Hello, gentle ping on this. CI is green and it closes #1528, so it should be ready to review whenever you have bandwidth. Happy to make any changes, and if it's easier to review in smaller pieces I can split it into one PR per loss. Just let me know what works. Thank you! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #1528.
cc @rdyro
What changed
This PR adds four loss functions to
optax.losses:byol_loss,simsiam_loss,dino_loss, andbarlow_twins_loss.The code lives in
optax.losses._self_supervised. The functions are exported fromoptax.lossesand listed indocs/api/losses.rst.Behavior
byol_loss,simsiam_loss, anddino_lossreturn one loss value per example. This matches the usual Optax pattern for losses that do not couple examples. Users can call.mean()when they need one scalar for a batch.barlow_twins_lossreturns one scalar because it uses batch statistics and a cross correlation matrix.The two view APIs use optional second view arguments instead of a boolean flag. That keeps normal
jax.jituse simple because callers do not need a static flag.The implementation stops gradients on the target or teacher branches inside the loss functions. DINO uses cross view student and teacher pairs. Barlow Twins standardizes each feature with biased batch variance and
eps=0.00001, matching the reference batch normalization behavior.Validation
The functions check shapes, floating dtypes, DINO temperatures, DINO teacher center broadcasting, Barlow Twins rank, and the minimum Barlow Twins batch size.
Tests
The new tests cover numerical properties, seeded random inputs, shape and dtype errors, two view composition,
jax.jit,jax.vmap, stop gradient behavior, degenerate inputs, DINO translation invariance, x64 temperature dtype behavior, and Barlow Twins off diagonal scaling.Verification
python -m pytest optax/losses/_self_supervised_test.pyThe private validation branch passed GitHub Actions, including pytest across the configured Python and JAX matrix, ruff, flake8, pylint, pyrefly, pre commit checks, docs, and doctests.