diff --git a/mlx/backend/vulkan/kernels.cpp b/mlx/backend/vulkan/kernels.cpp index 1a74ecba9f..8c9271624e 100644 --- a/mlx/backend/vulkan/kernels.cpp +++ b/mlx/backend/vulkan/kernels.cpp @@ -2599,7 +2599,8 @@ void dispatch_sum_rows_op( const auto row_count = checked_u32(out.size(), "sum_rows output rows"); const uint32_t max_invocations = max_compute_work_group_invocations(); const uint32_t block_size = std::min( - push_constants.n_cols <= 32u ? 32u + push_constants.n_cols < 32u ? 1u + : push_constants.n_cols <= 32u ? 32u : push_constants.n_cols <= 64u ? 64u : push_constants.n_cols >= 4096u ? 1024u : 128u,