Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mlx/backend/metal/kernels/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ using namespace metal;
instantiate_sdpa_vector(type, 64, 64) \
instantiate_sdpa_vector(type, 96, 96) \
instantiate_sdpa_vector(type, 128, 128) \
instantiate_sdpa_vector(type, 192, 192) \
instantiate_sdpa_vector(type, 256, 256) \
instantiate_sdpa_vector_aggregation(type, 64) \
instantiate_sdpa_vector_aggregation(type, 96) \
instantiate_sdpa_vector_aggregation(type, 128) \
instantiate_sdpa_vector_aggregation(type, 192) \
instantiate_sdpa_vector_aggregation(type, 256)

instantiate_sdpa_vector_heads(float)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
attention, dtype, bq, bk, bd, wm, wn, mtype, float)

#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \
instantiate_attn(iname, itype, 32, 16, 256, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 16, 192, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \
instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype)
Expand Down
17 changes: 14 additions & 3 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ void sdpa_full_self_attention_metal(
bool do_causal_,
const std::optional<array>& mask,
const std::optional<array>& sinks) {
if (metal::is_nax_available() && q.shape(3) != 80 &&
// head_dim >= 192 only reaches the fused path for long sequences (see
// use_fallback); the NAX kernel family has no bd=192/256 instantiations,
// so route those shapes to the legacy steel kernel which does.
if (metal::is_nax_available() && q.shape(3) != 80 && q.shape(3) < 192 &&
(env::enable_tf32() || q.dtype() != float32)) {
return sdpa_full_self_attention_nax(
/* const Stream& s = */ s,
Expand Down Expand Up @@ -621,9 +624,17 @@ bool ScaledDotProductAttention::use_fallback(
const bool sdpa_vector_supported_head_dim =
query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128 ||
query_head_dim == 256);
query_head_dim == 192 || query_head_dim == 256);
// For head_dim >= 192, the fused full-attention kernel is slower than
// unfused for short sequences. Only route to fused when kL is large enough
// that the unfused path would exceed Metal buffer limits (the fused kernel
// tiles K/V so it scales to arbitrary sequence lengths).
const bool sdpa_full_large_hd_ok =
(query_head_dim == 192 || query_head_dim == 256) &&
key_sequence_length > 16384;
const bool sdpa_full_supported_head_dim = query_head_dim == value_head_dim &&
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128);
(query_head_dim == 64 || query_head_dim == 80 || query_head_dim == 128 ||
sdpa_full_large_hd_ok);

const bool sdpa_full_supported_mask = !has_mask || has_arr_mask ||
(query_sequence_length <= key_sequence_length && do_causal);
Expand Down
21 changes: 21 additions & 0 deletions python/tests/test_fast_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,27 @@ def mlx_primitives_sdpa(q, k, v, scale, mask=None):


class TestFastSDPA(mlx_tests.MLXTestCase):
def test_sdpa_large_head_dim_full_attention(self):
qL = 9
kL = 16385

for D in [192, 256]:
with self.subTest(head_dim=D):
q, k, v, scale, _ = prepare_inputs(
1, qL, kL, D, 1, 1, None, False, mx.float16
)
ref = mlx_primitives_sdpa(q, k, v, scale)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
self.assertTrue(mx.allclose(ref, out, atol=1e-3, rtol=1e-3))

def test_sdpa_vector_head_dim_192(self):
q, k, v, scale, _ = prepare_inputs(
1, 1, 257, 192, 1, 1, None, False, mx.float16
)
ref = mlx_primitives_sdpa(q, k, v, scale)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale)
self.assertTrue(mx.allclose(ref, out, atol=1e-3, rtol=1e-3))

def test_sdpa_vector_kv_transposed_head_seq(self):
D = 64
Nq = 4
Expand Down