Fused SDPA full-attention for head_dim=192/256 (revival of #3293)#3660
Open
hojin12312 wants to merge 5 commits into
Open
Fused SDPA full-attention for head_dim=192/256 (revival of #3293)#3660hojin12312 wants to merge 5 commits into
hojin12312 wants to merge 5 commits into
Conversation
sdpa_full_supported_head_dim only included {64, 80, 128}. Models with
head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention
path which materializes the full score matrix as a single matmul.
At 32K+ context this creates 8+ GB single allocations that crash
Metal's buffer allocator.
Add head_dim=256 to the dispatch gate and instantiate steel_attention
kernel with bd=256. The Metal kernel template handles arbitrary BD
via template parameter — no kernel code changes needed.
Verified: 32K, 64K, 128K context on M2 Ultra with Qwen3.5-122B-A10B.
The fused steel_attention kernel with bd=256 is ~30% slower than the unfused (matmul + softmax + matmul) path. Route head_dim=256 to unfused by default and only use the fused kernel when key_sequence_length > 16384, where unfused would exceed Metal buffer limits. Benchmark (M2 Ultra, H=64, qL=2048, float16): kL=16384: unfused 124ms vs fused 249ms (2.0x faster with routing) kL=32768: fused only (unfused crashes) Vector path (qL<=8, decode) is unaffected — already supports head_dim=256.
Same pattern as head_dim=256: unfused by default for short sequences, fused when kL > 16384 (where unfused would exceed Metal buffer limits). Adds vector kernel instantiations for decode path. Fixes ml-explore#3312.
- Add missing bd=192 steel_attention instantiation (use_fallback routes head_dim=192 to the fused full kernel, but only bd=256 was instantiated) - Exclude head_dim >= 192 from the NAX dispatch branch: the NAX kernel family only instantiates bd=64/128, so those shapes go to the legacy steel kernel which has the instantiations Co-authored-by: Thump604 <thump@cosmiccooler.org>
This was referenced Jun 11, 2026
Author
|
I updated the PR description to correct the Gemma attribution and make the validation scope explicit. The earlier description incorrectly treated Gemma 4's sliding head_dim=256 shape as its long-context full-attention shape. Its long-context global layers use head_dim=512, so this PR's 192/256 full-attention route is not selected there. I removed the Gemma serving results as evidence for this PR, corrected the bf16/fp32 description, and linked the stock mlx-lm A/B and detailed correction in #3658. I also added focused regression tests in c9f8a15 for the head_dim=192/256 full-attention route above kL=16384 and the newly instantiated head_dim=192 vector route. Both focused tests and pre-commit on the changed files pass. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Proposed changes
Revival of #3293: add head_dim=192/256 instantiations to the fused
steel_attentionfull-attention kernel and route prefill SDPA to them whenkL > 16384. Below that dispatch threshold the unfused path remains selectedfor lower latency while its score-shaped transient is still considered
acceptable.
Decode's
sdpa_vectorpath already supported head_dim=256. This PR also addsthe missing head_dim=192 vector and vector-aggregation instantiations, while
closing the full-attention/prefill dispatch gap for both 192 and 256.
The three original commits by @Thump604 are preserved with their authorship.
Additional commits port the original work to the code structure at submission
time and add focused regression coverage:
bd=192steel instantiation already targeted by routing192/256 instantiations
the newly instantiated head_dim=192 vector route
Scope correction
The original version of this PR description incorrectly used Gemma 4 26B as
end-to-end evidence for this change. That attribution was wrong.
Gemma 4 26B/31B use
global_head_dim=512in the full-attention layers thatsee the long context. Their head_dim=256 layers are sliding-window layers with
bounded
kL, so the new 192/256 large-kLroute in this PR is not selectedfor their long-context attention. A stock
mlx-lmA/B subsequently showedidentical Gemma peak memory and prompt throughput between vanilla MLX and this
branch.
The previously reported Gemma 156K -> 184K admission increase came from a
serving-side estimator change in oMLX, not from this kernel. The fallback score
tensor was also previously described as fp32; for the measured bf16 Gemma
workload the materialized array is bf16, with fp32 softmax accumulation.
Therefore this PR should be evaluated for models whose actual long-context
full-attention head dimension is 192 or 256. Gemma 4 26B/31B are not evidence
for or against that kernel.
Detailed correction and the stock
mlx-lmmeasurements:#3658 (comment)
Kernel microbenchmarks
Synthetic GQA shape, not labeled as a Gemma model shape:
n_q=16, n_kv=8, head_dim=256, fp16, M4 Max 36 GB. Peak memory includesinputs.
For this shape, fused overhead above the input allocation stays approximately
8-16 MB, while the unfused score-shaped transient grows linearly with
kL.The fused path is slower per kernel above the current dispatch threshold:
approximately 0.74-0.79x the unfused throughput on this machine. The threshold
preserves the lower-latency unfused route for shorter contexts. Error values
in the table are shown at the benchmark's printed precision.
Stock mlx-lm validation
After a maintainer request, the comparison was rerun through
python -m mlx_lm.benchmark, without oMLX or a serving scheduler.Gemma 4 26B
Gemma 4 26B UD-MLX 3-bit and 4-bit were tested from 32K through 131K. Peak
memory was identical between the two builds to the benchmark's 0.001 GB
precision, and prompt throughput matched within 0.1%. Static dispatch analysis
confirms both builds use the same unfused fallback for the long-context
head_dim=512 global layers.
Qwen3.6 35B
An independent stock
mlx-lmsmoke by @Thump604 showed the expectedmemory/speed tradeoff on a genuine head_dim=256 workload:
The 131K PR run aborted with an interactivity-impacting command-buffer error,
so that row is unresolved and is not presented as a valid A/B result.
These stock benchmark arms were not a strict single-commit comparison because
the tested builds were separated by other main-branch commits. The static
dispatch change and focused route tests establish what this PR changes; the
stock benchmarks are supporting workload evidence.
Independent validation:
#3658 (comment)
Validation
python/tests/test_fast_sdpa.py:qL=9,kL=16385python/tests/test_fast_sdpa.py: 18 tests run, 1 skippedpython/tests/test_fast.py: 23 passed before the focused tests were added;the kernel source is unchanged by the test-only follow-up commit
0.0000at the benchmark's printed precisionmlx-lmcorrection run: Gemma 4 26B 3-bit/4-bit, 32K-131KTradeoff
This is intentionally not a blanket performance optimization. For the tested
head_dim=256 prefill shapes, the fused kernel trades prompt throughput for a
bounded-memory attention path. The dispatch threshold keeps the unfused path
where it is faster and reserves the fused path for long-
kLworkloads wherethe growing score transient is the limiting factor.
Credit: the kernel work and original routing are @Thump604's (#3293). This PR
preserves those commits and authorship, ports the work to the submission-time
code structure, and adds the small completeness fixes and regression tests
described above.
Checklist
pre-commit run --all-filesto format my code / installed pre-commit prior to committing changes