diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index c668d9d8c5..ca2fd358ef 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -32,10 +32,12 @@ using namespace metal; instantiate_sdpa_vector(type, 64, 64) \ instantiate_sdpa_vector(type, 96, 96) \ instantiate_sdpa_vector(type, 128, 128) \ + instantiate_sdpa_vector(type, 192, 192) \ instantiate_sdpa_vector(type, 256, 256) \ instantiate_sdpa_vector_aggregation(type, 64) \ instantiate_sdpa_vector_aggregation(type, 96) \ instantiate_sdpa_vector_aggregation(type, 128) \ + instantiate_sdpa_vector_aggregation(type, 192) \ instantiate_sdpa_vector_aggregation(type, 256) instantiate_sdpa_vector_heads(float) diff --git a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal index 0ff9d91b00..d6e7770880 100644 --- a/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +++ b/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal @@ -12,6 +12,8 @@ attention, dtype, bq, bk, bd, wm, wn, mtype, float) #define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \ + instantiate_attn(iname, itype, 32, 16, 192, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index d387a5c08c..37c0a204c6 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -174,7 +174,10 @@ void sdpa_full_self_attention_metal( bool do_causal_, const std::optional& mask, const std::optional& sinks) { - if (metal::is_nax_available() && q.shape(3) != 80 && + // head_dim >= 192 only reaches the fused path for long sequences (see + // use_fallback); the NAX kernel family has no bd=192/256 instantiations, + // so route those shapes to the legacy steel kernel which does. + if (metal::is_nax_available() && q.shape(3) != 80 && q.shape(3) < 192 && (env::enable_tf32() || q.dtype() != float32)) { return sdpa_full_self_attention_nax( /* const Stream& s = */ s, @@ -621,9 +624,17 @@ bool ScaledDotProductAttention::use_fallback( const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 || - query_head_dim == 256); + query_head_dim == 192 || query_head_dim == 256); + // For head_dim >= 192, the fused full-attention kernel is slower than + // unfused for short sequences. Only route to fused when kL is large enough + // that the unfused path would exceed Metal buffer limits (the fused kernel + // tiles K/V so it scales to arbitrary sequence lengths). + const bool sdpa_full_large_hd_ok = + (query_head_dim == 192 || query_head_dim == 256) && + key_sequence_length > 16384; const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim && - (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128); + (query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 || + sdpa_full_large_hd_ok); const bool sdpa_full_supported_mask = !has_mask || has_arr_mask || (query_sequence_length <= key_sequence_length && do_causal); diff --git a/python/tests/test_fast_sdpa.py b/python/tests/test_fast_sdpa.py index 7bd867084e..299141ac7a 100644 --- a/python/tests/test_fast_sdpa.py +++ b/python/tests/test_fast_sdpa.py @@ -117,6 +117,27 @@ def mlx_primitives_sdpa(q, k, v, scale, mask=None): class TestFastSDPA(mlx_tests.MLXTestCase): + def test_sdpa_large_head_dim_full_attention(self): + qL = 9 + kL = 16385 + + for D in [192, 256]: + with self.subTest(head_dim=D): + q, k, v, scale, _ = prepare_inputs( + 1, qL, kL, D, 1, 1, None, False, mx.float16 + ) + ref = mlx_primitives_sdpa(q, k, v, scale) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + self.assertTrue(mx.allclose(ref, out, atol=1e-3, rtol=1e-3)) + + def test_sdpa_vector_head_dim_192(self): + q, k, v, scale, _ = prepare_inputs( + 1, 1, 257, 192, 1, 1, None, False, mx.float16 + ) + ref = mlx_primitives_sdpa(q, k, v, scale) + out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale) + self.assertTrue(mx.allclose(ref, out, atol=1e-3, rtol=1e-3)) + def test_sdpa_vector_kv_transposed_head_seq(self): D = 64 Nq = 4