From 84812c1cabdde26f0e13b8aef88d1387ec62d03a Mon Sep 17 00:00:00 2001 From: Thump604 Date: Sat, 21 Mar 2026 23:27:48 -0500 Subject: [PATCH 1/5] fix: add head_dim=256 to fused SDPA full attention kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sdpa_full_supported_head_dim only included {64, 80, 128}. Models with head_dim=256 (Qwen3.5 family) fell back to the unfused naive attention path which materializes the full score matrix as a single matmul. At 32K+ context this creates 8+ GB single allocations that crash Metal's buffer allocator. Add head_dim=256 to the dispatch gate and instantiate steel_attention kernel with bd=256. The Metal kernel template handles arbitrary BD via template parameter — no kernel code changes needed. Verified: 32K, 64K, 128K context on M2 Ultra with Qwen3.5-122B-A10B. --- .../metal/kernels/steel/attn/kernels/steel_attention.metal | 1 + mlx/backend/metal/scaled_dot_product_attention.cpp | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index 0ff9d91b00..4a67826951 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -12,6 +12,7 @@ attention, dtype, bq, bk, bd, wm, wn, mtype, float) #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index d387a5c08c..cf7a8544da 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -623,7 +623,8 @@ bool ScaledDotProductAttention::use_fallback( (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || query_head_dim == 256); const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); + (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 || + query_head_dim == 256); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); From 780f0702616baed83e20d3b2e93cea84b162a04d Mon Sep 17 00:00:00 2001 From: Thump604 Date: Tue, 24 Mar 2026 17:26:29 -0500 Subject: [PATCH 2/5] perf: route head_dim=256 to unfused SDPA for short sequences MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fused steel_attention kernel with bd=256 is ~30% slower than the unfused (matmul + softmax + matmul) path. Route head_dim=256 to unfused by default and only use the fused kernel when key_sequence_length > 16384, where unfused would exceed Metal buffer limits. Benchmark (M2 Ultra, H=64, qL=2048, float16): kL=16384: unfused 124ms vs fused 249ms (2.0x faster with routing) kL=32768: fused only (unfused crashes) Vector path (qL<=8, decode) is unaffected — already supports head_dim=256. --- mlx/backend/metal/scaled_dot_product_attention.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index cf7a8544da..27272f2795 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -622,9 +622,14 @@ bool ScaledDotProductAttention::use_fallback( query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || query_head_dim == 256); + // For head_dim=256, the fused full-attention kernel is ~30% slower than + // unfused. Only route to fused when kL is large enough that unfused would + // exceed Metal buffer limits (the fused kernel tiles K/V so it scales). + const bool sdpa_full_256_ok = + query_head_dim == 256 && key_sequence_length > 16384; const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 || - query_head_dim == 256); + sdpa_full_256_ok); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); From f4f6c23dc926b52be9e544dde21f9074a0396c77 Mon Sep 17 00:00:00 2001 From: Thump604 Date: Tue, 24 Mar 2026 19:55:48 -0500 Subject: [PATCH 3/5] feat: add head_dim=192 to fused SDPA kernel support Same pattern as head_dim=256: unfused by default for short sequences, fused when kL > 16384 (where unfused would exceed Metal buffer limits). Adds vector kernel instantiations for decode path. Fixes #3312. --- .../kernels/scaled_dot_product_attention.metal | 2 ++ .../metal/scaled_dot_product_attention.cpp | 16 +++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index c668d9d8c5..ca2fd358ef 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -32,10 +32,12 @@ using namespace metal; instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 128, 128) \ + instantiate_sdpa_vector(type, 192, 192) \ instantiate_sdpa_vector(type, 256, 256) \ instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 128) \ + instantiate_sdpa_vector_aggregation(type, 192) \ instantiate_sdpa_vector_aggregation(type, 256) instantiate_sdpa_vector_heads(float) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 27272f2795..0b1f8d1f70 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -621,15 +621,17 @@ bool ScaledDotProductAttention::use_fallback( const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || - query_head_dim == 256); - // For head_dim=256, the fused full-attention kernel is ~30% slower than - // unfused. Only route to fused when kL is large enough that unfused would - // exceed Metal buffer limits (the fused kernel tiles K/V so it scales). - const bool sdpa_full_256_ok = - query_head_dim == 256 && key_sequence_length > 16384; + query_head_dim == 192 || query_head_dim == 256); + // For head_dim >= 192, the fused full-attention kernel is slower than + // unfused for short sequences. Only route to fused when kL is large enough + // that the unfused path would exceed Metal buffer limits (the fused kernel + // tiles K/V so it scales to arbitrary sequence lengths). + const bool sdpa_full_large_hd_ok = + (query_head_dim == 192 || query_head_dim == 256) && + key_sequence_length > 16384; const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 || - sdpa_full_256_ok); + sdpa_full_large_hd_ok); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); From 0aa50825acca3be22c8f02c5581d17aa442ba742 Mon Sep 17 00:00:00 2001 From: hojin12312 Date: Thu, 11 Jun 2026 20:13:51 +0900 Subject: [PATCH 4/5] fix: complete head_dim=192/256 fused routing for current main - Add missing bd=192 steel_attention instantiation (use_fallback routes head_dim=192 to the fused full kernel, but only bd=256 was instantiated) - Exclude head_dim >= 192 from the NAX dispatch branch: the NAX kernel family only instantiates bd=64/128, so those shapes go to the legacy steel kernel which has the instantiations Co-authored-by: Thump604 --- .../metal/kernels/steel/attn/kernels/steel_attention.metal | 1 + mlx/backend/metal/scaled_dot_product_attention.cpp | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index 4a67826951..d6e7770880 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -13,6 +13,7 @@ #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 192, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 0b1f8d1f70..37c0a204c6 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -174,7 +174,10 @@ void sdpa_full_self_attention_metal( bool do_causal_, const std::optional& mask, const std::optional& sinks) { - if (metal::is_nax_available() && q.shape(3) != 80 && + // head_dim >= 192 only reaches the fused path for long sequences (see + // use_fallback); the NAX kernel family has no bd=192/256 instantiations, + // so route those shapes to the legacy steel kernel which does. + if (metal::is_nax_available() && q.shape(3) != 80 && q.shape(3) < 192 && (env::enable_tf32() || q.dtype() != float32)) { return sdpa_full_self_attention_nax( /* const Stream& s = */ s, From c9f8a154a95ee46dbdb43b85335c281376bc758e Mon Sep 17 00:00:00 2001 From: hojin12312 Date: Sun, 14 Jun 2026 08:58:52 +0900 Subject: [PATCH 5/5] test: cover large head-dim SDPA routes --- python/tests/test_fast_sdpa.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 7bd867084e..299141ac7a 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -117,6 +117,27 @@ def mlx_primitives_sdpa(q, k, v, scale, mask=None): class TestFastSDPA(mlx_tests.MLXTestCase): + def test_sdpa_large_head_dim_full_attention(self): + qL = 9 + kL = 16385 + + for D in [192, 256]: + with self.subTest(head_dim=D): + q, k, v, scale, _ = prepare_inputs( + 1, qL, kL, D, 1, 1, None, False, mx.float16 + ) + ref = mlx_primitives_sdpa(q, k, v, scale) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + self.assertTrue(mx.allclose(ref, out, atol=1e-3, rtol=1e-3)) + + def test_sdpa_vector_head_dim_192(self): + q, k, v, scale, _ = prepare_inputs( + 1, 1, 257, 192, 1, 1, None, False, mx.float16 + ) + ref = mlx_primitives_sdpa(q, k, v, scale) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + self.assertTrue(mx.allclose(ref, out, atol=1e-3, rtol=1e-3)) + def test_sdpa_vector_kv_transposed_head_seq(self): D = 64 Nq = 4