Skip to content

Fused SDPA full-attention for head_dim=192/256 (revival of #3293)#3660

Open
hojin12312 wants to merge 5 commits into
ml-explore:mainfrom
hojin12312:sdpa-256-revival
Open

Fused SDPA full-attention for head_dim=192/256 (revival of #3293)#3660
hojin12312 wants to merge 5 commits into
ml-explore:mainfrom
hojin12312:sdpa-256-revival

Conversation

@hojin12312

@hojin12312 hojin12312 commented Jun 11, 2026

Copy link
Copy Markdown

Proposed changes

Revival of #3293: add head_dim=192/256 instantiations to the fused
steel_attention full-attention kernel and route prefill SDPA to them when
kL > 16384. Below that dispatch threshold the unfused path remains selected
for lower latency while its score-shaped transient is still considered
acceptable.

Decode's sdpa_vector path already supported head_dim=256. This PR also adds
the missing head_dim=192 vector and vector-aggregation instantiations, while
closing the full-attention/prefill dispatch gap for both 192 and 256.

The three original commits by @Thump604 are preserved with their authorship.
Additional commits port the original work to the code structure at submission
time and add focused regression coverage:

  • add the missing bd=192 steel instantiation already targeted by routing
  • exclude head_dim >= 192 from the NAX branch, whose kernel family has no
    192/256 instantiations
  • exercise the 192/256 full-attention route above the dispatch threshold and
    the newly instantiated head_dim=192 vector route

Scope correction

The original version of this PR description incorrectly used Gemma 4 26B as
end-to-end evidence for this change. That attribution was wrong.

Gemma 4 26B/31B use global_head_dim=512 in the full-attention layers that
see the long context. Their head_dim=256 layers are sliding-window layers with
bounded kL, so the new 192/256 large-kL route in this PR is not selected
for their long-context attention. A stock mlx-lm A/B subsequently showed
identical Gemma peak memory and prompt throughput between vanilla MLX and this
branch.

The previously reported Gemma 156K -> 184K admission increase came from a
serving-side estimator change in oMLX, not from this kernel. The fallback score
tensor was also previously described as fp32; for the measured bf16 Gemma
workload the materialized array is bf16, with fp32 softmax accumulation.

Therefore this PR should be evaluated for models whose actual long-context
full-attention head dimension is 192 or 256. Gemma 4 26B/31B are not evidence
for or against that kernel.

Detailed correction and the stock mlx-lm measurements:
#3658 (comment)

Kernel microbenchmarks

Synthetic GQA shape, not labeled as a Gemma model shape:

n_q=16, n_kv=8, head_dim=256, fp16, M4 Max 36 GB. Peak memory includes
inputs.

qL kL inputs fused peak unfused peak max abs err
1024 32768 264M 272M 1,824M 0.0
2048 65536 528M 544M 5,696M 0.0
2048 131072 1,040M 1,056M 11,328M 0.0

For this shape, fused overhead above the input allocation stays approximately
8-16 MB, while the unfused score-shaped transient grows linearly with kL.
The fused path is slower per kernel above the current dispatch threshold:
approximately 0.74-0.79x the unfused throughput on this machine. The threshold
preserves the lower-latency unfused route for shorter contexts. Error values
in the table are shown at the benchmark's printed precision.

Stock mlx-lm validation

After a maintainer request, the comparison was rerun through
python -m mlx_lm.benchmark, without oMLX or a serving scheduler.

Gemma 4 26B

Gemma 4 26B UD-MLX 3-bit and 4-bit were tested from 32K through 131K. Peak
memory was identical between the two builds to the benchmark's 0.001 GB
precision, and prompt throughput matched within 0.1%. Static dispatch analysis
confirms both builds use the same unfused fallback for the long-context
head_dim=512 global layers.

Qwen3.6 35B

An independent stock mlx-lm smoke by @Thump604 showed the expected
memory/speed tradeoff on a genuine head_dim=256 workload:

context main prompt_tps PR prompt_tps main peak GB PR peak GB
32K 1005.743 871.489 41.316 39.882
64K 802.929 624.513 44.201 40.068

The 131K PR run aborted with an interactivity-impacting command-buffer error,
so that row is unresolved and is not presented as a valid A/B result.

These stock benchmark arms were not a strict single-commit comparison because
the tested builds were separated by other main-branch commits. The static
dispatch change and focused route tests establish what this PR changes; the
stock benchmarks are supporting workload evidence.

Independent validation:
#3658 (comment)

Validation

  • focused route tests added in python/tests/test_fast_sdpa.py:
    • head_dim 192/256 full attention with qL=9, kL=16385
    • head_dim 192 vector attention
  • current python/tests/test_fast_sdpa.py: 18 tests run, 1 skipped
  • python/tests/test_fast.py: 23 passed before the focused tests were added;
    the kernel source is unchanged by the test-only follow-up commit
  • head_dim=256 fp16 fused/unfused microbenchmark errors were reported as
    0.0000 at the benchmark's printed precision
  • pre-commit hooks passed on all files changed by this PR
  • original fix: add head_dim=256 to fused SDPA full attention kernel #3293 work was also tested on M2 Ultra and M3 Ultra
  • stock mlx-lm correction run: Gemma 4 26B 3-bit/4-bit, 32K-131K

Tradeoff

This is intentionally not a blanket performance optimization. For the tested
head_dim=256 prefill shapes, the fused kernel trades prompt throughput for a
bounded-memory attention path. The dispatch threshold keeps the unfused path
where it is faster and reserves the fused path for long-kL workloads where
the growing score transient is the limiting factor.

Credit: the kernel work and original routing are @Thump604's (#3293). This PR
preserves those commits and authorship, ports the work to the submission-time
code structure, and adds the small completeness fixes and regression tests
described above.

Checklist

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
    • pre-commit hooks were run successfully on every file changed by this PR
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (not applicable; no public API or user documentation changed)

Thump604 and others added 4 commits June 11, 2026 20:10
sdpa_full_supported_head_dim only included {64, 80, 128}. Models with
head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention
path which materializes the full score matrix as a single matmul.
At 32K+ context this creates 8+ GB single allocations that crash
Metal's buffer allocator.

Add head_dim=256 to the dispatch gate and instantiate steel_attention
kernel with bd=256. The Metal kernel template handles arbitrary BD
via template parameter — no kernel code changes needed.

Verified: 32K, 64K, 128K context on M2 Ultra with Qwen3.5-122B-A10B.
The fused steel_attention kernel with bd=256 is ~30% slower than the
unfused (matmul + softmax + matmul) path. Route head_dim=256 to unfused
by default and only use the fused kernel when key_sequence_length > 16384,
where unfused would exceed Metal buffer limits.

Benchmark (M2 Ultra, H=64, qL=2048, float16):
  kL=16384: unfused 124ms vs fused 249ms (2.0x faster with routing)
  kL=32768: fused only (unfused crashes)

Vector path (qL<=8, decode) is unaffected — already supports head_dim=256.
Same pattern as head_dim=256: unfused by default for short sequences,
fused when kL > 16384 (where unfused would exceed Metal buffer limits).

Adds vector kernel instantiations for decode path.

Fixes ml-explore#3312.
- Add missing bd=192 steel_attention instantiation (use_fallback routes
  head_dim=192 to the fused full kernel, but only bd=256 was instantiated)
- Exclude head_dim >= 192 from the NAX dispatch branch: the NAX kernel
  family only instantiates bd=64/128, so those shapes go to the legacy
  steel kernel which has the instantiations

Co-authored-by: Thump604 <thump@cosmiccooler.org>
@hojin12312

Copy link
Copy Markdown
Author

I updated the PR description to correct the Gemma attribution and make the validation scope explicit.

The earlier description incorrectly treated Gemma 4's sliding head_dim=256 shape as its long-context full-attention shape. Its long-context global layers use head_dim=512, so this PR's 192/256 full-attention route is not selected there. I removed the Gemma serving results as evidence for this PR, corrected the bf16/fp32 description, and linked the stock mlx-lm A/B and detailed correction in #3658.

I also added focused regression tests in c9f8a15 for the head_dim=192/256 full-attention route above kL=16384 and the newly instantiated head_dim=192 vector route. Both focused tests and pre-commit on the changed files pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants