Skip to content

cuda+metal/scaled_masked_softmax: bool mask + post_softmax_mask#2285

Open
kali wants to merge 5 commits into
mainfrom
gpu/smms-bool-post-mask
Open

cuda+metal/scaled_masked_softmax: bool mask + post_softmax_mask#2285
kali wants to merge 5 commits into
mainfrom
gpu/smms-bool-post-mask

Conversation

@kali
Copy link
Copy Markdown
Collaborator

@kali kali commented May 26, 2026

Adds a sibling kernel that loads the mask as uchar/char, substitutes -inf at masked positions before softmax, and when post_softmax_mask is set scrubs fully-masked rows (sum == 0 / NaN) to 0 on write-back. Lifts the GpuScaledMaskedSoftmax guards so bool masks aren't rejected by output_facts, and drops the rule_if!(!post_softmax_mask) gate on both backends.

For the nemotron-streaming encoder this moves all 24 SMS nodes off CPU (--cuda matches the recorded io bundle at --approx very). Metal mirror matches structurally; CI's macOS nemotron harness covers the numeric check there.

@kali kali force-pushed the gpu/smms-bool-post-mask branch from 3eeeef3 to bf52f69 Compare May 27, 2026 08:15
kali added 5 commits May 27, 2026 11:04
Adds a sibling kernel that loads the mask as uchar/char, substitutes
-inf at masked positions before softmax, and when post_softmax_mask is
set scrubs fully-masked rows (sum == 0 / NaN) to 0 on write-back.
Lifts the GpuScaledMaskedSoftmax guards so bool masks aren't rejected
by output_facts, and drops the rule_if!(!post_softmax_mask) gate on
both backends.

For the nemotron-streaming encoder this moves all 24 SMS nodes off CPU
(--cuda matches the recorded io bundle at --approx very).  Metal mirror
matches structurally; CI's macOS nemotron harness covers the numeric
check there.
…, Reduce)

Both SMS (this branch) and DiagGather (already on main) now have CUDA +
Metal kernels.  IsNan and Reduce don't appear in any of the 4 streaming
models — IsNan never did, Reduce<Sum> is always F32 which both backends
handle.  Audit on --cuda confirms zero CPU instances; --metal mirrors the
same allowlist now that the Metal kernels exist.  Tight placement check:
any regression that puts one of these on CPU now fails CI.
The decoder is stepped one token at a time by the caller (external state
plumbed through the outer graph), so iters resolves to 1 and the Scan
body can be inlined.  Apply the existing core force_scan_external_state
transform on the decoder run; the two LSTM cells now land on GPU.

Drop Scan from the gpu allowlists — no model in the harness keeps a
Scan node on CPU after this.
Now that cuda + metal Gather kernels are on main, the decoder embedding
lookup runs on GPU.  Audit confirms zero CPU Gather across all 4 models
(decoder is run with -t force_scan_external_state so its embedding
input is fed directly to CudaGather/MetalGather).
@kali kali force-pushed the gpu/smms-bool-post-mask branch from 152ac43 to 3800d84 Compare May 27, 2026 11:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant