contrib: add sharding tests for contrib optimizers#1714
Conversation
Add optax/contrib/_sharding_test.py mirroring the pattern in optax/_src/sharding_test.py. Tests 14 contrib optimizers across three checks: abstract-input init, sharding-type stability across init/update, and JIT-preservation of sharding types. Optimizers that require concrete parameter values at init (dadapt_adamw, dog, dowg, prodigy, schedule_free_adamw, schedule_free_sgd) are included in the sharding type tests but excluded from the abstract-init test, with a comment explaining why.
rdyro
left a comment
There was a problem hiding this comment.
Thanks, this looks like a promising PR!
|
|
||
| class ContribShardingTest(parameterized.TestCase): | ||
|
|
||
| @parameterized.named_parameters(ABSTRACT_INIT_OPTIMIZERS.items()) |
There was a problem hiding this comment.
Given a lot of optimizers here don't support an abstract init, I'd omit this test case. This'll simplify the optimizer list.
There was a problem hiding this comment.
Done, dropped the abstract-init test. The optimizer list is now the same set used in the sharding test.
| from optax._src import utils | ||
|
|
||
|
|
||
| os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8' |
There was a problem hiding this comment.
Can you set up a custom contextmanager with try/finally updating the jax.config.update("jax_cpu_device_num", ...) instead of using this env variable globally?
There was a problem hiding this comment.
The actual config key is jax_num_cpu_devices (not jax_cpu_device_num), and it can only be set before the JAX backend initializes, a try/finally contextmanager would error on teardown with RuntimeError: jax_num_cpu_devices config should be updated before backends are initialized.
Replaced the os.environ global with jax.config.update('jax_num_cpu_devices', 8) at module level, before any JAX operations. This matches how optax/_src/sharding_test.py handles it (it uses os.environ['XLA_FLAGS'] at module level for the same reason), just using the JAX config API instead of env variables.
| self.assertIsNotNone(state) | ||
|
|
||
| @parameterized.named_parameters(SHARDING_OPTIMIZERS.items()) | ||
| def test_state_sharding_type_init_match_update(self, optimizer): |
There was a problem hiding this comment.
is the only difference between the two tests the application of jax,jit? maybe we should keep only one of the two versions or make this test parametric, there's a lot of duplication at them moment
There was a problem hiding this comment.
Yes, merged into a single test_state_sharding_type_stable. It now checks all three cases in one pass: eager init/update type stability, jit init matches eager init, and jit update matches eager update. Less repetition, same coverage.
Replace module-level os.environ manipulation with
jax.config.update('jax_num_cpu_devices', 8). Also drop the abstract-init
test for contrib optimizers (several require concrete values during init)
and consolidate the two sharding checks into a single parameterized test.
Summary
optax/contrib/_sharding_test.py, mirroring the pattern established inoptax/_src/sharding_test.pyacprop,ademamix,adopt,cocob,dadapt_adamw,dog,dowg,galore,madgrad,muon,prodigy,schedule_free_adamw,schedule_free_sgd,simplified_ademamixtest_init_with_abstract_input— verifies state can be created fromjax.ShapeDtypeStruct(8 optimizers that support abstract init)test_state_sharding_type_init_match_update— verifies sharding type is preserved across a gradient step (JAX ≥ 0.7.2, all 14 optimizers)test_state_sharding_type_preserved_with_jit— verifies JIT does not alter sharding type (JAX ≥ 0.7.2, all 14 optimizers)Optimizers that require concrete parameter values during
init(e.g.dadapt_adamwwhich records an initial norm,dog/dowg/prodigywhich track distance from initialization) are included in the sharding type tests but excluded from the abstract-init test, with a comment explaining the distinction.Test plan
python -m pytest optax/contrib/_sharding_test.py -qpasses locally (abstract-init tests confirmed; sharding-type tests skipped locally without JAX ≥ 0.7.2 multi-device setup)