sc/matmul: guard Triton launches with the input's CUDA device (multi-GPU fix)#17
Open
heroarmor wants to merge 1 commit into
Open
sc/matmul: guard Triton launches with the input's CUDA device (multi-GPU fix)#17heroarmor wants to merge 1 commit into
heroarmor wants to merge 1 commit into
Conversation
Collaborator
Author
Validation ✅ (2× RTX Pro 6000, gl1809)Unit (the new regression test, run standalone since pytest isn't in the test env): Each runs End-to-end: Llama-3.1-70B-Instruct, sharded across 2 GPUs via
(SC ~70× fp16, in line with the 8B reference. Output coherent at sl=256.) |
The Triton kernels launch on torch.cuda.current_device() (cuda:0 by
default). Under device_map="auto" model sharding (e.g. a 70B split across
2 GPUs), a layer's tensors can live on cuda:1 while the kernel still
launches in the cuda:0 context — dereferencing cuda:1 pointers from a
cuda:0 launch, which fails with:
CUDA error: an illegal memory access was encountered
at _sc_matmul_bipolar_mlp_chunked (down_proj)
Single-GPU never trips this (current device == data device), so it only
surfaced when scaling past what fits on one card.
Fix: wrap the public sc_matmul body in `torch.cuda.device(a.device)`
(no-op on single-GPU / CPU). This makes every kernel launch — and the
per-head / batched torch.cuda.stream contexts — target the tensors' GPU.
The heavy implementation is unchanged; it's renamed to _sc_matmul_impl
and called through the device-guarded wrapper.
Test: parametrized regression in test_sc_smoke.py runs sc_matmul on
cuda:1 inputs (current device still cuda:0) across all granularities,
asserts no fault, output lands on the input GPU, and the result is
bit-identical to the cuda:0 result. Skipped when < 2 CUDA devices.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
c945e4a to
5008c07
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
sc_matmul's Triton kernels launch ontorch.cuda.current_device()(cuda:0 by default). Underdevice_map="auto"model sharding — e.g. loading Llama-3.1-70B across 2 GPUs — a decoder layer's tensors can live on cuda:1 while Triton still launches in the cuda:0 context. That dereferences cuda:1 pointers from a cuda:0 launch and crashes:Single-GPU never trips this (current device == data device), which is why it only appeared when scaling past what fits on one card.
How it was localized
sc_matmulat the exact 70Bdown_projshapes (M=8192, D=28672, N∈{1,4}) — all pass. So it's not a wide-D/Mkernel bug.device_map="auto"sharding across cuda:0/cuda:1.grepconfirms notorch.cuda.device(...)guard anywhere insc/around the Triton launches (onlytorch.cuda.stream(...), which sets the stream, not the device).Fix
Wrap the public
sc_matmulbody intorch.cuda.device(a.device)(a no-op on single-GPU and on CPU inputs). Every kernel launch — plus the per-head/batchedtorch.cuda.streamcontexts — now targets the tensors' GPU. The heavy implementation is unchanged; it's renamed to_sc_matmul_impland called through the thin device-guarded wrapper, so the full API/docstring is preserved.Test
tests/test_sc_smoke.py: parametrized regression that runssc_matmulon cuda:1 inputs (current device left at cuda:0) across all granularities incl. thechunk_d=128MLP path that crashed. Asserts no fault, output lands on the input GPU, the device is restored, and the result is bit-identical to the cuda:0 result. Skipped when< 2CUDA devices.Validation
(Both validated via Slurm job 50713401 — results appended below.)
🤖 Generated with Claude Code