From 476be990d8d3f7aa0243d638a95ec1d6ef4ebd9b Mon Sep 17 00:00:00 2001 From: Aditi Ramakrishnan Date: Fri, 26 Jun 2026 12:32:42 +0530 Subject: [PATCH 1/3] contrib: add sharding tests for contrib optimizers Add optax/contrib/_sharding_test.py mirroring the pattern in optax/_src/sharding_test.py. Tests 14 contrib optimizers across three checks: abstract-input init, sharding-type stability across init/update, and JIT-preservation of sharding types. Optimizers that require concrete parameter values at init (dadapt_adamw, dog, dowg, prodigy, schedule_free_adamw, schedule_free_sgd) are included in the sharding type tests but excluded from the abstract-init test, with a comment explaining why. --- optax/contrib/_sharding_test.py | 116 ++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 optax/contrib/_sharding_test.py diff --git a/optax/contrib/_sharding_test.py b/optax/contrib/_sharding_test.py new file mode 100644 index 000000000..e71e31358 --- /dev/null +++ b/optax/contrib/_sharding_test.py @@ -0,0 +1,116 @@ +# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sharding and dtype-stability tests for optax.contrib optimizers.""" + +import os + +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import optax +from optax._src import test_utils +from optax._src import utils + + +os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' + +# Optimizers that support initialization from abstract inputs +# (jax.ShapeDtypeStruct). Optimizers that require concrete parameter values +# during init (e.g. to record an initial norm) are excluded from this dict +# but tested in SHARDING_OPTIMIZERS below. +ABSTRACT_INIT_OPTIMIZERS = { + 'acprop': optax.contrib.acprop(1e-3), + 'ademamix': optax.contrib.ademamix(1e-3), + 'adopt': optax.contrib.adopt(1e-2), + 'cocob': optax.contrib.cocob(), + 'galore': optax.contrib.galore(1e-2, rank=4), + 'madgrad': optax.contrib.madgrad(1e-2), + 'muon': optax.contrib.muon(1e-2), + 'simplified_ademamix': optax.contrib.simplified_ademamix(1e-3), +} + +# All optimizers tested for sharding-type stability. Some (dadapt_adamw, dog, +# dowg, prodigy, schedule_free_*) require concrete params at init and are +# therefore omitted from ABSTRACT_INIT_OPTIMIZERS above. +SHARDING_OPTIMIZERS = { + **ABSTRACT_INIT_OPTIMIZERS, + 'dadapt_adamw': optax.contrib.dadapt_adamw(1e-1), + 'dog': optax.contrib.dog(1.0), + 'dowg': optax.contrib.dowg(1.0), + 'prodigy': optax.contrib.prodigy(1e-1), + 'schedule_free_adamw': optax.contrib.schedule_free_adamw( + 1e-2, warmup_steps=5000 + ), + 'schedule_free_sgd': optax.contrib.schedule_free_sgd( + 1e-2, warmup_steps=5000 + ), +} + + +class ContribShardingTest(parameterized.TestCase): + + @parameterized.named_parameters(ABSTRACT_INIT_OPTIMIZERS.items()) + def test_init_with_abstract_input(self, optimizer): + params = jax.ShapeDtypeStruct(shape=(2, 4, 8), dtype=jnp.float32) + state = optimizer.init(params) + self.assertIsNotNone(state) + + @parameterized.named_parameters(SHARDING_OPTIMIZERS.items()) + def test_state_sharding_type_init_match_update(self, optimizer): + if utils.parse_version(jax.__version__) < utils.parse_version('0.7.2'): + self.skipTest('Skipping sharding-in-types test') + mesh = jax.make_mesh( + (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,) + ) + sharding = jax.sharding.NamedSharding(mesh, jax.P(None, 'x')) + + with jax.set_mesh(mesh): + params = jnp.zeros((2, 8, 4), dtype=jnp.float16, out_sharding=sharding) + + state0 = optimizer.init(params) + _, state1 = optimizer.update(params, state0, params) + + type0 = jax.tree.map(jax.typeof, state0) + type1 = jax.tree.map(jax.typeof, state1) + test_utils.assert_trees_all_equal(type0, type1) + + @parameterized.named_parameters(SHARDING_OPTIMIZERS.items()) + def test_state_sharding_type_preserved_with_jit(self, optimizer): + if utils.parse_version(jax.__version__) < utils.parse_version('0.7.2'): + self.skipTest('Skipping sharding-in-types test') + mesh = jax.make_mesh( + (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,) + ) + sharding = jax.sharding.NamedSharding(mesh, jax.P(None, 'x')) + + with jax.set_mesh(mesh): + params = jnp.zeros((2, 8, 4), dtype=jnp.float16, out_sharding=sharding) + + state0 = optimizer.init(params) + state1 = jax.jit(optimizer.init)(params) + type0 = jax.tree.map(jax.typeof, state0) + type1 = jax.tree.map(jax.typeof, state1) + test_utils.assert_trees_all_equal(type0, type1) + + _, state2 = optimizer.update(params, state0, params) + _, state3 = jax.jit(optimizer.update)(params, state0, params) + type2 = jax.tree.map(jax.typeof, state2) + type3 = jax.tree.map(jax.typeof, state3) + test_utils.assert_trees_all_equal(type2, type3) + + +if __name__ == '__main__': + absltest.main() From 4cdd530cd941dba41801971c5936fd6f92b464ac Mon Sep 17 00:00:00 2001 From: Aditi Ramakrishnan Date: Sat, 27 Jun 2026 12:14:16 +0530 Subject: [PATCH 2/3] contrib: use jax.config to set device count in sharding tests Replace module-level os.environ manipulation with jax.config.update('jax_num_cpu_devices', 8). Also drop the abstract-init test for contrib optimizers (several require concrete values during init) and consolidate the two sharding checks into a single parameterized test. --- optax/contrib/_sharding_test.py | 95 +++++++++++++-------------------- 1 file changed, 36 insertions(+), 59 deletions(-) diff --git a/optax/contrib/_sharding_test.py b/optax/contrib/_sharding_test.py index e71e31358..11f53a124 100644 --- a/optax/contrib/_sharding_test.py +++ b/optax/contrib/_sharding_test.py @@ -14,8 +14,6 @@ # ============================================================================== """Sharding and dtype-stability tests for optax.contrib optimizers.""" -import os - from absl.testing import absltest from absl.testing import parameterized import jax @@ -24,32 +22,21 @@ from optax._src import test_utils from optax._src import utils +# Set device count before the JAX backend is initialized so that +# jax.make_mesh((8,), ...) works in the tests below. +jax.config.update('jax_num_cpu_devices', 8) -os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' - -# Optimizers that support initialization from abstract inputs -# (jax.ShapeDtypeStruct). Optimizers that require concrete parameter values -# during init (e.g. to record an initial norm) are excluded from this dict -# but tested in SHARDING_OPTIMIZERS below. -ABSTRACT_INIT_OPTIMIZERS = { +OPTIMIZERS = { 'acprop': optax.contrib.acprop(1e-3), 'ademamix': optax.contrib.ademamix(1e-3), 'adopt': optax.contrib.adopt(1e-2), 'cocob': optax.contrib.cocob(), - 'galore': optax.contrib.galore(1e-2, rank=4), - 'madgrad': optax.contrib.madgrad(1e-2), - 'muon': optax.contrib.muon(1e-2), - 'simplified_ademamix': optax.contrib.simplified_ademamix(1e-3), -} - -# All optimizers tested for sharding-type stability. Some (dadapt_adamw, dog, -# dowg, prodigy, schedule_free_*) require concrete params at init and are -# therefore omitted from ABSTRACT_INIT_OPTIMIZERS above. -SHARDING_OPTIMIZERS = { - **ABSTRACT_INIT_OPTIMIZERS, 'dadapt_adamw': optax.contrib.dadapt_adamw(1e-1), 'dog': optax.contrib.dog(1.0), 'dowg': optax.contrib.dowg(1.0), + 'galore': optax.contrib.galore(1e-2, rank=4), + 'madgrad': optax.contrib.madgrad(1e-2), + 'muon': optax.contrib.muon(1e-2), 'prodigy': optax.contrib.prodigy(1e-1), 'schedule_free_adamw': optax.contrib.schedule_free_adamw( 1e-2, warmup_steps=5000 @@ -57,59 +44,49 @@ 'schedule_free_sgd': optax.contrib.schedule_free_sgd( 1e-2, warmup_steps=5000 ), + 'simplified_ademamix': optax.contrib.simplified_ademamix(1e-3), } class ContribShardingTest(parameterized.TestCase): - @parameterized.named_parameters(ABSTRACT_INIT_OPTIMIZERS.items()) - def test_init_with_abstract_input(self, optimizer): - params = jax.ShapeDtypeStruct(shape=(2, 4, 8), dtype=jnp.float32) - state = optimizer.init(params) - self.assertIsNotNone(state) - - @parameterized.named_parameters(SHARDING_OPTIMIZERS.items()) - def test_state_sharding_type_init_match_update(self, optimizer): + @parameterized.named_parameters(OPTIMIZERS.items()) + def test_state_sharding_type_stable(self, optimizer): if utils.parse_version(jax.__version__) < utils.parse_version('0.7.2'): self.skipTest('Skipping sharding-in-types test') - mesh = jax.make_mesh( - (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,) - ) - sharding = jax.sharding.NamedSharding(mesh, jax.P(None, 'x')) - - with jax.set_mesh(mesh): - params = jnp.zeros((2, 8, 4), dtype=jnp.float16, out_sharding=sharding) - state0 = optimizer.init(params) - _, state1 = optimizer.update(params, state0, params) - - type0 = jax.tree.map(jax.typeof, state0) - type1 = jax.tree.map(jax.typeof, state1) - test_utils.assert_trees_all_equal(type0, type1) - - @parameterized.named_parameters(SHARDING_OPTIMIZERS.items()) - def test_state_sharding_type_preserved_with_jit(self, optimizer): - if utils.parse_version(jax.__version__) < utils.parse_version('0.7.2'): - self.skipTest('Skipping sharding-in-types test') mesh = jax.make_mesh( (8,), ('x',), axis_types=(jax.sharding.AxisType.Explicit,) ) sharding = jax.sharding.NamedSharding(mesh, jax.P(None, 'x')) with jax.set_mesh(mesh): - params = jnp.zeros((2, 8, 4), dtype=jnp.float16, out_sharding=sharding) - - state0 = optimizer.init(params) - state1 = jax.jit(optimizer.init)(params) - type0 = jax.tree.map(jax.typeof, state0) - type1 = jax.tree.map(jax.typeof, state1) - test_utils.assert_trees_all_equal(type0, type1) - - _, state2 = optimizer.update(params, state0, params) - _, state3 = jax.jit(optimizer.update)(params, state0, params) - type2 = jax.tree.map(jax.typeof, state2) - type3 = jax.tree.map(jax.typeof, state3) - test_utils.assert_trees_all_equal(type2, type3) + params = jnp.zeros( + (2, 8, 4), dtype=jnp.float16, out_sharding=sharding + ) + + # Eager init and update should have matching sharding types. + state_eager = optimizer.init(params) + _, state_after_update = optimizer.update(params, state_eager, params) + test_utils.assert_trees_all_equal( + jax.tree.map(jax.typeof, state_eager), + jax.tree.map(jax.typeof, state_after_update), + ) + + # JIT-compiled init and update should match their eager counterparts. + state_jit = jax.jit(optimizer.init)(params) + test_utils.assert_trees_all_equal( + jax.tree.map(jax.typeof, state_eager), + jax.tree.map(jax.typeof, state_jit), + ) + + _, state_update_jit = jax.jit(optimizer.update)( + params, state_eager, params + ) + test_utils.assert_trees_all_equal( + jax.tree.map(jax.typeof, state_after_update), + jax.tree.map(jax.typeof, state_update_jit), + ) if __name__ == '__main__': From 5889eaa9dd62f65fe87ba4a9e02423fb898c35d4 Mon Sep 17 00:00:00 2001 From: Aditi Ramakrishnan Date: Sun, 28 Jun 2026 00:21:57 +0530 Subject: [PATCH 3/3] contrib: drop jit sharding checks, eager init/update is sufficient --- optax/contrib/_sharding_test.py | 22 +++------------------- 1 file changed, 3 insertions(+), 19 deletions(-) diff --git a/optax/contrib/_sharding_test.py b/optax/contrib/_sharding_test.py index 11f53a124..9eb278782 100644 --- a/optax/contrib/_sharding_test.py +++ b/optax/contrib/_sharding_test.py @@ -65,29 +65,13 @@ def test_state_sharding_type_stable(self, optimizer): (2, 8, 4), dtype=jnp.float16, out_sharding=sharding ) - # Eager init and update should have matching sharding types. - state_eager = optimizer.init(params) - _, state_after_update = optimizer.update(params, state_eager, params) + state = optimizer.init(params) + _, state_after_update = optimizer.update(params, state, params) test_utils.assert_trees_all_equal( - jax.tree.map(jax.typeof, state_eager), + jax.tree.map(jax.typeof, state), jax.tree.map(jax.typeof, state_after_update), ) - # JIT-compiled init and update should match their eager counterparts. - state_jit = jax.jit(optimizer.init)(params) - test_utils.assert_trees_all_equal( - jax.tree.map(jax.typeof, state_eager), - jax.tree.map(jax.typeof, state_jit), - ) - - _, state_update_jit = jax.jit(optimizer.update)( - params, state_eager, params - ) - test_utils.assert_trees_all_equal( - jax.tree.map(jax.typeof, state_after_update), - jax.tree.map(jax.typeof, state_update_jit), - ) - if __name__ == '__main__': absltest.main()