From 99f1bc5af64ec44e4492a1f7bda7a5e1c9158f75 Mon Sep 17 00:00:00 2001 From: Helder Vasconcelos Date: Thu, 11 Jun 2026 22:39:26 +0100 Subject: [PATCH] Multi-row GEMV decode kernels (skinny_mm, default-on, ABI v9) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MLX's Metal matmul drops from ~161 GB/s (GEMV, M=1) to ~56 GB/s (tiled GEMM) the moment M reaches 2 (ml-explore/mlx#3661) — exactly the continuous-batching decode shape, where every per-token step is weight-bandwidth-bound. This was the ~2.6x per-row decode cliff between B=1 and B=2 that capped batched throughput. model/skinny_matmul adds two custom fast::metal_kernel kernels for the dense-fp16 B in [2,16], L==1 decode shape: each simdgroup reads a weight row once (coalesced half4) and keeps the batch's activations as register accumulators — a one-column-per-simdgroup variant for B <= 4 (~GEMV bandwidth) and a two-column variant for 5..16 (halves the redundant activation reads; still ahead of the tiled GEMM at 16). linear() falls back to mx::matmul past 16, where the GEMM wins. Plumbing: EngineConfig::skinny_mm (default on) -> Worker -> applied via DecoderModel::set_skinny_mm post-load (raw models keep stock numerics); ABI v9 appends opts2.skinny_mm (struct_size-gated, no new symbols); --skinny-mm / MLXFORGE_SKINNY_MM / "skinny_mm"; node skinnyMm. Gates (255/255 green, default-on): pure-kernel allclose grid across both variants, M in {2..16}, D in {128,1024}, odd O (tail guard) + applies() shape rejection; worker-level row-for-row token equality of a B=4 kernel-on batch vs the stock-matmul batch; C ABI on/off/default output agreement at B=2; ServerConfig parse. check-abi.sh clean. Accumulation is fp32 in a different order than mx::matmul (fp16-noise logit drift, same class as the decode-vs-recompute gap) — hence token-level gates. Bench (Qwen3-0.6B-bf16, M1 Pro, pure defaults, vs llama.cpp b8470 same run): N=8 agg 286 vs 218 tok/s (+31%), N=16 288 vs 301 (-4%, was -35% at the start of the day), N=1 parity, TTFT lead at every level. Per-row decode B=2: 39 -> 86 tok/s. Co-Authored-By: Claude Fable 5 --- CMakeLists.txt | 1 + apps/mlxforge.cpp | 4 +- bindings/node/index.d.ts | 6 ++ bindings/node/src/addon.cc | 2 + doc/applications.md | 1 + doc/architecture.md | 14 ++++ src/capi/mlxforge.cpp | 3 + src/capi/mlxforge.h | 16 +++- src/model/decoder_model.cpp | 8 +- src/model/decoder_model.h | 8 ++ src/model/skinny_matmul.cpp | 114 +++++++++++++++++++++++++++++ src/model/skinny_matmul.h | 37 ++++++++++ src/runtime/engine.cpp | 2 +- src/runtime/engine.h | 6 ++ src/runtime/worker.cpp | 6 +- src/runtime/worker.h | 10 ++- src/server/config.cpp | 6 +- src/server/config.h | 7 +- tests/CMakeLists.txt | 1 + tests/capi/capi_test.cpp | 37 ++++++++++ tests/model/skinny_matmul_test.cpp | 66 +++++++++++++++++ tests/scheduler/worker_test.cpp | 54 ++++++++++++++ tests/server/hardening_test.cpp | 8 ++ 23 files changed, 404 insertions(+), 13 deletions(-) create mode 100644 src/model/skinny_matmul.cpp create mode 100644 src/model/skinny_matmul.h create mode 100644 tests/model/skinny_matmul_test.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a3e166..a2b5018 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -57,6 +57,7 @@ add_library(mlxforge_core STATIC src/core/model_source.cpp src/model/decoder_model.cpp src/model/sdpa.cpp + src/model/skinny_matmul.cpp src/model/qwen3.cpp src/model/qwen3_moe.cpp src/model/qwen3_5.cpp diff --git a/apps/mlxforge.cpp b/apps/mlxforge.cpp index 5996dc1..36b8afc 100644 --- a/apps/mlxforge.cpp +++ b/apps/mlxforge.cpp @@ -80,6 +80,7 @@ void print_help() { " --kv-spill-bytes spill-dir disk budget in bytes, 0 = unbounded (default 0)\n" " --prefill-chunk interleaved-prefill chunk size in tokens, 0 = monolithic\n" " (default 256: decode keeps streaming during prefills)\n" + " --skinny-mm <0|1> multi-row GEMV decode kernels for small batches (default 1)\n" " -h, --help show this help and exit\n" "\n" "The model may be given via -m or the config file's \"model\" key.\n" @@ -87,7 +88,7 @@ void print_help() { "Env vars: MLXFORGE_HOST, MLXFORGE_PORT, MLXFORGE_MAX_CTX, MLXFORGE_MAX_WAITING, " "MLXFORGE_KV_BUDGET, MLXFORGE_KV_BITS, MLXFORGE_PREFIX_CACHE, MLXFORGE_KV_BLOCK, " "MLXFORGE_KV_POOL, MLXFORGE_KV_SPILL_DIR, MLXFORGE_KV_SPILL_BYTES, " - "MLXFORGE_PREFILL_CHUNK."); + "MLXFORGE_PREFILL_CHUNK, MLXFORGE_SKINNY_MM."); std::fflush(stdout); } @@ -146,6 +147,7 @@ int main(int argc, char** argv) { ec.kv_spill_dir = sc.kv_spill_dir; ec.kv_spill_bytes = sc.kv_spill_bytes; ec.prefill_chunk = sc.prefill_chunk; + ec.skinny_mm = sc.skinny_mm; engine = std::make_unique(std::move(ec)); } catch (const std::exception& e) { mlxforge::log::error("model error: {}", e.what()); diff --git a/bindings/node/index.d.ts b/bindings/node/index.d.ts index 4d0d78a..2d775ab 100644 --- a/bindings/node/index.d.ts +++ b/bindings/node/index.d.ts @@ -37,6 +37,12 @@ export interface EngineOptions { * or queued prefills. Default 256 (on); 0 = monolithic prefill (off). */ prefillChunk?: number; + /** + * Multi-row GEMV decode kernels: small batched-decode matmuls (2-16 rows) + * bypass MLX's tiled GEMM, roughly doubling per-row decode throughput at + * small batch sizes. Default true. + */ + skinnyMm?: boolean; } export interface SamplingOptions { diff --git a/bindings/node/src/addon.cc b/bindings/node/src/addon.cc index 1a6f2cb..0d84699 100644 --- a/bindings/node/src/addon.cc +++ b/bindings/node/src/addon.cc @@ -279,6 +279,8 @@ class EngineWrap : public Napi::ObjectWrap { const int chunk = o.Get("prefillChunk").As().Int32Value(); opts.prefill_chunk = chunk <= 0 ? -1 : chunk; } + if (o.Has("skinnyMm") && o.Get("skinnyMm").IsBoolean()) + opts.skinny_mm = o.Get("skinnyMm").As().Value() ? 1 : -1; } char* err = nullptr; diff --git a/doc/applications.md b/doc/applications.md index bc6a3d5..fc667a4 100644 --- a/doc/applications.md +++ b/doc/applications.md @@ -64,6 +64,7 @@ with environment-variable fallbacks (`server/config`): | `--kv-spill-dir` | `MLXFORGE_KV_SPILL_DIR` | off | SSD spill directory: RAM-evicted prefix blocks persist here and survive restarts. | | `--kv-spill-bytes` | `MLXFORGE_KV_SPILL_BYTES` | `0` (unbounded) | Disk budget for the spill directory. | | `--prefill-chunk` | `MLXFORGE_PREFILL_CHUNK` | `256` (on) | Chunked-prefill interleaving: admissions prefill this many tokens per engine step with a decode step in between, so in-flight requests keep streaming during long or queued prefills (+25–35% batched throughput, up to 60% lower TTFT under load). `0` = monolithic prefill per admission. | +| `--skinny-mm` | `MLXFORGE_SKINNY_MM` | `1` (on) | Multi-row GEMV decode kernels: dense fp16 matmuls of the batched-decode shape (2–16 rows) bypass MLX's tiled GEMM, which runs at a fraction of GEMV bandwidth there ([mlx#3661](https://github.com/ml-explore/mlx/issues/3661)) — roughly 2× per-row decode throughput at small batch sizes. | ### Logging diff --git a/doc/architecture.md b/doc/architecture.md index e8d8855..4b345d0 100644 --- a/doc/architecture.md +++ b/doc/architecture.md @@ -161,6 +161,20 @@ In `Worker::decode_step()`: proof that one eval covers the whole batch. 5. Read the chosen ids back to the host, push each row's token, mark finished rows. +**Multi-row GEMV kernels** (`skinny_mm`, default on): a decode step is +weight-bandwidth-bound, and MLX's tiled GEMM drops to ~1/3 of GEMV bandwidth +the moment the batch reaches 2 rows (ml-explore/mlx#3661) — historically a +~2.6× per-row decode cliff between B=1 and B=2. `model/skinny_matmul` provides +custom `fast::metal_kernel` kernels for the B∈[2,16] dense-fp16 decode shape: +each simdgroup reads a weight row once and keeps the batch's activations as +register accumulators (a one-column variant for B≤4 at ~GEMV bandwidth, a +two-column variant for 5–16). Past B=16 the tiled GEMM wins and `linear()` +falls back. Accumulation is fp32 in a different order than `mx::matmul`, so +logits differ at fp16-noise scale; the gate is row-for-row token equality of a +kernel-on batch against the stock-matmul batch +(`tests/scheduler/worker_test.cpp`), plus a pure-kernel `allclose` grid +(`tests/model/skinny_matmul_test.cpp`). + ### Batch-size bucketing If the active batch shape changed every step, MLX would re-trace/re-compile the diff --git a/src/capi/mlxforge.cpp b/src/capi/mlxforge.cpp index 3b76736..9c4e6d6 100644 --- a/src/capi/mlxforge.cpp +++ b/src/capi/mlxforge.cpp @@ -176,6 +176,9 @@ mlxforge_engine* mlxforge_engine_create2(const char* model_spec, * (256, on); negative explicitly disables (monolithic prefill). */ if (covered(&opts->prefill_chunk + 1) && opts->prefill_chunk != 0) cfg.prefill_chunk = opts->prefill_chunk < 0 ? 0 : opts->prefill_chunk; + /* v9: multi-row GEMV decode kernels. Zero-init keeps the default (on). */ + if (covered(&opts->skinny_mm + 1) && opts->skinny_mm != 0) + cfg.skinny_mm = opts->skinny_mm > 0; auto handle = std::make_unique(); handle->model_name = model_spec; diff --git a/src/capi/mlxforge.h b/src/capi/mlxforge.h index f2e34c6..c0f16c1 100644 --- a/src/capi/mlxforge.h +++ b/src/capi/mlxforge.h @@ -48,8 +48,10 @@ extern "C" { * kv_pool_bytes, kv_spill_dir, kv_spill_bytes) — appended, struct_size- * gated; no new symbols. * v8: mlxforge_engine_opts2.prefill_chunk (chunked-prefill interleaving, + * default-on) — appended, struct_size-gated; no new symbols. + * v9: mlxforge_engine_opts2.skinny_mm (multi-row GEMV decode kernels, * default-on) — appended, struct_size-gated; no new symbols. */ -#define MLXFORGE_ABI_VERSION 8 +#define MLXFORGE_ABI_VERSION 9 typedef struct mlxforge_engine mlxforge_engine; typedef struct mlxforge_request mlxforge_request; @@ -143,7 +145,14 @@ mlxforge_engine* mlxforge_engine_create(const char* model_spec, * this many tokens per worker iteration with a decode step in between, so * in-flight requests keep streaming during long or queued prefills. On by * default (256). 0 keeps the default; < 0 disables it (monolithic prefill - * per admission, the pre-v8 behavior). */ + * per admission, the pre-v8 behavior). + * + * skinny_mm (v9+) toggles the multi-row GEMV decode kernels: dense fp16 + * matmuls of the batched-decode shape (2-16 rows, one token each) bypass + * MLX's tiled GEMM, which runs at a fraction of GEMV bandwidth there + * (ml-explore/mlx#3661) — roughly 2x per-row decode throughput at small + * batch sizes. On by default. 0 keeps the default; < 0 disables (stock + * matmul); > 0 enables. */ typedef struct { size_t struct_size; /* caller sets sizeof(mlxforge_engine_opts2) */ int max_waiting; /* max queued requests; <= 0 => default (256) */ @@ -158,6 +167,9 @@ typedef struct { /* ---- v8 ---- */ int prefill_chunk; /* tokens per interleaved prefill chunk; 0 => default (256); < 0 => monolithic prefill (off) */ + /* ---- v9 ---- */ + int skinny_mm; /* multi-row GEMV decode kernels; 0 => default (on); + < 0 => off (stock matmul); > 0 => on */ } mlxforge_engine_opts2; /* Create an engine with extended options (v6+). Identical contract to diff --git a/src/model/decoder_model.cpp b/src/model/decoder_model.cpp index f7a659a..48dada1 100644 --- a/src/model/decoder_model.cpp +++ b/src/model/decoder_model.cpp @@ -8,6 +8,7 @@ #include "core/logging.h" #include "model/sdpa.h" +#include "model/skinny_matmul.h" #include "mlx/fast.h" #include "mlx/ops.h" @@ -121,7 +122,12 @@ mx::array DecoderModel::linear(const mx::array& x, const std::string& weight_key w_.at(base + ".biases"), /*transpose=*/true, qp.group_size, qp.bits); } - return mx::matmul(x, mx::transpose(w_.at(weight_key))); // weight is (out, in) + const mx::array& w = w_.at(weight_key); + // The batched-decode shape (B in [2, 16], L == 1) takes the multi-row GEMV + // kernels when enabled — MLX's tiled GEMM runs at a fraction of GEMV + // bandwidth there (mlx#3661). See set_skinny_mm(). + if (skinny_mm_ && skinny_matmul_applies(x, w)) return skinny_matmul(x, w); + return mx::matmul(x, mx::transpose(w)); // weight is (out, in) } mx::array DecoderModel::embed(const mx::array& tokens) const { diff --git a/src/model/decoder_model.h b/src/model/decoder_model.h index 87e00d2..d381ae4 100644 --- a/src/model/decoder_model.h +++ b/src/model/decoder_model.h @@ -99,6 +99,13 @@ class DecoderModel { // 0 where a key is causally valid and not left-padding, -inf otherwise. mx::array batch_mask(int prev_idx, int n_query, const mx::array& left_padding) const; + // Enable the multi-row GEMV decode kernels (model/skinny_matmul): linear() + // routes dense fp16 matmuls of the batched-decode shape (B in [2, 16], + // L == 1) through them instead of MLX's tiled GEMM. Off by default so raw + // models (single-stream tools, golden references) keep stock numerics; the + // Worker turns it on per EngineConfig::skinny_mm after loading. + void set_skinny_mm(bool on) { skinny_mm_ = on; } + protected: // Abstract base: construct a concrete subclass (or via create_model()). DecoderModel(ModelConfig config, Weights weights); @@ -131,6 +138,7 @@ class DecoderModel { ModelConfig cfg_; Weights w_; mx::array rope_freqs_; + bool skinny_mm_ = false; // see set_skinny_mm() }; } // namespace mlxforge diff --git a/src/model/skinny_matmul.cpp b/src/model/skinny_matmul.cpp new file mode 100644 index 0000000..9d3eef4 --- /dev/null +++ b/src/model/skinny_matmul.cpp @@ -0,0 +1,114 @@ +#include "model/skinny_matmul.h" + +#include +#include + +#include "mlx/fast.h" +#include "mlx/ops.h" + +namespace mlxforge { + +namespace { + +// One simdgroup per output column o: 32 lanes split D with half4 loads, so the +// weight row is read once (coalesced) for all M activation rows; per-lane fp32 +// accumulators reduce via simd_sum. M is a compile-time template arg so the +// accumulators stay in registers. Best for M in [2, 4]. +constexpr const char* kSourceOneCol = R"( + const uint lane = thread_position_in_grid.x; // 0..31 + const uint o = thread_position_in_grid.y; // output column (w row) + const int D = w_shape[1]; + const int O = w_shape[0]; + + const device half4* wrow = (const device half4*)(w + (size_t)o * D); + float acc[M]; + for (int m = 0; m < M; ++m) acc[m] = 0.0f; + + for (int k = lane; k < D / 4; k += 32) { + half4 wv = wrow[k]; + for (int m = 0; m < M; ++m) { + half4 xv = ((const device half4*)(x + (size_t)m * D))[k]; + acc[m] += (float)wv.x * (float)xv.x + (float)wv.y * (float)xv.y + + (float)wv.z * (float)xv.z + (float)wv.w * (float)xv.w; + } + } + for (int m = 0; m < M; ++m) { + float r = simd_sum(acc[m]); + if (lane == 0) y[(size_t)m * O + o] = (half)r; + } +)"; + +// Two output columns per simdgroup: each activation load feeds two weight +// rows, halving the redundant x traffic that degrades the one-column variant +// past M ~ 8. Best for M in [5, 16]; the crossover vs the tiled GEMM is past +// 16 (66 GB/s at M=16 vs the GEMM's flat ~56 on M1 Pro). +constexpr const char* kSourceTwoCol = R"( + const uint lane = thread_position_in_grid.x; // 0..31 + const uint pair = thread_position_in_grid.y; // output column pair + const int D = w_shape[1]; + const int O = w_shape[0]; + const uint o0 = pair * 2; + const uint o1 = o0 + 1; + const bool has1 = o1 < (uint)O; + + const device half4* w0 = (const device half4*)(w + (size_t)o0 * D); + const device half4* w1 = (const device half4*)(w + (size_t)(has1 ? o1 : o0) * D); + float acc0[M]; + float acc1[M]; + for (int m = 0; m < M; ++m) { acc0[m] = 0.0f; acc1[m] = 0.0f; } + + for (int k = lane; k < D / 4; k += 32) { + half4 wv0 = w0[k]; + half4 wv1 = w1[k]; + for (int m = 0; m < M; ++m) { + half4 xv = ((const device half4*)(x + (size_t)m * D))[k]; + acc0[m] += (float)wv0.x * (float)xv.x + (float)wv0.y * (float)xv.y + + (float)wv0.z * (float)xv.z + (float)wv0.w * (float)xv.w; + acc1[m] += (float)wv1.x * (float)xv.x + (float)wv1.y * (float)xv.y + + (float)wv1.z * (float)xv.z + (float)wv1.w * (float)xv.w; + } + } + for (int m = 0; m < M; ++m) { + float r0 = simd_sum(acc0[m]); + float r1 = simd_sum(acc1[m]); + if (lane == 0) { + y[(size_t)m * O + o0] = (half)r0; + if (has1) y[(size_t)m * O + o1] = (half)r1; + } + } +)"; + +constexpr int kOneColMaxM = 4; +constexpr int kMaxM = 16; + +} // namespace + +bool skinny_matmul_applies(const mx::array& x, const mx::array& w) { + if (x.dtype() != mx::float16 || w.dtype() != mx::float16) return false; + if (w.ndim() != 2 || w.shape()[1] % 128 != 0) return false; + const int nd = x.ndim(); + if (nd == 3 && x.shape()[1] != 1) return false; // decode shape only, never prefill + if (nd != 2 && nd != 3) return false; + const int m = x.shape()[0]; + return m >= 2 && m <= kMaxM && x.shape()[nd - 1] == w.shape()[1]; +} + +mx::array skinny_matmul(const mx::array& x, const mx::array& w) { + static const auto one_col = mx::fast::metal_kernel( + "mlxforge_gemv_multirow", {"x", "w"}, {"y"}, kSourceOneCol); + static const auto two_col = mx::fast::metal_kernel( + "mlxforge_gemv_multirow2", {"x", "w"}, {"y"}, kSourceTwoCol); + + const int m = x.shape()[0]; + const int o = w.shape()[0]; + mx::array x2 = x.ndim() == 3 ? mx::reshape(x, {m, x.shape()[2]}) : x; + const bool narrow = m <= kOneColMaxM; + std::vector out = (narrow ? one_col : two_col)( + {x2, w}, {mx::Shape{m, o}}, {mx::float16}, + /*grid=*/{32, narrow ? o : (o + 1) / 2, 1}, /*threadgroup=*/{32, 1, 1}, + /*template_args=*/{{"M", m}}, + /*init_value=*/std::nullopt, /*verbose=*/false, {}); + return x.ndim() == 3 ? mx::reshape(out[0], {m, 1, o}) : out[0]; +} + +} // namespace mlxforge diff --git a/src/model/skinny_matmul.h b/src/model/skinny_matmul.h new file mode 100644 index 0000000..dcb000e --- /dev/null +++ b/src/model/skinny_matmul.h @@ -0,0 +1,37 @@ +// Multi-row GEMV kernels for the batched-decode regime. +// +// MLX's Metal matmul drops from ~161 GB/s (GEMV, M=1) to ~56 GB/s (tiled GEMM) +// the moment M reaches 2, and every M in [2, 64] pays the same tile cost +// (ml-explore/mlx#3661) — exactly the continuous-batching decode shape, where +// each per-token step is weight-bandwidth-bound. These custom +// fast::metal_kernel kernels read each weight row once per simdgroup and keep +// the M activation rows as register accumulators, recovering GEMV-class +// bandwidth: a one-column-per-simdgroup variant for M in [2, 4] (~161 GB/s) +// and a two-column variant for M in [5, 16] (the doubled arithmetic intensity +// halves the redundant activation reads that degrade larger M; ~125 GB/s at +// M=8, still ahead of the tiled GEMM at M=16). Beyond 16 the fallback GEMM is +// the right path. +// +// Accumulation is fp32 but in a different order than mx::matmul, so logits can +// differ at fp16-noise scale — the same class as the decode-vs-recompute gap. +// Gated by EngineConfig::skinny_mm (default on) via DecoderModel::set_skinny_mm +// and token-equality tests against the stock-matmul stream. +#pragma once + +#include "mlx/array.h" + +namespace mx = mlx::core; + +namespace mlxforge { + +// True when the kernel path applies to the shapes: x is (B, 1, D) or (B, D) +// fp16 with B in [2, 16], w is a dense fp16 (O, D) weight, and D is a multiple +// of 128 (half4 loads across 32 lanes). Enablement is the caller's flag +// (DecoderModel::skinny_mm_); this checks shapes only. +bool skinny_matmul_applies(const mx::array& x, const mx::array& w); + +// x @ w.T via the multi-row GEMV kernels. Preserves x's leading shape: +// (B, 1, D) -> (B, 1, O), (B, D) -> (B, O). +mx::array skinny_matmul(const mx::array& x, const mx::array& w); + +} // namespace mlxforge diff --git a/src/runtime/engine.cpp b/src/runtime/engine.cpp index fa54330..32f3e39 100644 --- a/src/runtime/engine.cpp +++ b/src/runtime/engine.cpp @@ -186,7 +186,7 @@ Engine::Engine(EngineConfig cfg, Loaded loaded) // cfg_ is initialized above, so the KV-quant validation sees the model. worker_(make_factory(std::move(loaded.dir), loaded.is_gguf), &scheduler_, &tok_, validate_kv_quant(cfg, cfg_), validate_prefix_cache(cfg, cfg_, model_name_), - validate_prefill_chunk(cfg)) { + validate_prefill_chunk(cfg), cfg.skinny_mm) { // Configure the max waiting requests for the batch scheduler. scheduler_.set_max_waiting(cfg.max_waiting); diff --git a/src/runtime/engine.h b/src/runtime/engine.h index e897c96..685d8fb 100644 --- a/src/runtime/engine.h +++ b/src/runtime/engine.h @@ -51,6 +51,12 @@ struct EngineConfig { // benchmarked sweet spot); 0 = monolithic prefill per admission. Negative // values are rejected at construction. int prefill_chunk = 256; + // Multi-row GEMV decode kernels (model/skinny_matmul): dense fp16 matmuls of + // the batched-decode shape (B in [2, 16], L == 1) bypass MLX's tiled GEMM, + // which runs at a fraction of GEMV bandwidth there (ml-explore/mlx#3661). + // On by default; logits may differ from the stock kernel at fp16-noise scale + // (fp32 accumulation in a different order), token-equality gated in tests. + bool skinny_mm = true; }; // Per-call embedding options. The two int fields are tri-state: -1 means "use diff --git a/src/runtime/worker.cpp b/src/runtime/worker.cpp index 64726d6..4914d20 100644 --- a/src/runtime/worker.cpp +++ b/src/runtime/worker.cpp @@ -59,9 +59,10 @@ bool consume(Request& req, int& produced, int id, const TokenLogprob* lp) { } // namespace Worker::Worker(ModelFactory factory, Scheduler* scheduler, const Tokenizer* tok, - KVQuantConfig kv_quant, PrefixCacheConfig prefix, int prefill_chunk) + KVQuantConfig kv_quant, PrefixCacheConfig prefix, int prefill_chunk, + bool skinny_mm) : factory_(std::move(factory)), sched_(scheduler), tok_(tok), kv_quant_(kv_quant), - prefix_cfg_(prefix), prefill_chunk_(prefill_chunk) {} + prefix_cfg_(prefix), prefill_chunk_(prefill_chunk), skinny_mm_(skinny_mm) {} Worker::~Worker() { stop(); } @@ -195,6 +196,7 @@ void Worker::stop() { void Worker::run() { log::info("worker: loading model..."); model_ = factory_(); // load the model on this thread so its arrays live here + model_->set_skinny_mm(skinny_mm_); // batched-decode GEMV kernels (engine option) // The prefix pool holds MLX arrays, so it lives (and dies) with this thread. if (prefix_cfg_.enabled) { prefix_ = std::make_unique(prefix_cfg_); diff --git a/src/runtime/worker.h b/src/runtime/worker.h index fabbb2b..0c12437 100644 --- a/src/runtime/worker.h +++ b/src/runtime/worker.h @@ -41,11 +41,14 @@ class Worker { // `kv_quant` selects the decode cache's storage (dense fp16 by default) and // `prefix` the prefix-cache setting; the Engine validates both against the // model before construction. `prefill_chunk` is the interleaved-admission - // chunk size in tokens (0 = monolithic prefill, see below). Defined + // chunk size in tokens (0 = monolithic prefill, see below). `skinny_mm` + // routes batched-decode dense matmuls through the multi-row GEMV kernels + // (model/skinny_matmul; applied to the model after it loads). Defined // out-of-line (with the destructor) because the unique_ptr // member needs the complete type for cleanup. Worker(ModelFactory factory, Scheduler* scheduler, const Tokenizer* tok = nullptr, - KVQuantConfig kv_quant = {}, PrefixCacheConfig prefix = {}, int prefill_chunk = 256); + KVQuantConfig kv_quant = {}, PrefixCacheConfig prefix = {}, int prefill_chunk = 256, + bool skinny_mm = true); ~Worker(); Worker(const Worker&) = delete; @@ -193,7 +196,8 @@ class Worker { // Interleaved-prefill queue (worker thread only; empty when idle or feature off). std::deque pending_; - int prefill_chunk_; // chunk size in tokens; 0 = monolithic admits + int prefill_chunk_; // chunk size in tokens; 0 = monolithic admits + bool skinny_mm_; // multi-row GEMV decode kernels (set on the model post-load) std::atomic decode_steps_{0}; std::atomic ready_{false}; diff --git a/src/server/config.cpp b/src/server/config.cpp index b50f303..9140ba3 100644 --- a/src/server/config.cpp +++ b/src/server/config.cpp @@ -59,7 +59,7 @@ ServerConfig ServerConfig::from_file(const std::string& path) { static const std::set kKnownKeys = { "model", "host", "port", "max_ctx", "max_waiting", "kv_budget", "kv_bits", "prefix_cache", "kv_block", "kv_pool", "kv_spill_dir", "kv_spill_bytes", - "prefill_chunk"}; + "prefill_chunk", "skinny_mm"}; for (const auto& [key, _] : j.items()) { if (kKnownKeys.find(key) == kKnownKeys.end()) { throw std::runtime_error("config file: unknown key '" + key + "' in '" + path + "'"); @@ -114,6 +114,7 @@ ServerConfig ServerConfig::from_file(const std::string& path) { if (c.prefill_chunk < 0) throw std::runtime_error("config file: 'prefill_chunk' must be >= 0 (0 = monolithic)"); } + if (j.contains("skinny_mm")) c.skinny_mm = require_type(j, "skinny_mm"); return c; } @@ -165,6 +166,7 @@ ServerConfig ServerConfig::parse(const std::vector& args) { c.kv_spill_bytes = static_cast( env_long("MLXFORGE_KV_SPILL_BYTES", static_cast(c.kv_spill_bytes))); c.prefill_chunk = static_cast(env_long("MLXFORGE_PREFILL_CHUNK", c.prefill_chunk)); + c.skinny_mm = env_long("MLXFORGE_SKINNY_MM", c.skinny_mm ? 1 : 0) != 0; // Helper: extract value for a flag (accepts "--flag value" or "--flag=value") auto value_of = [&](const std::string& a, size_t& i) -> std::string { @@ -209,6 +211,8 @@ ServerConfig ServerConfig::parse(const std::vector& args) { c.kv_spill_bytes = static_cast(std::stoll(value_of(a, i))); else if (flag == "--prefill-chunk") c.prefill_chunk = std::stoi(value_of(a, i)); + else if (flag == "--skinny-mm") + c.skinny_mm = std::stoi(value_of(a, i)) != 0; else throw std::runtime_error("unknown flag: " + flag); } diff --git a/src/server/config.h b/src/server/config.h index 5fec91a..face2d4 100644 --- a/src/server/config.h +++ b/src/server/config.h @@ -48,6 +48,9 @@ struct ServerConfig { // decode step in between (default on at 256). 0 = monolithic prefill. int prefill_chunk = 256; + // Multi-row GEMV decode kernels for small batched-decode matmuls (default on). + bool skinny_mm = true; + // Parses command line arguments (the model via -m/--model, plus optional flags as --flag value or --flag=value), // layering configuration sources by precedence (lowest to highest): // struct defaults < config file (-c/--config) < environment variables < CLI flags. @@ -55,7 +58,7 @@ struct ServerConfig { // Env vars: MLXFORGE_HOST, MLXFORGE_PORT, MLXFORGE_MAX_CTX, MLXFORGE_MAX_WAITING, // MLXFORGE_KV_BUDGET, MLXFORGE_KV_BITS, MLXFORGE_PREFIX_CACHE, MLXFORGE_KV_BLOCK, // MLXFORGE_KV_POOL, MLXFORGE_KV_SPILL_DIR, MLXFORGE_KV_SPILL_BYTES, - // MLXFORGE_PREFILL_CHUNK. + // MLXFORGE_PREFILL_CHUNK, MLXFORGE_SKINNY_MM. // Throws std::runtime_error if an unknown or malformed flag is encountered. static ServerConfig parse(const std::vector& args); @@ -63,7 +66,7 @@ struct ServerConfig { // with struct defaults filling any keys the file omits. Recognized keys // (snake_case): "model", "host", "port", "max_ctx", "max_waiting", "kv_budget", // "kv_bits", "prefix_cache", "kv_block", "kv_pool", "kv_spill_dir", - // "kv_spill_bytes", "prefill_chunk". + // "kv_spill_bytes", "prefill_chunk", "skinny_mm". // Validates before applying: rejects unknown keys, wrong types, and out-of-range // values. Throws std::runtime_error (with the file path / offending key) on any // failure to open, parse, or validate. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 6093201..47843e2 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -14,6 +14,7 @@ add_executable(mlxforge_tests model/llama_block_test.cpp model/llama_forward_test.cpp model/batched_decode_test.cpp + model/skinny_matmul_test.cpp cache/kv_cache_test.cpp cache/batch_kv_cache_test.cpp cache/batch_kv_filter_merge_test.cpp diff --git a/tests/capi/capi_test.cpp b/tests/capi/capi_test.cpp index a05b436..c3e529b 100644 --- a/tests/capi/capi_test.cpp +++ b/tests/capi/capi_test.cpp @@ -388,3 +388,40 @@ TEST_CASE("C ABI v8 prefill_chunk: chunked and monolithic engines agree") { CHECK(run_with_chunk(8) == monolithic); // aggressive chunking CHECK(run_with_chunk(0) == monolithic); // the default (256, on) } + +TEST_CASE("C ABI v9 skinny_mm: kernel-on and stock-matmul engines agree") { + if (!model_available()) { + MESSAGE("MLXFORGE_MODEL_DIR not present; skipping"); + return; + } + const char* prompt = + "Once upon a time in a quiet village by the sea, a young engineer set out " + "to write a Metal kernel that read every weight row exactly once. Describe " + "what it computed first."; + mlxforge_sampling s = {}; + s.max_tokens = 12; + + // Two concurrent submissions put the decode batch at B=2 — the kernel shape. + auto run_with = [&](int skinny) { + char* err = nullptr; + mlxforge_engine_opts2 opts = {}; + opts.struct_size = sizeof(opts); + opts.skinny_mm = skinny; // 0 => default (on); < 0 => stock matmul + mlxforge_engine* eng = mlxforge_engine_create2(model_dir().c_str(), &opts, &err); + REQUIRE_MESSAGE(eng != nullptr, (err ? err : "engine_create2 failed")); + mlxforge_request* a = mlxforge_submit_text(eng, prompt, &s, &err); + REQUIRE_MESSAGE(a != nullptr, (err ? err : "submit failed")); + mlxforge_request* b = mlxforge_submit_text(eng, prompt, &s, &err); + REQUIRE_MESSAGE(b != nullptr, (err ? err : "submit failed")); + const std::string out = drain(a) + "\n---\n" + drain(b); + mlxforge_request_free(a); + mlxforge_request_free(b); + mlxforge_engine_free(eng); + return out; + }; + + const std::string stock = run_with(-1); + CHECK(stock.size() > 5); + CHECK(run_with(1) == stock); // explicit on + CHECK(run_with(0) == stock); // the default (on) +} diff --git a/tests/model/skinny_matmul_test.cpp b/tests/model/skinny_matmul_test.cpp new file mode 100644 index 0000000..182d2f9 --- /dev/null +++ b/tests/model/skinny_matmul_test.cpp @@ -0,0 +1,66 @@ +// The multi-row GEMV decode kernels must agree with mx::matmul at fp16-noise +// tolerance across the whole M range and both kernel variants (one-column for +// M <= 4, two-column for 5..16, including an odd O exercising the tail guard), +// and the shape gate must reject everything outside the batched-decode shape. +// Pure GPU test — no model weights needed. +#include + +#include "model/skinny_matmul.h" + +#include "mlx/ops.h" +#include "mlx/random.h" +#include "mlx/transforms.h" + +namespace { + +float max_abs_diff(const mx::array& a, const mx::array& b) { + mx::array d = mx::max(mx::abs(mx::subtract(mx::astype(a, mx::float32), + mx::astype(b, mx::float32)))); + mx::eval(d); + return d.item(); +} + +} // namespace + +TEST_CASE("skinny_matmul matches mx::matmul across M, D, and both variants") { + for (int d : {128, 1024}) { + for (int o : {17, 1536}) { // odd O exercises the two-column tail guard + mx::array w = mx::astype(mx::random::normal({o, d}), mx::float16); + for (int m : {2, 3, 4, 5, 8, 12, 16}) { + CAPTURE(m); + CAPTURE(d); + CAPTURE(o); + mx::array x = mx::astype( + mx::multiply(mx::random::normal({m, d}), mx::array(0.05f)), mx::float16); + REQUIRE(mlxforge::skinny_matmul_applies(x, w)); + mx::array ref = mx::matmul(x, mx::transpose(w)); + mx::array got = mlxforge::skinny_matmul(x, w); + CHECK(got.shape() == ref.shape()); + // fp32 accumulation vs the GEMM's accumulation order: fp16-noise scale. + CHECK(max_abs_diff(got, ref) < 5e-3f); + + // The 3-D decode shape (B, 1, D) round-trips its leading shape. + mx::array x3 = mx::reshape(x, {m, 1, d}); + REQUIRE(mlxforge::skinny_matmul_applies(x3, w)); + mx::array got3 = mlxforge::skinny_matmul(x3, w); + CHECK(got3.shape() == mx::Shape{m, 1, o}); + CHECK(max_abs_diff(mx::reshape(got3, {m, o}), ref) < 5e-3f); + } + } + } +} + +TEST_CASE("skinny_matmul_applies rejects everything outside the decode shape") { + mx::array w = mx::astype(mx::random::normal({64, 1024}), mx::float16); + auto x = [&](mx::Shape s, mx::Dtype t = mx::float16) { + return mx::astype(mx::random::normal(std::move(s)), t); + }; + CHECK_FALSE(mlxforge::skinny_matmul_applies(x({1, 1024}), w)); // M=1: GEMV is faster + CHECK_FALSE(mlxforge::skinny_matmul_applies(x({17, 1024}), w)); // past the GEMM crossover + CHECK(mlxforge::skinny_matmul_applies(x({16, 1024}), w)); + CHECK_FALSE(mlxforge::skinny_matmul_applies(x({4, 2, 1024}), w)); // prefill (L > 1) + CHECK(mlxforge::skinny_matmul_applies(x({4, 1, 1024}), w)); + CHECK_FALSE(mlxforge::skinny_matmul_applies(x({4, 1024}, mx::float32), w)); // dtype + mx::array w_odd = mx::astype(mx::random::normal({64, 1000}), mx::float16); + CHECK_FALSE(mlxforge::skinny_matmul_applies(x({4, 1000}), w_odd)); // D % 128 != 0 +} diff --git a/tests/scheduler/worker_test.cpp b/tests/scheduler/worker_test.cpp index b0a31cb..b431fd2 100644 --- a/tests/scheduler/worker_test.cpp +++ b/tests/scheduler/worker_test.cpp @@ -105,3 +105,57 @@ TEST_CASE("chunked prefill reproduces the reference greedy stream across chunk s worker.stop(); } } + +TEST_CASE("skinny_mm decode kernels reproduce the stock-matmul greedy stream") { + if (!model_available()) { + MESSAGE("MLXFORGE_MODEL_DIR not present; skipping"); + return; + } + const std::string dir = model_dir(); + mlxforge::ModelConfig cfg = mlxforge::ModelConfig::from_file(dir + "/config.json"); + + std::vector prompt; + for (const char* name : {"prompt_0_ids.npy", "prompt_1_ids.npy", "prompt_2_ids.npy"}) { + std::vector ids = load_token_ids(name); + prompt.insert(prompt.end(), ids.begin(), ids.end()); + } + const int kMax = 16; + const int kBatch = 4; // decode at B=4 routes every linear through the kernels + + // Reuse may only change speed, never tokens: the kernel-on batch must match + // the kernel-off batch row for row (both greedy on identical prompts). + std::vector> streams[2]; + for (bool skinny : {false, true}) { + mlxforge::Scheduler sched; + mlxforge::Worker worker( + [dir] { + mlxforge::ModelConfig c = mlxforge::ModelConfig::from_file(dir + "/config.json"); + auto w = mlxforge::load_weights(dir, c); + return std::make_unique(std::move(c), std::move(w)); + }, + &sched, /*tok=*/nullptr, /*kv_quant=*/{}, /*prefix=*/{}, /*prefill_chunk=*/256, skinny); + worker.start(); + + std::vector> reqs; + for (int i = 0; i < kBatch; ++i) { + auto r = std::make_shared(); + r->prompt_ids = prompt; + r->params.temperature = 0.0f; + r->max_tokens = kMax; + r->eos_ids = cfg.eos_token_ids; + REQUIRE(sched.submit(r)); + reqs.push_back(std::move(r)); + } + for (const auto& r : reqs) { + std::vector got; + int tok = 0; + while (r->tokens.pop(tok)) got.push_back(tok); + streams[skinny].push_back(std::move(got)); + } + worker.stop(); + } + for (int i = 0; i < kBatch; ++i) { + CAPTURE(i); + assert_tokens_equal(streams[1][i], streams[0][i]); + } +} diff --git a/tests/server/hardening_test.cpp b/tests/server/hardening_test.cpp index 38ecc80..25906c6 100644 --- a/tests/server/hardening_test.cpp +++ b/tests/server/hardening_test.cpp @@ -61,6 +61,14 @@ TEST_CASE("ServerConfig parses prefill_chunk (flag, file, validation)") { std::runtime_error); } +TEST_CASE("ServerConfig parses skinny_mm (flag, file)") { + CHECK(ServerConfig::parse({"-m", "/m"}).skinny_mm == true); // default on + CHECK(ServerConfig::parse({"-m", "/m", "--skinny-mm", "0"}).skinny_mm == false); + CHECK(ServerConfig::parse({"-m", "/m", "--skinny-mm=1"}).skinny_mm == true); + CHECK(ServerConfig::parse({"-c", write_temp_config(R"({"skinny_mm": false})")}) + .skinny_mm == false); +} + TEST_CASE("ServerConfig rejects unknown and positional args") { CHECK_THROWS_AS(ServerConfig::parse({"-m", "/m", "--bogus", "x"}), std::runtime_error); CHECK_THROWS_AS(ServerConfig::parse({"-m"}), std::runtime_error); // missing value