From 1abbd5463a7d15771c89e491e80ab66347aef80e Mon Sep 17 00:00:00 2001 From: Goni Zahavy Date: Tue, 16 Jun 2026 11:59:33 +0300 Subject: [PATCH] [Vulkan] Fix sum_rows block size for n_cols < 32 --- mlx/backend/vulkan/kernels.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlx/backend/vulkan/kernels.cpp b/mlx/backend/vulkan/kernels.cpp index 1a74ecba9f..8c9271624e 100644 --- a/mlx/backend/vulkan/kernels.cpp +++ b/mlx/backend/vulkan/kernels.cpp @@ -2599,7 +2599,8 @@ void dispatch_sum_rows_op( const auto row_count = checked_u32(out.size(), "sum_rows output rows"); const uint32_t max_invocations = max_compute_work_group_invocations(); const uint32_t block_size = std::min( - push_constants.n_cols <= 32u ? 32u + push_constants.n_cols < 32u ? 1u + : push_constants.n_cols <= 32u ? 32u : push_constants.n_cols <= 64u ? 64u : push_constants.n_cols >= 4096u ? 1024u : 128u,