Skip to content

onnx: add com.microsoft GroupQueryAttention handler#2292

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/onnx-gqa
Open

onnx: add com.microsoft GroupQueryAttention handler#2292
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feature/onnx-gqa

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

GroupQueryAttention (prefill): causal grouped-query attention over unpacked query/key/value, lowered onto tract's Sdpa (which already handles the kv_num_heads < num_heads head sharing), with present_key/present_value outputs. Decode-step KV cache (past_key/past_value), internal rotary (do_rotary), local-window attention and softcap are rejected with clear errors.

Validated bit-close vs onnxruntime: attention output ≤3.6e-7 and present_key/present_value bit-exact, across head_size 8/16/64, several num_heads/kv_num_heads ratios (incl. multi-query kv_num_heads=1) and batch>1. Full ONNX node suite: 4428 passed / 0 failed (op is purely additive). clippy + fmt clean.

Note for reviewers: ORT's GroupQueryAttention prefill is standard causal grouped-query attention — its seqlens_k input is the 0-indexed position of the last token (total_sequence_length - 1), not the token count.

Part of com.microsoft contrib-op coverage for ORT-exported LLMs (sibling to #2287#2291).

🤖 Generated with Claude Code

Prefill-only GroupQueryAttention lowered onto tract Sdpa: reshapes Q/K/V to
4D, applies an explicit lower-triangular causal mask, and returns
present_key/present_value (the reshaped K/V). Sdpa handles the grouped-query
head sharing (kv_num_heads < num_heads). Decode-step KV cache, internal
rotary (do_rotary), local-window attention and softcap are rejected with
clear errors.

Validated against onnxruntime across head_size 8/16/64, several
num_heads/kv_num_heads ratios (incl. multi-query kv=1) and batch>1: attention
output matches to <=3.6e-7 and present_key/present_value are bit-exact.

ORT's GroupQueryAttention prefill is standard causal grouped-query attention;
the seqlens_k input is the 0-indexed position of the last token
(total_sequence_length - 1), not the token count.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant