diff --git a/emerging_optimizers/soap/soap.py b/emerging_optimizers/soap/soap.py index c05672a..710e906 100644 --- a/emerging_optimizers/soap/soap.py +++ b/emerging_optimizers/soap/soap.py @@ -34,6 +34,7 @@ __all__ = [ "SOAP", + "StackedSoap", "precondition", "init_kronecker_factors", "update_kronecker_factors", @@ -584,3 +585,133 @@ def _clip_update_rms_in_place(u: torch.Tensor, max_rms: float, eps: float = 1e-7 scale = (max_rms / (rms + eps)).clamp(max=1.0) # in‐place scale u.mul_(scale) + + +def _stack_2d(x: torch.Tensor) -> torch.Tensor: + """Flattens a 2D or 3D tensor to 2D, merging the batch dim into the smaller matrix edge. + + A 2D tensor is returned unchanged. A 3D tensor ``(b, m, n)`` is merged into the smaller of its two + matrix edges: ``(m, b * n)`` when ``n <= m``, otherwise ``(b * m, n)``. + + Args: + x: A 2D matrix ``(m, n)`` or a 3D batched matrix ``(b, m, n)``. + + Returns: + The 2D stacking of ``x``. + """ + if x.ndim == 2: + return x + b, m, n = x.shape + if n <= m: + # -> (m, b*n): move the batch next to the smaller edge, then merge. + out = x.permute(1, 0, 2).reshape(m, b * n) + else: + # -> (b*m, n): contiguous merge into rows. + out = x.reshape(b * m, n) + return out.contiguous() + + +def _unstack(u: torch.Tensor, shape: torch.Size) -> torch.Tensor: + """Inverse of :func:`_stack_2d`, restoring the original ``shape``.""" + if len(shape) == 2: + return u + b, m, n = shape + if n <= m: + return u.reshape(m, b, n).permute(1, 0, 2).reshape(shape) + return u.reshape(shape) + + +@registry.register_optimizer("stacked_soap") +class StackedSoap(SOAP): + """Limited-memory SOAP for batched / 3D parameters via transient 2D stacking. + + Optimizes the real parameters directly: ``self.param_groups``, ``self.state``, and gradients are all + keyed by the user's parameters, so learning-rate schedulers, gradient clipping, and ``state_dict`` + behave exactly as for plain :class:`SOAP`. Each 3D parameter is flattened to 2D by merging its batch + dim into the smaller matrix edge (see :func:`_stack_2d`) only for the duration of :meth:`step`: the + parameter's ``data`` and ``grad`` are swapped to their 2D views, the inherited SOAP step runs, and the + 2D update is unstacked back into the original storage. Because the swap happens before the inherited + step, its lazy state initialization sizes the optimizer state to the stacked 2D shape automatically. + + Stacking on the smaller edge keeps both Kronecker factors small (the larger edge becomes a single + shared factor) while reusing the full, unmodified SOAP machinery (KL-Shampoo + QR eigenbasis). The + stacking is a storage-sharing view except for the permute branch (``q <= p``), which allocates one + transient 2D buffer per step. A plain 2D parameter is stacked as itself, so this is exactly stock SOAP. + + SOAP is configured with the fixed settings appropriate for this use: decoupled weight decay, no + Nesterov, bias correction on, the QR eigenbasis path with 1 power-iteration step, KL-Shampoo on, and + the default matmul precision. + + Args: + params: Iterable of 2D or 3D parameters to optimize or dicts defining parameter groups. + lr: The learning rate. + betas: Inner Adam betas ``(b1, b2)``. + shampoo_beta: Beta for the kronecker factor moving average. + eps: Inner Adam epsilon. + weight_decay: Decoupled weight decay coefficient. + """ + + def __init__( + self, + params: ParamsT, + lr: float, + betas: tuple[float, float] = (0.9, 0.95), + shampoo_beta: float = 0.95, + eps: float = 1e-8, + weight_decay: float = 0.01, + ) -> None: + super().__init__( + params, + lr, + betas=betas, + shampoo_beta=shampoo_beta, + eps=eps, + weight_decay=weight_decay, + weight_decay_method="decoupled", + nesterov=False, + correct_bias=True, + use_eigh=False, + power_iter_steps=1, + use_kl_shampoo=True, + ) + + if TYPE_CHECKING: + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + if closure is not None: + raise ValueError("closure is not supported") + + # Swap each parameter's data/grad to their 2D stacking, run the inherited SOAP step on the 2D + # views (state is keyed by the real parameter and sized for the stacked shape), then unstack the + # update back into the original storage. The restore runs in a finally so that an exception inside + # super().step() (e.g. OOM, a NaN check) cannot leave parameters stuck in their 2D stacked shape. + saved: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = [] + try: + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue # pragma: no cover + data, grad = p.data, p.grad + saved.append((p, data, grad)) + p.data = _stack_2d(data) + p.grad = _stack_2d(grad) + + super().step() + finally: + for p, data, grad in saved: + stacked = p.data + p.data = data + p.grad = grad + # Copy back only when stacking allocated an independent buffer (permute branch); the view + # branches already wrote the update through to the original storage. + if stacked.data_ptr() != data.data_ptr(): + data.copy_(_unstack(stacked, data.shape)) + return None diff --git a/examples/stacked_soap_grouped_linear.py b/examples/stacked_soap_grouped_linear.py new file mode 100644 index 0000000..258b169 --- /dev/null +++ b/examples/stacked_soap_grouped_linear.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + + +os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"] = "1" + +import torch +import transformer_engine.pytorch as te +from absl import app, flags + +from emerging_optimizers.soap.soap import StackedSoap + + +FLAGS = flags.FLAGS + +flags.DEFINE_integer("num_experts", 8, "Number of experts (grouped GEMMs).") +flags.DEFINE_integer("in_features", 512, "Input feature dimension per expert.") +flags.DEFINE_integer("out_features", 1024, "Output feature dimension per expert.") +flags.DEFINE_integer("tokens", 256, "Total number of tokens routed across experts.") +flags.DEFINE_integer("steps", 5, "Number of optimization steps.") +flags.DEFINE_float("lr", 1e-3, "Learning rate.") + + +def main(argv: list[str]) -> None: + """Build a Transformer Engine GroupedLinear with a single 3D weight and train it with StackedSoap.""" + del argv + if not torch.cuda.is_available(): + raise RuntimeError("This example requires a CUDA device (Transformer Engine is GPU-only).") + + device = torch.device("cuda") + dtype = torch.bfloat16 + + grouped_linear = te.GroupedLinear( + FLAGS.num_experts, + FLAGS.in_features, + FLAGS.out_features, + bias=False, + single_grouped_weight=True, + params_dtype=dtype, + device=device, + ) + + weight = grouped_linear.weight + print(f"Single expert weight tensor: shape={tuple(weight.shape)}, dtype={weight.dtype}") + + optimizer = StackedSoap(grouped_linear.parameters(), lr=FLAGS.lr, weight_decay=0.0) + + # MoE routing: split `tokens` rows across experts; m_splits must sum to the token count. + base = FLAGS.tokens // FLAGS.num_experts + m_splits = [base] * FLAGS.num_experts + m_splits[-1] += FLAGS.tokens - sum(m_splits) + + for step in range(FLAGS.steps): + x = torch.randn(FLAGS.tokens, FLAGS.in_features, device=device, dtype=dtype, requires_grad=True) + out = grouped_linear(x, m_splits) + loss = out.float().square().mean() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + print(f"step {step}: loss={loss.item():.6f}") + + +if __name__ == "__main__": + app.run(main) diff --git a/tests/test_soap.py b/tests/test_soap.py index e3312c4..bd14d5b 100644 --- a/tests/test_soap.py +++ b/tests/test_soap.py @@ -22,7 +22,7 @@ from absl.testing import absltest, parameterized from emerging_optimizers.soap import REKLS, SOAP, soap -from emerging_optimizers.soap.soap import _clip_update_rms_in_place +from emerging_optimizers.soap.soap import StackedSoap, _clip_update_rms_in_place, _stack_2d, _unstack flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") @@ -578,5 +578,100 @@ def test_eigenbasis_matches_reference(self, shape: tuple, num_steps: int): self.assertEqual(test_state["step"], ref_state["step"]) +class StackedSoapTest(parameterized.TestCase): + def setUp(self): + self.device = FLAGS.device + + @parameterized.product(shape=[(8, 5), (4, 6, 3), (4, 3, 6)]) + def test_smoke(self, shape) -> None: + p = torch.nn.Parameter(torch.randn(shape, device=self.device)) + opt = StackedSoap([p], lr=1e-2, weight_decay=0.01) + for _ in range(3): + p.grad = torch.randn_like(p) + opt.step() + self.assertTrue(torch.isfinite(p).all()) + + @parameterized.product(shape=[(8, 5), (4, 6, 3), (4, 3, 6), (4, 5, 5)]) + def test_stack_unstack_shapes_and_roundtrip(self, shape) -> None: + x = torch.randn(shape, device=self.device) + + if x.ndim == 2: + expected_2d = shape + else: + b, m, n = shape + expected_2d = (m, b * n) if n <= m else (b * m, n) + + stacked = _stack_2d(x) + self.assertEqual(stacked.shape, torch.Size(expected_2d)) + + restored = _unstack(stacked, x.shape) + self.assertEqual(restored.shape, x.shape) + assert_equal(restored, x) + + @parameterized.product(shape=[(8, 5), (16, 16), (5, 7)]) + def test_2d_input_7steps_matches_vanilla_soap(self, shape) -> None: + x = torch.randn(shape, device=self.device) + p_stacked = torch.nn.Parameter(x.clone()) + p_ref = torch.nn.Parameter(x.clone()) + + opt_stacked = StackedSoap([p_stacked], lr=1e-2, weight_decay=0.01) + opt_ref = SOAP( + [p_ref], + 1e-2, + weight_decay=0.01, + weight_decay_method="decoupled", + nesterov=False, + correct_bias=True, + use_eigh=False, + power_iter_steps=1, + use_kl_shampoo=True, + ) + + for _ in range(7): + grad = torch.randn(shape, device=self.device) + p_stacked.grad = grad.clone() + p_ref.grad = grad.clone() + opt_stacked.step() + opt_ref.step() + assert_equal( + p_stacked.detach(), + p_ref.detach(), + msg=lambda m: f"StackedSoap must match stock SOAP exactly on 2D params.\n\n{m}", + ) + + @parameterized.product(shape=[(4, 6, 3), (4, 3, 6)]) + def test_3d_input_5steps_matches_vanilla_soap(self, shape) -> None: + """StackedSoap on a 3D param must match vanilla SOAP run on the manually stacked 2D param.""" + x = torch.randn(shape, device=self.device) + p_stacked = torch.nn.Parameter(x.clone()) + # Reference is vanilla SOAP on the 2D stacking of the same parameter. + p_ref = torch.nn.Parameter(_stack_2d(x).clone()) + + opt_stacked = StackedSoap([p_stacked], lr=1e-2, weight_decay=0.01) + opt_ref = SOAP( + [p_ref], + 1e-2, + weight_decay=0.01, + weight_decay_method="decoupled", + nesterov=False, + correct_bias=True, + use_eigh=False, + power_iter_steps=1, + use_kl_shampoo=True, + ) + + for _ in range(5): + grad = torch.randn(shape, device=self.device) + p_stacked.grad = grad.clone() + p_ref.grad = _stack_2d(grad) + opt_stacked.step() + opt_ref.step() + assert_equal( + _stack_2d(p_stacked.detach()), + p_ref.detach(), + msg=lambda m: f"StackedSoap on a 3D param must match vanilla SOAP on its 2D stacking.\n\n{m}", + ) + + if __name__ == "__main__": absltest.main()