Skip to content

Add cos/sin width guard to fused MLA RoPE kernels#5497

Open
ShauryaaSharma wants to merge 1 commit into
NVIDIA:mainfrom
ShauryaaSharma:fix/fused-mla-rope-cos-sin-width-guard
Open

Add cos/sin width guard to fused MLA RoPE kernels#5497
ShauryaaSharma wants to merge 1 commit into
NVIDIA:mainfrom
ShauryaaSharma:fix/fused-mla-rope-cos-sin-width-guard

Conversation

@ShauryaaSharma

Copy link
Copy Markdown
  • I, the PR author, have personally reviewed every line of this PR.

What does this PR do ?

Add a width guard to the fused MLA RoPE kernels so an undersized cos/sin cache fails with a clear error instead of silently reading out of bounds.

The fused kernels (ApplyMLARotaryEmbQ / ApplyMLARotaryEmbKV) read emb_dim cos/sin values per token and assume the cache is emb_dim wide. When the rotary cache is built with rotary_percent < 1, the base RotaryEmbedding shrinks it to int(emb_dim * rotary_percent) , narrower than emb_dim , and the kernel reads past the buffer, picks up uninitialized GPU memory, and produces NaN gradients. This is reachable on main: MultiLatentAttention forwards rotary_percent into the base RotaryEmbedding for rope_type="rope" (multi_latent_attention.py:184), and the fused branch (:870) calls the kernel regardless of rope type with no width check today (fused_mla_yarn_rope_apply.py:251-255). The guard asserts cos.shape[-1] == emb_dim (and sin) in both forward paths, turning a silent ~1e14 out-of-bounds read into an immediate, readable failure.

Issue tracking

Linked issue: Related to #5317

This guard makes the failure mode in #5317 loud instead of silent; the root-cause fix for that issue is rotary_percent = 1.0 in the Megatron-Bridge DSv4 config mapping. Independent of #5412, which fixes a separate in-place aliasing bug.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

The fused MLA RoPE kernels read emb_dim cos/sin values per token and assume the
cache is emb_dim wide. When the rotary cache is built with rotary_percent < 1, the
base RotaryEmbedding shrinks it to int(emb_dim * rotary_percent), which is narrower
than emb_dim. The kernel then reads past the buffer, picks up uninitialized GPU
memory, and produces NaN gradients.

Add an explicit check in both ApplyMLARotaryEmbQ.forward and ApplyMLARotaryEmbKV.forward
that the cos/sin last dimension equals emb_dim, turning the silent out-of-bounds read
into a clear error. Add CPU-only regression tests for the query and key/value kernels.

Related to NVIDIA#5317.

Signed-off-by: ShauryaaSharma <shauryasofficial27@gmail.com>
@ShauryaaSharma ShauryaaSharma requested review from a team as code owners June 25, 2026 08:24
@copy-pr-bot

copy-pr-bot Bot commented Jun 25, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@svcnvidia-nemo-ci svcnvidia-nemo-ci marked this pull request as draft June 25, 2026 08:24
@github-actions

Copy link
Copy Markdown
Contributor

This PR has been automatically converted to draft because all PRs must start as drafts.

When you are ready for review, click Ready for Review to begin the review process. This will:

  1. Add the oncall reviewer (optional reviewer)
  2. Add required review teams based on your changes

See the contribution guide for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-request Final Review PR is in the "final review" stage waiting-on-maintainers Waiting on maintainers to respond

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants