cuda+metal/scaled_masked_softmax: bool mask + post_softmax_mask#2285
Open
kali wants to merge 5 commits into
Open
cuda+metal/scaled_masked_softmax: bool mask + post_softmax_mask#2285kali wants to merge 5 commits into
kali wants to merge 5 commits into
Conversation
3eeeef3 to
bf52f69
Compare
Adds a sibling kernel that loads the mask as uchar/char, substitutes -inf at masked positions before softmax, and when post_softmax_mask is set scrubs fully-masked rows (sum == 0 / NaN) to 0 on write-back. Lifts the GpuScaledMaskedSoftmax guards so bool masks aren't rejected by output_facts, and drops the rule_if!(!post_softmax_mask) gate on both backends. For the nemotron-streaming encoder this moves all 24 SMS nodes off CPU (--cuda matches the recorded io bundle at --approx very). Metal mirror matches structurally; CI's macOS nemotron harness covers the numeric check there.
…, Reduce) Both SMS (this branch) and DiagGather (already on main) now have CUDA + Metal kernels. IsNan and Reduce don't appear in any of the 4 streaming models — IsNan never did, Reduce<Sum> is always F32 which both backends handle. Audit on --cuda confirms zero CPU instances; --metal mirrors the same allowlist now that the Metal kernels exist. Tight placement check: any regression that puts one of these on CPU now fails CI.
The decoder is stepped one token at a time by the caller (external state plumbed through the outer graph), so iters resolves to 1 and the Scan body can be inlined. Apply the existing core force_scan_external_state transform on the decoder run; the two LSTM cells now land on GPU. Drop Scan from the gpu allowlists — no model in the harness keeps a Scan node on CPU after this.
Now that cuda + metal Gather kernels are on main, the decoder embedding lookup runs on GPU. Audit confirms zero CPU Gather across all 4 models (decoder is run with -t force_scan_external_state so its embedding input is fed directly to CudaGather/MetalGather).
152ac43 to
3800d84
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds a sibling kernel that loads the mask as uchar/char, substitutes -inf at masked positions before softmax, and when post_softmax_mask is set scrubs fully-masked rows (sum == 0 / NaN) to 0 on write-back. Lifts the GpuScaledMaskedSoftmax guards so bool masks aren't rejected by output_facts, and drops the rule_if!(!post_softmax_mask) gate on both backends.
For the nemotron-streaming encoder this moves all 24 SMS nodes off CPU (--cuda matches the recorded io bundle at --approx very). Metal mirror matches structurally; CI's macOS nemotron harness covers the numeric check there.