Skip to content

sc/matmul: guard Triton launches with the input's CUDA device (multi-GPU fix)#17

Open
heroarmor wants to merge 1 commit into
mainfrom
fix/multi-gpu-device-guard
Open

sc/matmul: guard Triton launches with the input's CUDA device (multi-GPU fix)#17
heroarmor wants to merge 1 commit into
mainfrom
fix/multi-gpu-device-guard

Conversation

@heroarmor
Copy link
Copy Markdown
Collaborator

Problem

sc_matmul's Triton kernels launch on torch.cuda.current_device() (cuda:0 by default). Under device_map="auto" model sharding — e.g. loading Llama-3.1-70B across 2 GPUs — a decoder layer's tensors can live on cuda:1 while Triton still launches in the cuda:0 context. That dereferences cuda:1 pointers from a cuda:0 launch and crashes:

torch.AcceleratorError: CUDA error: an illegal memory access was encountered
  at sc/kernels.py _sc_matmul_bipolar_mlp_chunked  (via down_proj)

Single-GPU never trips this (current device == data device), which is why it only appeared when scaling past what fits on one card.

How it was localized

  • 1-GPU sc_matmul at the exact 70B down_proj shapes (M=8192, D=28672, N∈{1,4}) — all pass. So it's not a wide-D/M kernel bug.
  • The crash only occurs with device_map="auto" sharding across cuda:0/cuda:1.
  • grep confirms no torch.cuda.device(...) guard anywhere in sc/ around the Triton launches (only torch.cuda.stream(...), which sets the stream, not the device).

Fix

Wrap the public sc_matmul body in torch.cuda.device(a.device) (a no-op on single-GPU and on CPU inputs). Every kernel launch — plus the per-head/batched torch.cuda.stream contexts — now targets the tensors' GPU. The heavy implementation is unchanged; it's renamed to _sc_matmul_impl and called through the thin device-guarded wrapper, so the full API/docstring is preserved.

Test

tests/test_sc_smoke.py: parametrized regression that runs sc_matmul on cuda:1 inputs (current device left at cuda:0) across all granularities incl. the chunk_d=128 MLP path that crashed. Asserts no fault, output lands on the input GPU, the device is restored, and the result is bit-identical to the cuda:0 result. Skipped when < 2 CUDA devices.

Validation

  • Unit: the new multi-GPU test (run on a 2-GPU node).
  • End-to-end: Llama-3.1-70B-Instruct under SC, sharded across 2× RTX Pro 6000 (which previously crashed), now runs to completion.

(Both validated via Slurm job 50713401 — results appended below.)

🤖 Generated with Claude Code

@heroarmor
Copy link
Copy Markdown
Collaborator Author

Validation ✅ (2× RTX Pro 6000, gl1809)

Unit (the new regression test, run standalone since pytest isn't in the test env):

[OK] per_tensor chunk_d=0   2d
[OK] per_row    chunk_d=0   2d
[OK] per_row    chunk_d=128 2d     # the MLP path that crashed on 70B
[OK] per_row    chunk_d=0   3d
[OK] per_head   chunk_d=0   3d
ALL PASS (5/5)

Each runs sc_matmul on cuda:1 while the current device stays cuda:0; output lands on cuda:1, device is restored to cuda:0, and the result is bit-identical to the cuda:0 result.

End-to-end: Llama-3.1-70B-Instruct, sharded across 2 GPUs via device_map="auto" (which crashed with illegal memory access before this fix), now completes:

config ms/tok
fp16 baseline 124
SC sc_prec=8 stoc_len=256 per_row 8647
SC halve + bitrev per_row 8361

(SC ~70× fp16, in line with the 8B reference. Output coherent at sl=256.)

The Triton kernels launch on torch.cuda.current_device() (cuda:0 by
default). Under device_map="auto" model sharding (e.g. a 70B split across
2 GPUs), a layer's tensors can live on cuda:1 while the kernel still
launches in the cuda:0 context — dereferencing cuda:1 pointers from a
cuda:0 launch, which fails with:

    CUDA error: an illegal memory access was encountered
    at _sc_matmul_bipolar_mlp_chunked (down_proj)

Single-GPU never trips this (current device == data device), so it only
surfaced when scaling past what fits on one card.

Fix: wrap the public sc_matmul body in `torch.cuda.device(a.device)`
(no-op on single-GPU / CPU). This makes every kernel launch — and the
per-head / batched torch.cuda.stream contexts — target the tensors' GPU.
The heavy implementation is unchanged; it's renamed to _sc_matmul_impl
and called through the device-guarded wrapper.

Test: parametrized regression in test_sc_smoke.py runs sc_matmul on
cuda:1 inputs (current device still cuda:0) across all granularities,
asserts no fault, output lands on the input GPU, and the result is
bit-identical to the cuda:0 result. Skipped when < 2 CUDA devices.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@heroarmor heroarmor force-pushed the fix/multi-gpu-device-guard branch from c945e4a to 5008c07 Compare May 23, 2026 21:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant