Skip to content

contrib: add sharding tests for contrib optimizers#1714

Open
irhyl wants to merge 2 commits into
google-deepmind:mainfrom
irhyl:contrib/sharding-tests
Open

contrib: add sharding tests for contrib optimizers#1714
irhyl wants to merge 2 commits into
google-deepmind:mainfrom
irhyl:contrib/sharding-tests

Conversation

@irhyl

@irhyl irhyl commented Jun 26, 2026

Copy link
Copy Markdown

Summary

  • Adds optax/contrib/_sharding_test.py, mirroring the pattern established in optax/_src/sharding_test.py
  • Tests 14 contrib optimizers: acprop, ademamix, adopt, cocob, dadapt_adamw, dog, dowg, galore, madgrad, muon, prodigy, schedule_free_adamw, schedule_free_sgd, simplified_ademamix
  • Three test cases per optimizer:
    • test_init_with_abstract_input — verifies state can be created from jax.ShapeDtypeStruct (8 optimizers that support abstract init)
    • test_state_sharding_type_init_match_update — verifies sharding type is preserved across a gradient step (JAX ≥ 0.7.2, all 14 optimizers)
    • test_state_sharding_type_preserved_with_jit — verifies JIT does not alter sharding type (JAX ≥ 0.7.2, all 14 optimizers)

Optimizers that require concrete parameter values during init (e.g. dadapt_adamw which records an initial norm, dog/dowg/prodigy which track distance from initialization) are included in the sharding type tests but excluded from the abstract-init test, with a comment explaining the distinction.

Test plan

  • python -m pytest optax/contrib/_sharding_test.py -q passes locally (abstract-init tests confirmed; sharding-type tests skipped locally without JAX ≥ 0.7.2 multi-device setup)
  • CI passes all lint and test jobs

Add optax/contrib/_sharding_test.py mirroring the pattern in
optax/_src/sharding_test.py. Tests 14 contrib optimizers across three
checks: abstract-input init, sharding-type stability across init/update,
and JIT-preservation of sharding types.

Optimizers that require concrete parameter values at init (dadapt_adamw,
dog, dowg, prodigy, schedule_free_adamw, schedule_free_sgd) are included
in the sharding type tests but excluded from the abstract-init test, with
a comment explaining why.

@rdyro rdyro left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks like a promising PR!

Comment thread optax/contrib/_sharding_test.py Outdated

class ContribShardingTest(parameterized.TestCase):

@parameterized.named_parameters(ABSTRACT_INIT_OPTIMIZERS.items())

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given a lot of optimizers here don't support an abstract init, I'd omit this test case. This'll simplify the optimizer list.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, dropped the abstract-init test. The optimizer list is now the same set used in the sharding test.

Comment thread optax/contrib/_sharding_test.py Outdated
from optax._src import utils


os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you set up a custom contextmanager with try/finally updating the jax.config.update("jax_cpu_device_num", ...) instead of using this env variable globally?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual config key is jax_num_cpu_devices (not jax_cpu_device_num), and it can only be set before the JAX backend initializes, a try/finally contextmanager would error on teardown with RuntimeError: jax_num_cpu_devices config should be updated before backends are initialized.

Replaced the os.environ global with jax.config.update('jax_num_cpu_devices', 8) at module level, before any JAX operations. This matches how optax/_src/sharding_test.py handles it (it uses os.environ['XLA_FLAGS'] at module level for the same reason), just using the JAX config API instead of env variables.

Comment thread optax/contrib/_sharding_test.py Outdated
self.assertIsNotNone(state)

@parameterized.named_parameters(SHARDING_OPTIMIZERS.items())
def test_state_sharding_type_init_match_update(self, optimizer):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the only difference between the two tests the application of jax,jit? maybe we should keep only one of the two versions or make this test parametric, there's a lot of duplication at them moment

@irhyl irhyl Jun 27, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, merged into a single test_state_sharding_type_stable. It now checks all three cases in one pass: eager init/update type stability, jit init matches eager init, and jit update matches eager update. Less repetition, same coverage.

Replace module-level os.environ manipulation with
jax.config.update('jax_num_cpu_devices', 8). Also drop the abstract-init
test for contrib optimizers (several require concrete values during init)
and consolidate the two sharding checks into a single parameterized test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants