Skip to content

Add stacked soap#235

Open
skyw wants to merge 8 commits into
mainfrom
skyw/stacked-soap
Open

Add stacked soap#235
skyw wants to merge 8 commits into
mainfrom
skyw/stacked-soap

Conversation

@skyw

@skyw skyw commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Save memory on grouped linear layers for MoE. Accuracy impact is yet to be tested.

skyw added 2 commits June 24, 2026 14:45
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw requested a review from mkhona-nvidia June 24, 2026 22:31
@copy-pr-bot

copy-pr-bot Bot commented Jun 24, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 24, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces StackedSoap, a memory-efficient SOAP variant for 3D parameters (e.g. grouped-linear MoE weights). It transiently reshapes each 3D (b, m, n) parameter to 2D for the duration of the optimizer step — merging the batch dim into the smaller matrix edge — then unstacks the update back into original storage, allowing stock SOAP's Kronecker-factor machinery to operate on a single shared 2D view instead of b independent factor sets.

  • Core helpers _stack_2d / _unstack are a correct inverse pair: the view branch (n > m) shares storage with the original tensor (update writes through), while the permute branch (n ≤ m) allocates a transient buffer and copies back via data_ptr() comparison.
  • Exception safety is handled with a try/finally that always restores p.data and p.grad to their original 3D shapes regardless of how super().step() exits.
  • Tests cover smoke runs, roundtrip shape correctness for both branches, exact numerical match with vanilla SOAP for 2D inputs, and exact match against manually-stacked vanilla SOAP for 3D inputs.

Confidence Score: 5/5

Safe to merge — the stacking/unstacking logic is a correct inverse pair, exception safety is properly handled, and the new class is covered by numerical equivalence tests against vanilla SOAP.

The core swap-step-restore pattern is sound: the data_ptr() check correctly distinguishes view branches (shared storage, no copy needed) from permute branches (independent buffer, copy-back required), and the try/finally guarantees parameters are never left in a 2D stacked state. Tests confirm exact numerical equivalence with stock SOAP for both 2D and 3D inputs.

No files require special attention.

Important Files Changed

Filename Overview
emerging_optimizers/soap/soap.py Adds _stack_2d, _unstack helpers and the StackedSoap subclass. Stacking/unstacking logic is a correct inverse pair, the data_ptr() check correctly distinguishes view branches (n>m, shared storage) from permute branches (n≤m, independent buffer), and the try/finally guarantees parameter shapes are always restored. No issues found.
tests/test_soap.py Adds StackedSoapTest with smoke tests, roundtrip shape checks (both n≤m and n>m branches), 2D equivalence with vanilla SOAP, and 3D equivalence against manually-stacked vanilla SOAP. Good coverage of both stacking branches.
examples/stacked_soap_grouped_linear.py New runnable example using TE GroupedLinear with a single 3D weight and StackedSoap. Env var is correctly set before the TE import, zero_grad ordering is valid, and m_splits construction is correct.

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant User
    participant SS as StackedSoap.step()
    participant Stack as _stack_2d / _unstack
    participant SOAP as super().step() (SOAP)

    User->>SS: step()
    SS->>SS: for each param p with grad

    SS->>Stack: _stack_2d(p.data)  shape (b,m,n)→2D
    Stack-->>SS: "stacked_data (view if n>m, new buf if n≤m)"

    SS->>Stack: _stack_2d(p.grad)  shape (b,m,n)→2D
    Stack-->>SS: stacked_grad

    SS->>SS: "p.data = stacked_data"
    SS->>SS: "p.grad = stacked_grad"
    SS->>SS: save (p, original_data, original_grad)

    SS->>SOAP: super().step()  [operates on 2D stacked tensors]
    SOAP-->>SS: done (state keyed by p, sized for 2D)

    Note over SS: finally block (always runs)
    SS->>SS: for each (p, data, grad) in saved
    SS->>SS: "stacked = p.data"
    SS->>SS: "p.data = original_data (restore 3D)"
    SS->>SS: "p.grad = original_grad (restore 3D)"

    alt "stacked.data_ptr() != data.data_ptr()  [permute branch: n≤m]"
        SS->>Stack: _unstack(stacked, original_shape)
        Stack-->>SS: 3D update tensor
        SS->>SS: original_data.copy_(update)
    else "view branch: n>m"
        Note over SS: update already written through shared storage
    end

    SS-->>User: None
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant User
    participant SS as StackedSoap.step()
    participant Stack as _stack_2d / _unstack
    participant SOAP as super().step() (SOAP)

    User->>SS: step()
    SS->>SS: for each param p with grad

    SS->>Stack: _stack_2d(p.data)  shape (b,m,n)→2D
    Stack-->>SS: "stacked_data (view if n>m, new buf if n≤m)"

    SS->>Stack: _stack_2d(p.grad)  shape (b,m,n)→2D
    Stack-->>SS: stacked_grad

    SS->>SS: "p.data = stacked_data"
    SS->>SS: "p.grad = stacked_grad"
    SS->>SS: save (p, original_data, original_grad)

    SS->>SOAP: super().step()  [operates on 2D stacked tensors]
    SOAP-->>SS: done (state keyed by p, sized for 2D)

    Note over SS: finally block (always runs)
    SS->>SS: for each (p, data, grad) in saved
    SS->>SS: "stacked = p.data"
    SS->>SS: "p.data = original_data (restore 3D)"
    SS->>SS: "p.grad = original_grad (restore 3D)"

    alt "stacked.data_ptr() != data.data_ptr()  [permute branch: n≤m]"
        SS->>Stack: _unstack(stacked, original_shape)
        Stack-->>SS: 3D update tensor
        SS->>SS: original_data.copy_(update)
    else "view branch: n>m"
        Note over SS: update already written through shared storage
    end

    SS-->>User: None
Loading

Reviews (6): Last reviewed commit: "add grouped linear example" | Re-trigger Greptile

Comment thread emerging_optimizers/soap/soap.py Outdated
Comment thread emerging_optimizers/soap/soap.py Outdated
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 24, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 283d7ed

@github-actions

github-actions Bot commented Jun 24, 2026

Copy link
Copy Markdown

Test Results

   81 files  ± 0    153 suites  +2   1m 45s ⏱️ ±0s
1 178 tests +12  1 178 ✅ +12  0 💤 ±0  0 ❌ ±0 
2 738 runs  +24  2 738 ✅ +24  0 💤 ±0  0 ❌ ±0 

Results for commit 612b85f. ± Comparison against base commit 46eda5a.

♻️ This comment has been updated with latest results.

@codecov

codecov Bot commented Jun 24, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

skyw added 2 commits June 24, 2026 20:58
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 25, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test fdb8c27

skyw and others added 3 commits June 26, 2026 10:08
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw

skyw commented Jun 26, 2026

Copy link
Copy Markdown
Contributor Author

/ok to test 612b85f

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant