From 44196ca4fef1120020d54db5e104e24c87ae2a72 Mon Sep 17 00:00:00 2001 From: Ahmed Taha <74939755+ahmedtaha100@users.noreply.github.com> Date: Tue, 9 Jun 2026 17:45:30 -0400 Subject: [PATCH] Add self-supervised losses --- docs/api/losses.rst | 4 + optax/losses/__init__.py | 4 + optax/losses/_self_supervised.py | 525 +++++++++++++++++++++++++ optax/losses/_self_supervised_test.py | 540 ++++++++++++++++++++++++++ 4 files changed, 1073 insertions(+) diff --git a/docs/api/losses.rst b/docs/api/losses.rst index fcef7b72d..1f3dbd684 100644 --- a/docs/api/losses.rst +++ b/docs/api/losses.rst @@ -6,13 +6,16 @@ Losses .. autosummary:: :toctree: generated/ + barlow_twins_loss binary_dice_loss + byol_loss convex_kl_divergence cosine_distance cosine_similarity ctc_loss ctc_loss_with_forward_probs dice_loss + dino_loss generalized_kl_divergence hinge_loss huber_loss @@ -32,6 +35,7 @@ Losses safe_softmax_cross_entropy sigmoid_binary_cross_entropy sigmoid_focal_loss + simsiam_loss smooth_labels softmax_cross_entropy softmax_cross_entropy_with_integer_labels diff --git a/optax/losses/__init__.py b/optax/losses/__init__.py index b124b5f10..5c789abaa 100644 --- a/optax/losses/__init__.py +++ b/optax/losses/__init__.py @@ -47,6 +47,10 @@ from optax.losses._segmentation import binary_dice_loss from optax.losses._segmentation import dice_loss from optax.losses._segmentation import multiclass_generalized_dice_loss +from optax.losses._self_supervised import barlow_twins_loss +from optax.losses._self_supervised import byol_loss +from optax.losses._self_supervised import dino_loss from optax.losses._self_supervised import ntxent +from optax.losses._self_supervised import simsiam_loss from optax.losses._self_supervised import triplet_margin_loss from optax.losses._smoothing import smooth_labels diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index 269a829c2..2208f1f00 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -14,6 +14,8 @@ # ============================================================================== """Self supervised losses.""" +from typing import Optional + import jax from jax import lax import jax.numpy as jnp @@ -195,3 +197,526 @@ def triplet_margin_loss( ) loss = jnp.maximum(positive_distance - negative_distance + margin, 0) return loss + + +def _check_optional_pair( + first: Optional[jax.typing.ArrayLike], + second: Optional[jax.typing.ArrayLike], + first_name: str, + second_name: str, +) -> bool: + """Returns whether an optional pair is present, validating completeness.""" + if (first is None) != (second is None): + raise ValueError( + f'`{first_name}` and `{second_name}` must either both be provided or ' + 'both be None.' + ) + return first is not None + + +def _check_same_shapes( + reference: jax.typing.ArrayLike, + reference_name: str, + *others: tuple[jax.typing.ArrayLike, str], +): + """Checks that all arrays have the same shape as a reference array.""" + reference_shape = jnp.shape(reference) + for array, name in others: + array_shape = jnp.shape(array) + if reference_shape != array_shape: + raise ValueError( + f'`{reference_name}` and `{name}` must have the same shape, found ' + f'{reference_shape} and {array_shape}.' + ) + + +def _negative_cosine_similarity( + predictions: jax.typing.ArrayLike, + targets: jax.typing.ArrayLike, + eps: jax.typing.ArrayLike, +) -> jax.Array: + """Computes negative cosine similarity with stopped target gradients.""" + targets = lax.stop_gradient(targets) + return -_regression.cosine_similarity( + predictions, targets, epsilon=eps, axis=-1 + ) + + +def byol_loss( + online_prediction_1: jax.typing.ArrayLike, + target_projection_2: jax.typing.ArrayLike, + online_prediction_2: Optional[jax.typing.ArrayLike] = None, + target_projection_1: Optional[jax.typing.ArrayLike] = None, + *, + eps: jax.typing.ArrayLike = 1e-6, +) -> jax.Array: + r"""Computes the Bootstrap Your Own Latent (BYOL) loss. + + BYOL regresses online-network predictions toward target-network projections + computed from another augmented view of the same examples. Targets are + treated as stop-gradient values inside this loss. + + For one direction, with online prediction :math:`q` and target projection + :math:`z`, this function computes the squared distance between + :math:`\ell_2`-normalized vectors, + + .. math:: + D(q, z) = \|\bar{q} - \bar{z}\|_2^2 = 2 - 2\cos(q, z). + + If `online_prediction_2` and `target_projection_1` are both provided, the + function returns the symmetric two-view BYOL objective: + + .. math:: + \frac{1}{2}\left(D(q_1, z_2) + D(q_2, z_1)\right). + + .. note:: + The BYOL paper minimizes the *sum* of the two directions, + :math:`D(q_1, z_2) + D(q_2, z_1)`. This function averages them instead, + for consistency with :func:`simsiam_loss + `; multiply the result by two to recover the + paper's objective. + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> online_prediction = jnp.array([[1.0, 0.0], [0.0, 1.0]]) + >>> target_projection = jnp.array([[0.0, 1.0], [1.0, 0.0]]) + >>> print(optax.losses.byol_loss(online_prediction, target_projection)) + [2. 2.] + >>> print(optax.losses.byol_loss(online_prediction, online_prediction)) + [0. 0.] + + Args: + online_prediction_1: Online-network prediction for view 1, with shape + `[..., feature_dim]`. + target_projection_2: Target-network projection for view 2, with the same + shape as `online_prediction_1`. Gradients are stopped through this + argument. + online_prediction_2: Optional online-network prediction for view 2. If + provided, `target_projection_1` must also be provided. + target_projection_1: Optional target-network projection for view 1. + Gradients are stopped through this argument. If provided, + `online_prediction_2` must also be provided. + eps: Minimum squared norm enforced in the cosine-similarity denominator, + so the effective minimum norm is `sqrt(eps)`. + + Returns: + BYOL loss values for each example, with shape `[...]`. When both views + are provided, each value averages the two directions. Take the mean for + a scalar batch loss. + + References: + Grill et al, `Bootstrap Your Own Latent: A New Approach to + Self-Supervised Learning `_, 2020. + + .. versionadded:: 0.2.9 + """ + online_prediction_1 = jnp.asarray(online_prediction_1) + target_projection_2 = jnp.asarray(target_projection_2) + utils.check_subdtype(online_prediction_1, jnp.floating) + utils.check_subdtype(target_projection_2, jnp.floating) + _check_same_shapes( + online_prediction_1, + 'online_prediction_1', + (target_projection_2, 'target_projection_2'), + ) + + eps = jnp.asarray(eps, dtype=online_prediction_1.dtype) + loss_12 = 2.0 + 2.0 * _negative_cosine_similarity( + online_prediction_1, target_projection_2, eps + ) + + if not _check_optional_pair( + online_prediction_2, + target_projection_1, + 'online_prediction_2', + 'target_projection_1', + ): + return loss_12 + + online_prediction_2 = jnp.asarray(online_prediction_2) + target_projection_1 = jnp.asarray(target_projection_1) + utils.check_subdtype(online_prediction_2, jnp.floating) + utils.check_subdtype(target_projection_1, jnp.floating) + _check_same_shapes( + online_prediction_1, + 'online_prediction_1', + (online_prediction_2, 'online_prediction_2'), + (target_projection_1, 'target_projection_1'), + ) + + loss_21 = 2.0 + 2.0 * _negative_cosine_similarity( + online_prediction_2, target_projection_1, eps + ) + return 0.5 * (loss_12 + loss_21) + + +def simsiam_loss( + prediction_1: jax.typing.ArrayLike, + target_projection_2: jax.typing.ArrayLike, + prediction_2: Optional[jax.typing.ArrayLike] = None, + target_projection_1: Optional[jax.typing.ArrayLike] = None, + *, + eps: jax.typing.ArrayLike = 1e-6, +) -> jax.Array: + r"""Computes the SimSiam negative cosine similarity loss. + + SimSiam compares a predictor output (prediction) computed from one + augmented view with a stop-gradient projection computed from another view: + + .. math:: + D(p, z) = -\cos(p, \operatorname{stop\_gradient}(z)). + + If `prediction_2` and `target_projection_1` are both provided, the function + returns the symmetric two-view objective of the paper: + + .. math:: + \frac{1}{2} D(p_1, z_2) + \frac{1}{2} D(p_2, z_1). + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> prediction = jnp.array([[1.0, 0.0], [0.0, 1.0]]) + >>> print(optax.losses.simsiam_loss(prediction, prediction)) + [-1. -1.] + + Args: + prediction_1: Predictor output for view 1, with shape + `[..., feature_dim]`. + target_projection_2: Projection for view 2, used as the regression + target, with the same shape as `prediction_1`. Gradients are stopped + through this argument. + prediction_2: Optional predictor output for view 2. If provided, + `target_projection_1` must also be provided. + target_projection_1: Optional projection for view 1, used as the + regression target. Gradients are stopped through this argument. If + provided, `prediction_2` must also be provided. + eps: Minimum squared norm enforced in the cosine-similarity denominator, + so the effective minimum norm is `sqrt(eps)`. + + Returns: + SimSiam loss values for each example, with shape `[...]`. When both views + are provided, each value averages the two directions. Take the mean for a + scalar batch loss. + + References: + Chen and He, `Exploring Simple Siamese Representation Learning + `_, 2021. + + .. versionadded:: 0.2.9 + """ + prediction_1 = jnp.asarray(prediction_1) + target_projection_2 = jnp.asarray(target_projection_2) + utils.check_subdtype(prediction_1, jnp.floating) + utils.check_subdtype(target_projection_2, jnp.floating) + _check_same_shapes( + prediction_1, + 'prediction_1', + (target_projection_2, 'target_projection_2'), + ) + + eps = jnp.asarray(eps, dtype=prediction_1.dtype) + loss_12 = _negative_cosine_similarity( + prediction_1, target_projection_2, eps + ) + + if not _check_optional_pair( + prediction_2, + target_projection_1, + 'prediction_2', + 'target_projection_1', + ): + return loss_12 + + prediction_2 = jnp.asarray(prediction_2) + target_projection_1 = jnp.asarray(target_projection_1) + utils.check_subdtype(prediction_2, jnp.floating) + utils.check_subdtype(target_projection_1, jnp.floating) + _check_same_shapes( + prediction_1, + 'prediction_1', + (prediction_2, 'prediction_2'), + (target_projection_1, 'target_projection_1'), + ) + + loss_21 = _negative_cosine_similarity( + prediction_2, target_projection_1, eps + ) + return 0.5 * (loss_12 + loss_21) + + +def _positive_temperature( + temperature: jax.typing.ArrayLike, name: str +) -> jax.Array: + """Validates that a temperature is a scalar and, when concrete, positive.""" + temperature = jnp.asarray(temperature) + if jnp.ndim(temperature) != 0: + raise ValueError(f'`{name}` must be a scalar.') + try: + is_positive = bool(temperature > 0) + except jax.errors.TracerBoolConversionError: + # Traced temperatures cannot be validated at trace time. + return temperature + if not is_positive: + raise ValueError(f'`{name}` must be positive.') + return temperature + + +def _single_view_dino_loss( + student_logits: jax.Array, + teacher_logits: jax.Array, + student_temperature: jax.Array, + teacher_temperature: jax.Array, + teacher_center: jax.typing.ArrayLike, +) -> jax.Array: + """Computes one DINO teacher-to-student cross entropy term per example.""" + teacher_center = jnp.asarray(teacher_center) + utils.check_subdtype(teacher_center, jnp.floating) + teacher_center = teacher_center.astype(teacher_logits.dtype) + student_temperature = student_temperature.astype(student_logits.dtype) + teacher_temperature = teacher_temperature.astype(teacher_logits.dtype) + logits_shape = jnp.shape(teacher_logits) + center_error = ValueError( + '`teacher_center` must be broadcastable to the teacher logits shape, ' + f'found {jnp.shape(teacher_center)} and {logits_shape}.' + ) + try: + broadcast_shape = jnp.broadcast_shapes( + jnp.shape(teacher_center), logits_shape + ) + except ValueError as error: + raise center_error from error + if broadcast_shape != logits_shape: + raise center_error + + teacher_probs = lax.stop_gradient( + jax.nn.softmax( + (teacher_logits - teacher_center) / teacher_temperature, axis=-1 + ) + ) + student_log_probs = jax.nn.log_softmax( + student_logits / student_temperature, axis=-1 + ) + return -jnp.sum(teacher_probs * student_log_probs, axis=-1) + + +def dino_loss( + student_logits_1: jax.typing.ArrayLike, + teacher_logits_2: jax.typing.ArrayLike, + student_logits_2: Optional[jax.typing.ArrayLike] = None, + teacher_logits_1: Optional[jax.typing.ArrayLike] = None, + *, + student_temperature: jax.typing.ArrayLike = 0.1, + teacher_temperature: jax.typing.ArrayLike = 0.04, + teacher_center: jax.typing.ArrayLike = 0.0, +) -> jax.Array: + r"""Computes the DINO self-distillation loss. + + DINO trains a student distribution to match a centered and sharpened + teacher distribution computed from a different augmented view of the same + examples. The teacher branch is treated as stop-gradient. + + For one (student, teacher) pair of views, this function computes the cross + entropy + + .. math:: + -\sum_k p_t(k)\log p_s(k), + + where :math:`p_t` is `softmax((teacher_logits - teacher_center) / + teacher_temperature)` and :math:`p_s` is `softmax(student_logits / + student_temperature)`. + + If `student_logits_2` and `teacher_logits_1` are both provided, the + function returns the two-view specialization of the DINO multi-crop + objective, averaging the term for `student_logits_1` against + `teacher_logits_2` with the term for `student_logits_2` against + `teacher_logits_1`. DINO never compares student and teacher outputs of the + same view, which is why the argument indices are cross-paired, as in + :func:`byol_loss `. For more than two crops, call + the single-pair form once per (student view, teacher view) pair of + distinct views and average the results. + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> student_logits = jnp.zeros((2, 4)) + >>> teacher_logits = jnp.zeros((2, 4)) + >>> loss = optax.losses.dino_loss(student_logits, teacher_logits) + >>> print(loss) # cross entropy of uniform distributions, i.e. log(4) + [1.3862944 1.3862944] + + Args: + student_logits_1: Student logits for view 1, with shape + `[..., num_classes]`. + teacher_logits_2: Teacher logits for view 2, with the same shape as + `student_logits_1`. Gradients are stopped through this argument. + student_logits_2: Optional student logits for view 2. If provided, + `teacher_logits_1` must also be provided. + teacher_logits_1: Optional teacher logits for view 1. Gradients are + stopped through this argument. If provided, `student_logits_2` must + also be provided. + student_temperature: Positive temperature for the student softmax. The + paper uses 0.1. Positivity can only be validated for concrete + (non-traced) values. + teacher_temperature: Positive temperature for the teacher softmax. The + paper warms this up from 0.04 to 0.07; pass the current schedule value. + Positivity can only be validated for concrete (non-traced) values. + teacher_center: Centering term subtracted from teacher logits before + temperature scaling. Must be broadcastable to the teacher logits shape, + and is cast to the teacher logits dtype. The paper maintains this + center as an exponential moving average of teacher outputs across + batches; the caller is responsible for updating it between steps. The + default 0.0 applies no centering. + + Returns: + DINO cross-entropy loss values for each example, with shape `[...]`. + When both views are provided, each value averages the two cross-view + terms. Take the mean for a scalar batch loss. + + References: + Caron et al, `Emerging Properties in Self-Supervised Vision Transformers + `_, 2021. + + .. versionadded:: 0.2.9 + """ + student_temperature = _positive_temperature( + student_temperature, 'student_temperature' + ) + teacher_temperature = _positive_temperature( + teacher_temperature, 'teacher_temperature' + ) + + student_logits_1 = jnp.asarray(student_logits_1) + teacher_logits_2 = jnp.asarray(teacher_logits_2) + utils.check_subdtype(student_logits_1, jnp.floating) + utils.check_subdtype(teacher_logits_2, jnp.floating) + _check_same_shapes( + student_logits_1, + 'student_logits_1', + (teacher_logits_2, 'teacher_logits_2'), + ) + + loss_12 = _single_view_dino_loss( + student_logits_1, + teacher_logits_2, + student_temperature, + teacher_temperature, + teacher_center, + ) + + if not _check_optional_pair( + student_logits_2, teacher_logits_1, 'student_logits_2', 'teacher_logits_1' + ): + return loss_12 + + student_logits_2 = jnp.asarray(student_logits_2) + teacher_logits_1 = jnp.asarray(teacher_logits_1) + utils.check_subdtype(student_logits_2, jnp.floating) + utils.check_subdtype(teacher_logits_1, jnp.floating) + _check_same_shapes( + student_logits_1, + 'student_logits_1', + (student_logits_2, 'student_logits_2'), + (teacher_logits_1, 'teacher_logits_1'), + ) + + loss_21 = _single_view_dino_loss( + student_logits_2, + teacher_logits_1, + student_temperature, + teacher_temperature, + teacher_center, + ) + return 0.5 * (loss_12 + loss_21) + + +def barlow_twins_loss( + projection_1: jax.typing.ArrayLike, + projection_2: jax.typing.ArrayLike, + *, + off_diagonal_scale: jax.typing.ArrayLike = 5e-3, + eps: jax.typing.ArrayLike = 1e-5, +) -> jax.Array: + r"""Computes the Barlow Twins redundancy-reduction loss. + + Barlow Twins compares two batches of projections for paired augmented views. + Each feature is standardized across the batch dimension (using the biased, + `1/batch_size`, variance, as in the batch-normalization-based reference + implementation), then a feature cross-correlation matrix :math:`C` is + computed. The objective pushes diagonal entries toward one and off-diagonal + entries toward zero: + + .. math:: + \sum_i (1 - C_{ii})^2 + + \lambda \sum_i \sum_{j \ne i} C_{ij}^2. + + This loss couples examples through batch statistics, so it returns a single + scalar and requires at least two examples. + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> projections = jnp.array([[1.0, 1.0, 1.0], + ... [-1.0, 1.0, -1.0], + ... [1.0, -1.0, -1.0], + ... [-1.0, -1.0, 1.0]]) + >>> loss = optax.losses.barlow_twins_loss(projections, projections) + >>> print(f'{loss:.4f}') + 0.0000 + + Args: + projection_1: Rank-2 array of projections for view 1, with shape + `[batch_size, feature_dim]` and `batch_size >= 2`. + projection_2: Rank-2 array of projections for view 2, with the same shape + as `projection_1`. + off_diagonal_scale: Multiplicative scale :math:`\lambda` for + off-diagonal terms. + eps: Small value added to per-feature variances before taking square + roots. The default matches the batch-normalization epsilon used by the + reference implementation. + + Returns: + Scalar Barlow Twins loss. + + References: + Zbontar et al, `Barlow Twins: Self-Supervised Learning via Redundancy + Reduction `_, 2021. + + .. versionadded:: 0.2.9 + """ + projection_1 = jnp.asarray(projection_1) + projection_2 = jnp.asarray(projection_2) + utils.check_subdtype(projection_1, jnp.floating) + utils.check_subdtype(projection_2, jnp.floating) + _check_same_shapes( + projection_1, 'projection_1', (projection_2, 'projection_2') + ) + utils.check_rank(projection_1, 2) + + batch_size = projection_1.shape[0] + if batch_size < 2: + raise ValueError( + '`projection_1` and `projection_2` must contain at least two examples' + f' to compute per-feature statistics, found batch size {batch_size}.' + ) + + eps = jnp.asarray(eps, dtype=projection_1.dtype) + off_diagonal_scale = jnp.asarray(off_diagonal_scale, dtype=projection_1.dtype) + + projection_1 = projection_1 - jnp.mean(projection_1, axis=0, keepdims=True) + projection_2 = projection_2 - jnp.mean(projection_2, axis=0, keepdims=True) + projection_1 = projection_1 / jnp.sqrt( + jnp.mean(projection_1**2, axis=0, keepdims=True) + eps + ) + projection_2 = projection_2 / jnp.sqrt( + jnp.mean(projection_2**2, axis=0, keepdims=True) + eps + ) + + cross_correlation = (projection_1.T @ projection_2) / batch_size + on_diagonal = jnp.diag(cross_correlation) + on_diagonal_loss = jnp.sum((1.0 - on_diagonal) ** 2) + + off_diagonal_loss = jnp.maximum( + 0.0, jnp.sum(cross_correlation**2) - jnp.sum(on_diagonal**2) + ) + return on_diagonal_loss + off_diagonal_scale * off_diagonal_loss diff --git a/optax/losses/_self_supervised_test.py b/optax/losses/_self_supervised_test.py index a4238d527..a1c293cf1 100644 --- a/optax/losses/_self_supervised_test.py +++ b/optax/losses/_self_supervised_test.py @@ -14,12 +14,15 @@ # ============================================================================== """Tests for self-supervised losses in `optax.losses._self_supervised.py`.""" +import functools + from absl.testing import absltest from absl.testing import parameterized import jax import jax.numpy as jnp import numpy as np +from optax._src import utils from optax.losses import _self_supervised @@ -125,5 +128,542 @@ def test_vmap(self, anchor, positive, negative): , atol=1e-4) +class ByolLossTest(parameterized.TestCase): + + @parameterized.parameters(0, 1, 42) + def test_single_direction_zero_for_identical_inputs(self, seed): + key = jax.random.key(seed) + projections = jax.random.normal(key, (8, 16), dtype=jnp.float32) + + loss = jax.jit(_self_supervised.byol_loss)(projections, projections) + + self.assertEqual(loss.shape, (8,)) + self.assertTrue(np.all(np.isfinite(loss))) + np.testing.assert_allclose(loss, jnp.zeros(8), atol=1e-5) + + @parameterized.parameters(0, 1, 42) + def test_single_direction_nonzero_for_opposite_inputs(self, seed): + key = jax.random.key(seed) + projections = jax.random.normal(key, (8, 16), dtype=jnp.float32) + + loss = jax.jit(_self_supervised.byol_loss)(projections, -projections) + + self.assertTrue(np.all(np.isfinite(loss))) + self.assertGreater(float(loss.min()), 1.0) + + @parameterized.parameters(0, 1, 42) + def test_single_direction_invariant_to_positive_scaling(self, seed): + key = jax.random.key(seed) + key1, key2 = jax.random.split(key) + online = jax.random.normal(key1, (8, 16), dtype=jnp.float32) + target = jax.random.normal(key2, (8, 16), dtype=jnp.float32) + byol_jit = jax.jit(_self_supervised.byol_loss) + + base = byol_jit(online, target) + scaled = byol_jit(3.0 * online, 0.5 * target) + + np.testing.assert_allclose(base, scaled, atol=1e-5) + + @parameterized.parameters(0, 1, 42) + def test_two_view_equals_average_of_single_direction_calls(self, seed): + key = jax.random.key(seed) + key1, key2, key3, key4 = jax.random.split(key, 4) + online_1 = jax.random.normal(key1, (8, 16), dtype=jnp.float32) + target_2 = jax.random.normal(key2, (8, 16), dtype=jnp.float32) + online_2 = jax.random.normal(key3, (8, 16), dtype=jnp.float32) + target_1 = jax.random.normal(key4, (8, 16), dtype=jnp.float32) + byol_jit = jax.jit(_self_supervised.byol_loss) + + two_view = byol_jit(online_1, target_2, online_2, target_1) + expected = 0.5 * ( + byol_jit(online_1, target_2) + byol_jit(online_2, target_1) + ) + + np.testing.assert_allclose(two_view, expected, atol=1e-6) + + @parameterized.parameters(0, 1, 42) + def test_equivariant_to_batch_permutation(self, seed): + key = jax.random.key(seed) + key1, key2, key3 = jax.random.split(key, 3) + online = jax.random.normal(key1, (10, 7), dtype=jnp.float32) + target = jax.random.normal(key2, (10, 7), dtype=jnp.float32) + permutation = jax.random.permutation(key3, online.shape[0]) + byol_jit = jax.jit(_self_supervised.byol_loss) + + base = byol_jit(online, target) + shuffled = byol_jit(online[permutation], target[permutation]) + + np.testing.assert_allclose(shuffled, base[permutation], atol=1e-5) + + def test_stops_gradient_through_target_projection(self): + key1, key2 = jax.random.split(jax.random.key(0)) + online = jax.random.normal(key1, (4, 5), dtype=jnp.float32) + target = jax.random.normal(key2, (4, 5), dtype=jnp.float32) + + grad_online, grad_target = jax.grad( + lambda o, t: _self_supervised.byol_loss(o, t).mean(), argnums=(0, 1) + )(online, target) + + self.assertGreater(float(jnp.linalg.norm(grad_online)), 0.0) + np.testing.assert_allclose(grad_target, jnp.zeros_like(target)) + + def test_finite_value_and_gradient_for_zero_norm_inputs(self): + projections = jnp.zeros((4, 5), dtype=jnp.float32) + + loss = jax.jit(_self_supervised.byol_loss)(projections, projections) + grad_fn = jax.grad( + lambda o: _self_supervised.byol_loss(o, projections).sum() + ) + grads = grad_fn(projections) + + self.assertTrue(np.all(np.isfinite(loss))) + self.assertTrue(np.all(np.isfinite(grads))) + + def test_raises_for_incomplete_or_mismatched_two_view_inputs(self): + projections = jnp.ones((4, 5), dtype=jnp.float32) + with self.assertRaises(ValueError): + _self_supervised.byol_loss( + projections, projections, online_prediction_2=projections + ) + with self.assertRaises(ValueError): + _self_supervised.byol_loss( + projections, projections, projections[:3], projections[:3] + ) + + def test_vmap(self): + key1, key2 = jax.random.split(jax.random.key(0)) + online = jax.random.normal(key1, (3, 8, 16), dtype=jnp.float32) + target = jax.random.normal(key2, (3, 8, 16), dtype=jnp.float32) + + vmapped = jax.jit(jax.vmap(_self_supervised.byol_loss))(online, target) + expected = jnp.stack([ + _self_supervised.byol_loss(online[i], target[i]) + for i in range(online.shape[0]) + ]) + + np.testing.assert_allclose(vmapped, expected, atol=1e-6) + + +class SimSiamLossTest(parameterized.TestCase): + + @parameterized.parameters(0, 1, 42) + def test_single_direction_minus_one_for_identical_inputs(self, seed): + key = jax.random.key(seed) + projections = jax.random.normal(key, (8, 16), dtype=jnp.float32) + + loss = jax.jit(_self_supervised.simsiam_loss)(projections, projections) + + self.assertEqual(loss.shape, (8,)) + self.assertTrue(np.all(np.isfinite(loss))) + np.testing.assert_allclose(loss, -jnp.ones(8), atol=1e-5) + + @parameterized.parameters(0, 1, 42) + def test_single_direction_invariant_to_positive_scaling(self, seed): + key = jax.random.key(seed) + key1, key2 = jax.random.split(key) + prediction = jax.random.normal(key1, (8, 16), dtype=jnp.float32) + target = jax.random.normal(key2, (8, 16), dtype=jnp.float32) + simsiam_jit = jax.jit(_self_supervised.simsiam_loss) + + base = simsiam_jit(prediction, target) + scaled = simsiam_jit(2.0 * prediction, 0.25 * target) + + np.testing.assert_allclose(base, scaled, atol=1e-5) + + @parameterized.parameters(0, 1, 42) + def test_two_view_equals_average_of_single_direction_calls(self, seed): + key = jax.random.key(seed) + key1, key2, key3, key4 = jax.random.split(key, 4) + prediction_1 = jax.random.normal(key1, (8, 16), dtype=jnp.float32) + target_2 = jax.random.normal(key2, (8, 16), dtype=jnp.float32) + prediction_2 = jax.random.normal(key3, (8, 16), dtype=jnp.float32) + target_1 = jax.random.normal(key4, (8, 16), dtype=jnp.float32) + simsiam_jit = jax.jit(_self_supervised.simsiam_loss) + + two_view = simsiam_jit(prediction_1, target_2, prediction_2, target_1) + expected = 0.5 * ( + simsiam_jit(prediction_1, target_2) + + simsiam_jit(prediction_2, target_1) + ) + + np.testing.assert_allclose(two_view, expected, atol=1e-6) + + @parameterized.parameters(0, 1, 42) + def test_equivariant_to_batch_permutation(self, seed): + key = jax.random.key(seed) + key1, key2, key3 = jax.random.split(key, 3) + prediction = jax.random.normal(key1, (10, 7), dtype=jnp.float32) + target = jax.random.normal(key2, (10, 7), dtype=jnp.float32) + permutation = jax.random.permutation(key3, prediction.shape[0]) + simsiam_jit = jax.jit(_self_supervised.simsiam_loss) + + base = simsiam_jit(prediction, target) + shuffled = simsiam_jit(prediction[permutation], target[permutation]) + + np.testing.assert_allclose(shuffled, base[permutation], atol=1e-5) + + def test_stops_gradient_through_target_projection(self): + key1, key2 = jax.random.split(jax.random.key(0)) + prediction = jax.random.normal(key1, (4, 5), dtype=jnp.float32) + target = jax.random.normal(key2, (4, 5), dtype=jnp.float32) + + grad_prediction, grad_target = jax.grad( + lambda p, t: _self_supervised.simsiam_loss(p, t).mean(), argnums=(0, 1) + )(prediction, target) + + self.assertGreater(float(jnp.linalg.norm(grad_prediction)), 0.0) + np.testing.assert_allclose(grad_target, jnp.zeros_like(target)) + + def test_finite_value_and_gradient_for_zero_norm_inputs(self): + projections = jnp.zeros((4, 5), dtype=jnp.float32) + + loss = jax.jit(_self_supervised.simsiam_loss)(projections, projections) + grad_fn = jax.grad( + lambda p: _self_supervised.simsiam_loss(p, projections).sum() + ) + grads = grad_fn(projections) + + self.assertTrue(np.all(np.isfinite(loss))) + self.assertTrue(np.all(np.isfinite(grads))) + + def test_raises_for_incomplete_or_mismatched_two_view_inputs(self): + projections = jnp.ones((4, 5), dtype=jnp.float32) + with self.assertRaises(ValueError): + _self_supervised.simsiam_loss( + projections, projections, prediction_2=projections + ) + with self.assertRaises(ValueError): + _self_supervised.simsiam_loss( + projections, projections, projections[:3], projections[:3] + ) + + def test_vmap(self): + key1, key2 = jax.random.split(jax.random.key(0)) + prediction = jax.random.normal(key1, (3, 8, 16), dtype=jnp.float32) + target = jax.random.normal(key2, (3, 8, 16), dtype=jnp.float32) + + vmapped = jax.jit(jax.vmap(_self_supervised.simsiam_loss))( + prediction, target + ) + expected = jnp.stack([ + _self_supervised.simsiam_loss(prediction[i], target[i]) + for i in range(prediction.shape[0]) + ]) + + np.testing.assert_allclose(vmapped, expected, atol=1e-6) + + +class DinoLossTest(parameterized.TestCase): + + @parameterized.parameters(0, 1, 42) + def test_single_pair_invariant_to_logit_translation(self, seed): + key = jax.random.key(seed) + key1, key2 = jax.random.split(key) + student_logits = jax.random.normal(key1, (4, 6), dtype=jnp.float32) + teacher_logits = jax.random.normal(key2, (4, 6), dtype=jnp.float32) + dino_jit = jax.jit(_self_supervised.dino_loss) + + base = dino_jit(student_logits, teacher_logits) + translated = dino_jit(student_logits + 7.0, teacher_logits + 7.0) + + self.assertEqual(base.shape, (4,)) + np.testing.assert_allclose(base, translated, atol=1e-5) + + @parameterized.parameters(0, 1, 42) + def test_two_view_equals_average_of_cross_view_single_calls(self, seed): + key = jax.random.key(seed) + key1, key2, key3, key4, key5 = jax.random.split(key, 5) + student_1 = jax.random.normal(key1, (4, 6), dtype=jnp.float32) + teacher_1 = jax.random.normal(key2, (4, 6), dtype=jnp.float32) + student_2 = jax.random.normal(key3, (4, 6), dtype=jnp.float32) + teacher_2 = jax.random.normal(key4, (4, 6), dtype=jnp.float32) + center = jax.random.normal(key5, (6,), dtype=jnp.float32) + dino_jit = jax.jit(_self_supervised.dino_loss) + + two_view = dino_jit( + student_1, teacher_2, student_2, teacher_1, teacher_center=center + ) + expected = 0.5 * ( + dino_jit(student_1, teacher_2, teacher_center=center) + + dino_jit(student_2, teacher_1, teacher_center=center) + ) + + np.testing.assert_allclose(two_view, expected, atol=1e-6) + + @parameterized.parameters(0, 1, 42) + def test_matching_distribution_not_worse_than_mismatch(self, seed): + key = jax.random.key(seed) + teacher_logits = jax.random.normal(key, (4, 6), dtype=jnp.float32) + dino_jit = jax.jit(_self_supervised.dino_loss) + + loss_match = dino_jit( + teacher_logits, + teacher_logits, + student_temperature=0.2, + teacher_temperature=0.2, + ) + loss_mismatch = dino_jit( + teacher_logits[:, ::-1], + teacher_logits, + student_temperature=0.2, + teacher_temperature=0.2, + ) + + self.assertTrue(np.all(np.isfinite(loss_match))) + self.assertTrue(np.all(np.isfinite(loss_mismatch))) + self.assertLessEqual( + float(loss_match.mean()), float(loss_mismatch.mean()) + 1e-6 + ) + + def test_stops_gradient_through_teacher_logits(self): + key1, key2 = jax.random.split(jax.random.key(0)) + student_logits = jax.random.normal(key1, (4, 6), dtype=jnp.float32) + teacher_logits = jax.random.normal(key2, (4, 6), dtype=jnp.float32) + + grad_student, grad_teacher = jax.grad( + lambda s, t: _self_supervised.dino_loss(s, t).mean(), argnums=(0, 1) + )(student_logits, teacher_logits) + + self.assertGreater(float(jnp.linalg.norm(grad_student)), 0.0) + np.testing.assert_allclose(grad_teacher, jnp.zeros_like(teacher_logits)) + + def test_raises_for_invalid_arguments(self): + student_logits = jnp.zeros((4, 6), dtype=jnp.float32) + teacher_logits = jnp.zeros((4, 6), dtype=jnp.float32) + with self.assertRaises(ValueError): + _self_supervised.dino_loss(student_logits, teacher_logits[:3]) + with self.assertRaises(ValueError): + _self_supervised.dino_loss( + student_logits, teacher_logits, student_logits_2=student_logits + ) + with self.assertRaises(ValueError): + _self_supervised.dino_loss( + student_logits, + teacher_logits, + student_logits[:3], + teacher_logits[:3], + ) + with self.assertRaises(ValueError): + _self_supervised.dino_loss( + student_logits, teacher_logits, student_temperature=0.0 + ) + with self.assertRaises(ValueError): + _self_supervised.dino_loss( + student_logits, + teacher_logits, + student_temperature=jnp.ones((2,), dtype=jnp.float32), + ) + with self.assertRaises(ValueError): + _self_supervised.dino_loss( + student_logits, teacher_logits, teacher_temperature=-0.1 + ) + with self.assertRaises(ValueError): + _self_supervised.dino_loss( + student_logits, + teacher_logits, + teacher_center=jnp.zeros((5,), dtype=jnp.float32), + ) + with self.assertRaises(ValueError): + # Broadcasts with, but not to, the teacher logits shape. + _self_supervised.dino_loss( + student_logits, + teacher_logits, + teacher_center=jnp.zeros((2, 1, 1), dtype=jnp.float32), + ) + with self.assertRaises(TypeError): + _self_supervised.dino_loss( + student_logits, + teacher_logits, + teacher_center=jnp.zeros((6,), dtype=jnp.int32), + ) + + def test_temperature_validation_is_skipped_for_traced_values(self): + student_logits = jnp.zeros((4, 6), dtype=jnp.float32) + teacher_logits = jnp.zeros((4, 6), dtype=jnp.float32) + + with self.assertRaises(ValueError): + _self_supervised.dino_loss( + student_logits, teacher_logits, student_temperature=-0.1 + ) + + # Traced temperatures cannot be validated at trace time, so the same + # value compiles and runs when passed as a traced array. + def loss_fn(temperature): + return _self_supervised.dino_loss( + student_logits, teacher_logits, student_temperature=temperature + ) + + traced = jax.jit(loss_fn)(jnp.asarray(-0.1, dtype=jnp.float32)) + self.assertEqual(traced.shape, (4,)) + + def test_temperatures_use_logit_dtype_under_x64(self): + student_logits = jnp.zeros((4, 6), dtype=jnp.float32) + teacher_logits = jnp.zeros((4, 6), dtype=jnp.float32) + + with utils.x64_precision(True): + loss = jax.jit(_self_supervised.dino_loss)( + student_logits, + teacher_logits, + student_temperature=0.1, + teacher_temperature=0.04, + ) + + self.assertEqual(loss.dtype, jnp.float32) + + def test_vmap(self): + key1, key2, key3 = jax.random.split(jax.random.key(0), 3) + student_logits = jax.random.normal(key1, (3, 4, 6), dtype=jnp.float32) + teacher_logits = jax.random.normal(key2, (3, 4, 6), dtype=jnp.float32) + center = jax.random.normal(key3, (6,), dtype=jnp.float32) + + dino = functools.partial(_self_supervised.dino_loss, teacher_center=center) + vmapped = jax.jit(jax.vmap(dino))(student_logits, teacher_logits) + expected = jnp.stack([ + dino(student_logits[i], teacher_logits[i]) + for i in range(student_logits.shape[0]) + ]) + + np.testing.assert_allclose(vmapped, expected, atol=1e-6) + + +def _random_decorrelated_projections(key, batch_size, feature_dim): + """Returns random projections whose cross-correlation matrix is identity. + + Centering a random matrix makes every column orthogonal to the all-ones + vector, so the Q factor of its QR decomposition has orthonormal, zero-mean + columns; scaling by `sqrt(batch_size)` gives unit-variance, decorrelated + features. + """ + projections = jax.random.normal( + key, (batch_size, feature_dim), dtype=jnp.float32 + ) + projections = projections - projections.mean(axis=0, keepdims=True) + q, _ = jnp.linalg.qr(projections) + return q * jnp.sqrt(batch_size) + + +class BarlowTwinsLossTest(parameterized.TestCase): + + @parameterized.parameters(0, 1, 42) + def test_zero_for_decorrelated_inputs_and_nonzero_otherwise(self, seed): + projections = _random_decorrelated_projections( + jax.random.key(seed), batch_size=8, feature_dim=5 + ) + + loss_same = jax.jit(_self_supervised.barlow_twins_loss)( + projections, projections + ) + loss_opposite = jax.jit(_self_supervised.barlow_twins_loss)( + projections, -projections + ) + + self.assertTrue(np.isfinite(loss_same)) + self.assertTrue(np.isfinite(loss_opposite)) + np.testing.assert_allclose(loss_same, 0.0, atol=1e-6) + self.assertGreater(float(loss_opposite), 1.0) + + @parameterized.parameters(0, 1, 42) + def test_invariant_to_batch_permutation(self, seed): + key = jax.random.key(seed) + key1, key2, key3 = jax.random.split(key, 3) + projection_1 = jax.random.normal(key1, (10, 7), dtype=jnp.float32) + projection_2 = jax.random.normal(key2, (10, 7), dtype=jnp.float32) + permutation = jax.random.permutation(key3, projection_1.shape[0]) + barlow_jit = jax.jit(_self_supervised.barlow_twins_loss) + + base = barlow_jit(projection_1, projection_2) + shuffled = barlow_jit( + projection_1[permutation], projection_2[permutation] + ) + + np.testing.assert_allclose(base, shuffled, atol=1e-5) + + def test_constant_inputs_are_finite(self): + projections = jnp.ones((4, 3), dtype=jnp.float32) + + loss = jax.jit(_self_supervised.barlow_twins_loss)( + projections, projections + ) + + self.assertTrue(np.isfinite(loss)) + + @parameterized.parameters(5e-3, 0.5) + def test_off_diagonal_term_scales_with_off_diagonal_scale( + self, off_diagonal_scale + ): + feature = jax.random.normal(jax.random.key(0), (8,), dtype=jnp.float32) + # Duplicating one feature makes the standardized cross-correlation matrix + # all ones, so the on-diagonal loss is ~0 and the off-diagonal sum of + # squares is 2; the loss must equal off_diagonal_scale * 2. + projections = jnp.stack([feature, feature], axis=1) + + loss = jax.jit( + functools.partial( + _self_supervised.barlow_twins_loss, + off_diagonal_scale=off_diagonal_scale, + ) + )(projections, projections) + + np.testing.assert_allclose(loss, 2.0 * off_diagonal_scale, rtol=1e-3) + + def test_raises_for_invalid_shapes(self): + with self.assertRaises(ValueError): + _self_supervised.barlow_twins_loss( + jnp.zeros((4, 3), dtype=jnp.float32), + jnp.zeros((5, 3), dtype=jnp.float32), + ) + with self.assertRaises(ValueError): + _self_supervised.barlow_twins_loss( + jnp.zeros((2, 3, 4), dtype=jnp.float32), + jnp.zeros((2, 3, 4), dtype=jnp.float32), + ) + with self.assertRaises(ValueError): + _self_supervised.barlow_twins_loss( + jnp.zeros((0, 3), dtype=jnp.float32), + jnp.zeros((0, 3), dtype=jnp.float32), + ) + with self.assertRaises(ValueError): + # Per-feature statistics require at least two examples. + _self_supervised.barlow_twins_loss( + jnp.zeros((1, 3), dtype=jnp.float32), + jnp.zeros((1, 3), dtype=jnp.float32), + ) + + def test_vmap(self): + key1, key2 = jax.random.split(jax.random.key(0)) + projection_1 = jax.random.normal(key1, (3, 8, 5), dtype=jnp.float32) + projection_2 = jax.random.normal(key2, (3, 8, 5), dtype=jnp.float32) + + vmapped = jax.jit(jax.vmap(_self_supervised.barlow_twins_loss))( + projection_1, projection_2 + ) + expected = jnp.stack([ + _self_supervised.barlow_twins_loss(projection_1[i], projection_2[i]) + for i in range(projection_1.shape[0]) + ]) + + np.testing.assert_allclose(vmapped, expected, atol=1e-6) + + +class SelfSupervisedDtypeTest(parameterized.TestCase): + + @parameterized.named_parameters( + ('byol', _self_supervised.byol_loss), + ('simsiam', _self_supervised.simsiam_loss), + ('dino', _self_supervised.dino_loss), + ('barlow_twins', _self_supervised.barlow_twins_loss), + ) + def test_raises_for_non_float_inputs(self, loss_fn): + inputs = jnp.ones((4, 3), dtype=jnp.int32) + with self.assertRaises(TypeError): + loss_fn(inputs, inputs) + + def test_byol_raises_for_non_float_second_pair(self): + projections = jnp.ones((4, 3), dtype=jnp.float32) + bad = jnp.ones((4, 3), dtype=jnp.int32) + with self.assertRaises(TypeError): + _self_supervised.byol_loss(projections, projections, bad, bad) + + if __name__ == '__main__': absltest.main()