diff --git a/optax/experimental/__init__.py b/optax/experimental/__init__.py index b1729e7f6..c5e6e9f13 100644 --- a/optax/experimental/__init__.py +++ b/optax/experimental/__init__.py @@ -16,8 +16,9 @@ """Experimental optax modules.""" from . import _aggregating as aggregating - +from . import _sharding as sharding __all__ = [ 'aggregating', + 'sharding', ] diff --git a/optax/experimental/_sharding.py b/optax/experimental/_sharding.py new file mode 100644 index 000000000..85c1872be --- /dev/null +++ b/optax/experimental/_sharding.py @@ -0,0 +1,263 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental sharding utilities for Optax gradient transformations. + +This module provides utilities for zero-redundancy sharding of Optax optimizer +state. The core idea is to shard optimizer state across more mesh axes than the +model parameters, reducing per-device memory usage of the optimizer without +changing the shapes of the state arrays. + +This module draws on ideas from ``jax_privacy.sharding_utils``, adapting them +for use with arbitrary Optax gradient transformations. + +.. admonition:: Assumptions + + 1. **Explicit sharding API required.** This module assumes that the calling + program uses JAX's explicit sharding API (i.e., "sharding and types"), as + described at + https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html. + In particular, a mesh should be set via ``jax.sharding.set_mesh()`` and + arrays should carry type-level sharding information. + + 2. **Performance characteristics not yet evaluated.** While we provide test + coverage ensuring that the shardings of intermediate optimizer state + arrays work as intended, we have **not** yet evaluated the performance + characteristics of these APIs. If you observe unexpected performance + behaviour (e.g., slow compilation, excessive cross-device communication, + or elevated memory usage), please raise an issue on GitHub. +""" + +import math +from typing import Any, cast + +import jax +from optax._src import base + +P = jax.sharding.PartitionSpec + + +# --------------------------------------------------------------------------- +# Private helpers +# --------------------------------------------------------------------------- + + +def _check_explicit_mesh(mesh: jax.sharding.Mesh) -> None: + """Raise if any mesh axis does not have ``AxisType.Explicit``.""" + if not all( + axis_type == jax.sharding.AxisType.Explicit + for axis_type in mesh.axis_types + ): + raise RuntimeError( + 'with_custom_sharding requires an explicit mesh. Please set the mesh ' + 'using jax.sharding.set_mesh() with ' + 'axis_types=jax.sharding.AxisType.Explicit.' + ) + + +def _get_mesh(*pytrees: Any) -> jax.sharding.Mesh: + """Extract the mesh from the first leaf with a ``NamedSharding``.""" + for pytree in pytrees: + for leaf in jax.tree.leaves(pytree): + sharding = jax.typeof(leaf).sharding + if isinstance(sharding, jax.sharding.NamedSharding): + return cast(jax.sharding.Mesh, sharding.mesh) + raise ValueError( + 'Could not extract mesh from any leaf. Ensure arrays carry type-level ' + 'sharding information (see jax.sharding.set_mesh()).' + ) + + +def _to_struct(leaf: jax.Array) -> jax.ShapeDtypeStruct: + """Convert a concrete array to its abstract ``ShapeDtypeStruct``.""" + typ = jax.typeof(leaf) + return jax.ShapeDtypeStruct(leaf.shape, leaf.dtype, sharding=typ.sharding) + + +def _maybe_reshard(leaf: jax.Array, abstract: jax.ShapeDtypeStruct): + """Reshard *leaf* to match *abstract*'s sharding, if it has one.""" + return jax.reshard(leaf, abstract.sharding) if abstract.sharding else leaf + + +def _reshard_to_abstract(pytree: Any, abstract_pytree: Any) -> Any: + """Reshard each leaf of *pytree* to match shardings in *abstract_pytree*.""" + return jax.tree.map(_maybe_reshard, pytree, abstract_pytree) + + +def _reshard_leaves_enhanced(pytree: Any) -> Any: + """Reshard every leaf of *pytree* to its enhanced sharding.""" + enhanced_abstract = _enhance_abstract_state(jax.tree.map(_to_struct, pytree)) + return _reshard_to_abstract(pytree, enhanced_abstract) + + +def _enhance_abstract_state(abstract_state: Any) -> Any: + """Map abstract optimizer state to one with enhanced sharding annotations.""" + + def _enhance_leaf(leaf): + if not isinstance(leaf.sharding, jax.sharding.NamedSharding): + return leaf + enhanced_pspec = _compute_enhanced_pspec(leaf) + mesh = cast(jax.sharding.Mesh, leaf.sharding.mesh) + return jax.ShapeDtypeStruct( + leaf.shape, + leaf.dtype, + sharding=jax.sharding.NamedSharding(mesh, enhanced_pspec), + ) + + return jax.tree.map(_enhance_leaf, abstract_state) + + +def _compute_enhanced_pspec( + abstract_array: jax.ShapeDtypeStruct, +) -> jax.sharding.PartitionSpec: + """Compute an enhanced PartitionSpec using unused mesh axes.""" + # Greedy algorithm: iterate over unused mesh axes in decreasing order of + # size and assign each to the largest array dimension that is evenly + # divisible by the cumulative shard size. Returns a PartitionSpec that + # utilises as many mesh axes as possible without changing the array shape. + shape = abstract_array.shape + if not shape: + # Scalar: nothing to shard. + return P() + + sharding = abstract_array.sharding + if isinstance(sharding, jax.sharding.NamedSharding): + current_pspec = sharding.spec + mesh = cast(jax.sharding.Mesh, sharding.mesh) + else: + raise TypeError( + 'compute_enhanced_pspec requires a NamedSharding, got ' + f'{type(sharding)}.' + ) + + ndim = len(shape) + + # Parse current pspec into per-dimension axis lists. + dim_axes: list[list[str]] = [[] for _ in range(ndim)] + used_axes: set[str] = set() + + for i, entry in enumerate(current_pspec): + if i >= ndim: + break + if entry is None: + continue + elif isinstance(entry, str): + dim_axes[i].append(entry) + used_axes.add(entry) + elif isinstance(entry, tuple): + for ax in entry: + dim_axes[i].append(ax) + used_axes.add(ax) + + # Unused mesh axes, sorted by size descending (greedy: largest first). + unused_axes = sorted( + ( + (name, mesh.shape[name]) + for name in mesh.axis_names + if name not in used_axes + ), + key=lambda pair: pair[1], + reverse=True, + ) + + # Greedy assignment: for each unused axis, assign to the largest compatible + # dimension. + for ax_name, ax_size in unused_axes: + best_dim = None + best_dim_size = -1 + for i in range(ndim): + current_shard_size = ( + math.prod(mesh.shape[a] for a in dim_axes[i]) if dim_axes[i] else 1 + ) + if shape[i] % (current_shard_size * ax_size) == 0: + if shape[i] > best_dim_size: + best_dim = i + best_dim_size = shape[i] + if best_dim is not None: + dim_axes[best_dim].append(ax_name) + + # Build the resulting PartitionSpec. + entries: list[str | tuple[str, ...] | None] = [] + for axes in dim_axes: + if not axes: + entries.append(None) + elif len(axes) == 1: + entries.append(axes[0]) + else: + entries.append(tuple(axes)) + return P(*entries) + + +# --------------------------------------------------------------------------- +# Public: with_custom_sharding wrapper +# --------------------------------------------------------------------------- + + +def with_custom_sharding( + inner: base.GradientTransformation, +) -> base.GradientTransformation: + """Wrap a gradient transformation with zero-redundancy state sharding. + + This wrapper modifies an existing Optax :class:`GradientTransformation` so + that its optimizer state is sharded across *more* mesh axes than the model + parameters. This reduces per-device memory usage of the optimizer state at + the cost of additional resharding operations during the ``update`` step. + + Unlike the flattening approach in ``jax_privacy.sharding_utils``, this + wrapper **preserves the shapes** of all optimizer-state arrays and only + modifies their shardings. A greedy algorithm (see + :func:`compute_enhanced_pspec`) assigns unused mesh axes to array dimensions + wherever the dimension size is evenly divisible by the mesh-axis size. + + Example usage:: + + import optax + from optax.experimental import sharding + tx = sharding.with_custom_sharding(optax.adam(1e-3)) + + Args: + inner: The base gradient transformation to wrap. + + Returns: + A new :class:`GradientTransformation` whose optimizer state uses enhanced + (zero-redundancy) sharding. + """ + + def init_fn(params): + # Extract mesh from params' type-level sharding info. + mesh = _get_mesh(params) + _check_explicit_mesh(mesh) + + # Materialise the optimizer state, then reshard to enhanced shardings. + state = inner.init(params) + enhanced_abstract = _enhance_abstract_state(jax.tree.map(_to_struct, state)) + return _reshard_to_abstract(state, enhanced_abstract) + + def update_fn(updates, state, params=None): + # Reshard updates (and params, if given) into the enhanced sharding domain. + enhanced_updates = _reshard_leaves_enhanced(updates) + enhanced_params = ( + _reshard_leaves_enhanced(params) if params is not None else None + ) + + # Delegate to the inner transform. + new_updates, new_state = inner.update( + enhanced_updates, + state, + enhanced_params, + ) + + return new_updates, new_state + + return base.GradientTransformation(init_fn, update_fn) diff --git a/optax/experimental/_sharding_test.py b/optax/experimental/_sharding_test.py new file mode 100644 index 000000000..bb2b1e024 --- /dev/null +++ b/optax/experimental/_sharding_test.py @@ -0,0 +1,258 @@ +# Copyright 2025 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax +import jax.numpy as jnp +import numpy as np +from optax._src import alias +from optax._src import update as optax_update +from optax.experimental import _sharding + +P = jax.sharding.PartitionSpec + +# Best-effort: set 8 CPU devices before JAX backend initialisation. +try: + chex.set_n_cpu_devices(8) +except RuntimeError: + pass + +_REQUIRED_DEVICES = 8 + + +def _make_explicit_mesh(shape, names): + """Create an explicit mesh for testing.""" + axis_types = (jax.sharding.AxisType.Explicit,) * len(names) + return jax.make_mesh(shape, names, axis_types=axis_types) + + +class ComputeEnhancedPspecTest(parameterized.TestCase): + """Tests for _compute_enhanced_pspec.""" + + def setUp(self): + super().setUp() + if jax.device_count() < _REQUIRED_DEVICES: + self.skipTest(f'requires {_REQUIRED_DEVICES} devices') + # Mesh: data=2, model=4 → 8 devices total. + self.mesh = _make_explicit_mesh((2, 4), ('data', 'model')) + + def _abstract(self, shape, pspec, mesh=None): + mesh = mesh or self.mesh + sharding = jax.sharding.NamedSharding(mesh, pspec) + return jax.ShapeDtypeStruct(shape, jnp.float32, sharding=sharding) + + def test_scalar_returns_empty(self): + result = _sharding._compute_enhanced_pspec(self._abstract((), P())) + self.assertEqual(result, P()) + + def test_all_axes_already_used(self): + """When all mesh axes are already assigned, nothing changes.""" + result = _sharding._compute_enhanced_pspec( + self._abstract((8, 4), P('data', 'model')) + ) + self.assertEqual(result, P('data', 'model')) + + def test_one_unused_axis_assigned(self): + """One unused axis ('data', size 2) should be assigned to dim 0.""" + # 'model' (size 4) already on dim 1. + # 'data' (size 2): dim 0 → 8 % 2 = 0 ✓, dim 1 → 4/(4) = 1, 1%2 ≠ 0 ✗. + result = _sharding._compute_enhanced_pspec( + self._abstract((8, 4), P(None, 'model')) + ) + self.assertEqual(result, P('data', 'model')) + + def test_all_axes_unused_greedy(self): + """Both axes unused: largest axis first, assigned to largest dim.""" + # 'model' (size 4) first → dim 0 (size 8, 8%4=0). + # 'data' (size 2) next → dim 0 (8%(4*2)=0, size 8) vs dim 1 (4%2=0, size + # 4). Picks dim 0 (larger). + result = _sharding._compute_enhanced_pspec( + self._abstract((8, 4), P(None, None)) + ) + self.assertEqual(result, P(('model', 'data'), None)) + + def test_axis_does_not_fit(self): + """Axes that don't fit any dimension are skipped.""" + # Shape (3, 5): 3 % 4 ≠ 0, 5 % 4 ≠ 0, 3 % 2 ≠ 0, 5 % 2 ≠ 0. + result = _sharding._compute_enhanced_pspec( + self._abstract((3, 5), P(None, None)) + ) + self.assertEqual(result, P(None, None)) + + def test_partial_fit(self): + """Only some unused axes fit.""" + # Shape (6, 3): 'model' (4) → 6%4≠0, 3%4≠0 → skip. + # 'data' (2) → 6%2=0 → dim 0. + result = _sharding._compute_enhanced_pspec( + self._abstract((6, 3), P(None, None)) + ) + self.assertEqual(result, P('data', None)) + + def test_1d_array(self): + """Single-dimension array.""" + result = _sharding._compute_enhanced_pspec(self._abstract((8,), P(None))) + # 'model' (4): 8%4=0. 'data' (2): 8%(4*2)=0. + self.assertEqual( + result, + P( + ('model', 'data'), + ), + ) + + def test_prefers_largest_dimension(self): + """When an axis fits in multiple dims, the largest dim is chosen.""" + # 3-axis mesh: a=2, b=2, c=2. + mesh = _make_explicit_mesh((2, 2, 2), ('a', 'b', 'c')) + # Shape (8, 4): pspec P('a', None). + # 'b' (2): dim 0 (8%(2*2)=0, size 8) vs dim 1 (4%2=0, size 4) → dim 0. + # 'c' (2): dim 0 (8%(4*2)=0, size 8) vs dim 1 (4%2=0, size 4) → dim 0. + result = _sharding._compute_enhanced_pspec( + self._abstract((8, 4), P('a', None), mesh) + ) + self.assertEqual(result, P(('a', 'b', 'c'), None)) + + def test_distributes_across_dims_when_needed(self): + """When an axis can't stack on an existing dim, it goes to another.""" + # 3-axis mesh: a=2, b=2, c=2. + mesh = _make_explicit_mesh((2, 2, 2), ('a', 'b', 'c')) + # Shape (4, 6): pspec P('a', None). + # 'b' (2): dim 0 (4%(2*2)=0, size 4) vs dim 1 (6%2=0, size 6) → dim 1. + # 'c' (2): dim 0 (4%(4)=0, size 4) vs dim 1 (6%(2*2)≠0, nope). + # → dim 0 wins for 'c'. + result = _sharding._compute_enhanced_pspec( + self._abstract((4, 6), P('a', None), mesh) + ) + self.assertEqual(result, P(('a', 'c'), 'b')) + + +class WithCustomShardingTest(absltest.TestCase): + """Integration tests for with_custom_sharding.""" + + def setUp(self): + super().setUp() + if jax.device_count() < _REQUIRED_DEVICES: + self.skipTest(f'requires {_REQUIRED_DEVICES} devices') + self.mesh = _make_explicit_mesh((2, 4), ('data', 'model')) + jax.sharding.set_mesh(self.mesh) + + def _make_params(self, shape=(8, 4), pspec=P(None, 'model')): + """Create params with the given shape and sharding.""" + sharding = jax.sharding.NamedSharding(self.mesh, pspec) + return jax.device_put(jnp.ones(shape, dtype=jnp.float32), sharding) + + def test_init_enhances_state_sharding(self): + """State leaves matching param shape should get enhanced sharding.""" + params = self._make_params() + tx = _sharding.with_custom_sharding(alias.adam(1e-3)) + state = tx.init(params) + + # The mu and nu arrays should have enhanced sharding (using 'data' axis). + for leaf in jax.tree.leaves(state): + if leaf.shape == (8, 4): + leaf_pspec = jax.typeof(leaf).sharding.spec + # 'data' should be incorporated into the sharding. + self.assertIn('data', str(leaf_pspec)) + + def test_update_produces_enhanced_sharding(self): + """Output updates should carry the enhanced (zero-redundancy) sharding.""" + params = self._make_params() + tx = _sharding.with_custom_sharding(alias.adam(1e-3)) + state = tx.init(params) + + grads = self._make_params() # Same shape/sharding as params. + updates, _ = tx.update(grads, state, params) + + # Updates should be in the enhanced sharding domain. + for upd_leaf in jax.tree.leaves(updates): + if upd_leaf.shape == (8, 4): + self.assertIn('data', str(jax.typeof(upd_leaf).sharding.spec)) + + def test_numerical_correctness(self): + """Wrapped transform should produce same values as unwrapped.""" + params = self._make_params() + grads = self._make_params() + + # Unwrapped. + tx_plain = alias.adam(1e-3) + state_plain = tx_plain.init(params) + updates_plain, _ = tx_plain.update(grads, state_plain, params) + + # Wrapped. + tx_wrapped = _sharding.with_custom_sharding(alias.adam(1e-3)) + state_wrapped = tx_wrapped.init(params) + updates_wrapped, _ = tx_wrapped.update(grads, state_wrapped, params) + + chex.assert_trees_all_close(updates_plain, updates_wrapped, atol=1e-7) + + def test_multiple_update_steps(self): + """Run several update steps to verify state evolution.""" + params = self._make_params() + tx = _sharding.with_custom_sharding(alias.adam(1e-3)) + state = tx.init(params) + + rng = np.random.RandomState(42) + for _ in range(5): + grads_np = rng.randn(*params.shape).astype(np.float32) + grads = jax.device_put( + jnp.array(grads_np), + jax.sharding.NamedSharding(self.mesh, P(None, 'model')), + ) + updates, state = tx.update(grads, state, params) + params = optax_update.apply_updates(params, updates) + + # Verify params are finite. + for leaf in jax.tree.leaves(params): + self.assertTrue(jnp.all(jnp.isfinite(leaf))) + + def test_sgd_with_momentum(self): + """Test wrapping SGD with momentum (simpler state than Adam).""" + params = self._make_params() + tx = _sharding.with_custom_sharding(alias.sgd(1e-2, momentum=0.9)) + state = tx.init(params) + + grads = self._make_params() + _, new_state = tx.update(grads, state, params) + + # SGD momentum state should have enhanced sharding. + for leaf in jax.tree.leaves(new_state): + if leaf.shape == (8, 4): + self.assertIn('data', str(jax.typeof(leaf).sharding.spec)) + + def test_no_unused_axes_is_identity(self): + """If params already use all axes, wrapping should be a no-op.""" + # Params sharded across both axes: P('data', 'model'). + params = self._make_params(pspec=P('data', 'model')) + + tx_plain = alias.sgd(1e-2, momentum=0.9) + tx_wrapped = _sharding.with_custom_sharding(alias.sgd(1e-2, momentum=0.9)) + + state_plain = tx_plain.init(params) + state_wrapped = tx_wrapped.init(params) + + # State shardings should match since no enhancement is possible. + for plain_leaf, wrapped_leaf in zip( + jax.tree.leaves(state_plain), + jax.tree.leaves(state_wrapped), + ): + self.assertEqual( + jax.typeof(plain_leaf).sharding.spec, + jax.typeof(wrapped_leaf).sharding.spec, + ) + + +if __name__ == '__main__': + absltest.main()