Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ add_library(mlxforge_core STATIC
src/core/model_source.cpp
src/model/decoder_model.cpp
src/model/sdpa.cpp
src/model/skinny_matmul.cpp
src/model/qwen3.cpp
src/model/qwen3_moe.cpp
src/model/qwen3_5.cpp
Expand Down
4 changes: 3 additions & 1 deletion apps/mlxforge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,15 @@ void print_help() {
" --kv-spill-bytes <B> spill-dir disk budget in bytes, 0 = unbounded (default 0)\n"
" --prefill-chunk <N> interleaved-prefill chunk size in tokens, 0 = monolithic\n"
" (default 256: decode keeps streaming during prefills)\n"
" --skinny-mm <0|1> multi-row GEMV decode kernels for small batches (default 1)\n"
" -h, --help show this help and exit\n"
"\n"
"The model may be given via -m or the config file's \"model\" key.\n"
"Config precedence (low to high): defaults < config file < env vars < CLI flags.\n"
"Env vars: MLXFORGE_HOST, MLXFORGE_PORT, MLXFORGE_MAX_CTX, MLXFORGE_MAX_WAITING, "
"MLXFORGE_KV_BUDGET, MLXFORGE_KV_BITS, MLXFORGE_PREFIX_CACHE, MLXFORGE_KV_BLOCK, "
"MLXFORGE_KV_POOL, MLXFORGE_KV_SPILL_DIR, MLXFORGE_KV_SPILL_BYTES, "
"MLXFORGE_PREFILL_CHUNK.");
"MLXFORGE_PREFILL_CHUNK, MLXFORGE_SKINNY_MM.");
std::fflush(stdout);
}

Expand Down Expand Up @@ -146,6 +147,7 @@ int main(int argc, char** argv) {
ec.kv_spill_dir = sc.kv_spill_dir;
ec.kv_spill_bytes = sc.kv_spill_bytes;
ec.prefill_chunk = sc.prefill_chunk;
ec.skinny_mm = sc.skinny_mm;
engine = std::make_unique<mlxforge::Engine>(std::move(ec));
} catch (const std::exception& e) {
mlxforge::log::error("model error: {}", e.what());
Expand Down
6 changes: 6 additions & 0 deletions bindings/node/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ export interface EngineOptions {
* or queued prefills. Default 256 (on); 0 = monolithic prefill (off).
*/
prefillChunk?: number;
/**
* Multi-row GEMV decode kernels: small batched-decode matmuls (2-16 rows)
* bypass MLX's tiled GEMM, roughly doubling per-row decode throughput at
* small batch sizes. Default true.
*/
skinnyMm?: boolean;
}

export interface SamplingOptions {
Expand Down
2 changes: 2 additions & 0 deletions bindings/node/src/addon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ class EngineWrap : public Napi::ObjectWrap<EngineWrap> {
const int chunk = o.Get("prefillChunk").As<Napi::Number>().Int32Value();
opts.prefill_chunk = chunk <= 0 ? -1 : chunk;
}
if (o.Has("skinnyMm") && o.Get("skinnyMm").IsBoolean())
opts.skinny_mm = o.Get("skinnyMm").As<Napi::Boolean>().Value() ? 1 : -1;
}

char* err = nullptr;
Expand Down
1 change: 1 addition & 0 deletions doc/applications.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ with environment-variable fallbacks (`server/config`):
| `--kv-spill-dir` | `MLXFORGE_KV_SPILL_DIR` | off | SSD spill directory: RAM-evicted prefix blocks persist here and survive restarts. |
| `--kv-spill-bytes` | `MLXFORGE_KV_SPILL_BYTES` | `0` (unbounded) | Disk budget for the spill directory. |
| `--prefill-chunk` | `MLXFORGE_PREFILL_CHUNK` | `256` (on) | Chunked-prefill interleaving: admissions prefill this many tokens per engine step with a decode step in between, so in-flight requests keep streaming during long or queued prefills (+25–35% batched throughput, up to 60% lower TTFT under load). `0` = monolithic prefill per admission. |
| `--skinny-mm` | `MLXFORGE_SKINNY_MM` | `1` (on) | Multi-row GEMV decode kernels: dense fp16 matmuls of the batched-decode shape (2–16 rows) bypass MLX's tiled GEMM, which runs at a fraction of GEMV bandwidth there ([mlx#3661](https://github.com/ml-explore/mlx/issues/3661)) — roughly 2× per-row decode throughput at small batch sizes. |

### Logging

Expand Down
14 changes: 14 additions & 0 deletions doc/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ In `Worker::decode_step()`:
proof that one eval covers the whole batch.
5. Read the chosen ids back to the host, push each row's token, mark finished rows.

**Multi-row GEMV kernels** (`skinny_mm`, default on): a decode step is
weight-bandwidth-bound, and MLX's tiled GEMM drops to ~1/3 of GEMV bandwidth
the moment the batch reaches 2 rows (ml-explore/mlx#3661) — historically a
~2.6× per-row decode cliff between B=1 and B=2. `model/skinny_matmul` provides
custom `fast::metal_kernel` kernels for the B∈[2,16] dense-fp16 decode shape:
each simdgroup reads a weight row once and keeps the batch's activations as
register accumulators (a one-column variant for B≤4 at ~GEMV bandwidth, a
two-column variant for 5–16). Past B=16 the tiled GEMM wins and `linear()`
falls back. Accumulation is fp32 in a different order than `mx::matmul`, so
logits differ at fp16-noise scale; the gate is row-for-row token equality of a
kernel-on batch against the stock-matmul batch
(`tests/scheduler/worker_test.cpp`), plus a pure-kernel `allclose` grid
(`tests/model/skinny_matmul_test.cpp`).

### Batch-size bucketing

If the active batch shape changed every step, MLX would re-trace/re-compile the
Expand Down
3 changes: 3 additions & 0 deletions src/capi/mlxforge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ mlxforge_engine* mlxforge_engine_create2(const char* model_spec,
* (256, on); negative explicitly disables (monolithic prefill). */
if (covered(&opts->prefill_chunk + 1) && opts->prefill_chunk != 0)
cfg.prefill_chunk = opts->prefill_chunk < 0 ? 0 : opts->prefill_chunk;
/* v9: multi-row GEMV decode kernels. Zero-init keeps the default (on). */
if (covered(&opts->skinny_mm + 1) && opts->skinny_mm != 0)
cfg.skinny_mm = opts->skinny_mm > 0;

auto handle = std::make_unique<mlxforge_engine>();
handle->model_name = model_spec;
Expand Down
16 changes: 14 additions & 2 deletions src/capi/mlxforge.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ extern "C" {
* kv_pool_bytes, kv_spill_dir, kv_spill_bytes) — appended, struct_size-
* gated; no new symbols.
* v8: mlxforge_engine_opts2.prefill_chunk (chunked-prefill interleaving,
* default-on) — appended, struct_size-gated; no new symbols.
* v9: mlxforge_engine_opts2.skinny_mm (multi-row GEMV decode kernels,
* default-on) — appended, struct_size-gated; no new symbols. */
#define MLXFORGE_ABI_VERSION 8
#define MLXFORGE_ABI_VERSION 9

typedef struct mlxforge_engine mlxforge_engine;
typedef struct mlxforge_request mlxforge_request;
Expand Down Expand Up @@ -143,7 +145,14 @@ mlxforge_engine* mlxforge_engine_create(const char* model_spec,
* this many tokens per worker iteration with a decode step in between, so
* in-flight requests keep streaming during long or queued prefills. On by
* default (256). 0 keeps the default; < 0 disables it (monolithic prefill
* per admission, the pre-v8 behavior). */
* per admission, the pre-v8 behavior).
*
* skinny_mm (v9+) toggles the multi-row GEMV decode kernels: dense fp16
* matmuls of the batched-decode shape (2-16 rows, one token each) bypass
* MLX's tiled GEMM, which runs at a fraction of GEMV bandwidth there
* (ml-explore/mlx#3661) — roughly 2x per-row decode throughput at small
* batch sizes. On by default. 0 keeps the default; < 0 disables (stock
* matmul); > 0 enables. */
typedef struct {
size_t struct_size; /* caller sets sizeof(mlxforge_engine_opts2) */
int max_waiting; /* max queued requests; <= 0 => default (256) */
Expand All @@ -158,6 +167,9 @@ typedef struct {
/* ---- v8 ---- */
int prefill_chunk; /* tokens per interleaved prefill chunk; 0 => default
(256); < 0 => monolithic prefill (off) */
/* ---- v9 ---- */
int skinny_mm; /* multi-row GEMV decode kernels; 0 => default (on);
< 0 => off (stock matmul); > 0 => on */
} mlxforge_engine_opts2;

/* Create an engine with extended options (v6+). Identical contract to
Expand Down
8 changes: 7 additions & 1 deletion src/model/decoder_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "core/logging.h"
#include "model/sdpa.h"
#include "model/skinny_matmul.h"

#include "mlx/fast.h"
#include "mlx/ops.h"
Expand Down Expand Up @@ -121,7 +122,12 @@ mx::array DecoderModel::linear(const mx::array& x, const std::string& weight_key
w_.at(base + ".biases"), /*transpose=*/true, qp.group_size,
qp.bits);
}
return mx::matmul(x, mx::transpose(w_.at(weight_key))); // weight is (out, in)
const mx::array& w = w_.at(weight_key);
// The batched-decode shape (B in [2, 16], L == 1) takes the multi-row GEMV
// kernels when enabled — MLX's tiled GEMM runs at a fraction of GEMV
// bandwidth there (mlx#3661). See set_skinny_mm().
if (skinny_mm_ && skinny_matmul_applies(x, w)) return skinny_matmul(x, w);
return mx::matmul(x, mx::transpose(w)); // weight is (out, in)
}

mx::array DecoderModel::embed(const mx::array& tokens) const {
Expand Down
8 changes: 8 additions & 0 deletions src/model/decoder_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ class DecoderModel {
// 0 where a key is causally valid and not left-padding, -inf otherwise.
mx::array batch_mask(int prev_idx, int n_query, const mx::array& left_padding) const;

// Enable the multi-row GEMV decode kernels (model/skinny_matmul): linear()
// routes dense fp16 matmuls of the batched-decode shape (B in [2, 16],
// L == 1) through them instead of MLX's tiled GEMM. Off by default so raw
// models (single-stream tools, golden references) keep stock numerics; the
// Worker turns it on per EngineConfig::skinny_mm after loading.
void set_skinny_mm(bool on) { skinny_mm_ = on; }

protected:
// Abstract base: construct a concrete subclass (or via create_model()).
DecoderModel(ModelConfig config, Weights weights);
Expand Down Expand Up @@ -131,6 +138,7 @@ class DecoderModel {
ModelConfig cfg_;
Weights w_;
mx::array rope_freqs_;
bool skinny_mm_ = false; // see set_skinny_mm()
};

} // namespace mlxforge
114 changes: 114 additions & 0 deletions src/model/skinny_matmul.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include "model/skinny_matmul.h"

#include <optional>
#include <vector>

#include "mlx/fast.h"
#include "mlx/ops.h"

namespace mlxforge {

namespace {

// One simdgroup per output column o: 32 lanes split D with half4 loads, so the
// weight row is read once (coalesced) for all M activation rows; per-lane fp32
// accumulators reduce via simd_sum. M is a compile-time template arg so the
// accumulators stay in registers. Best for M in [2, 4].
constexpr const char* kSourceOneCol = R"(
const uint lane = thread_position_in_grid.x; // 0..31
const uint o = thread_position_in_grid.y; // output column (w row)
const int D = w_shape[1];
const int O = w_shape[0];

const device half4* wrow = (const device half4*)(w + (size_t)o * D);
float acc[M];
for (int m = 0; m < M; ++m) acc[m] = 0.0f;

for (int k = lane; k < D / 4; k += 32) {
half4 wv = wrow[k];
for (int m = 0; m < M; ++m) {
half4 xv = ((const device half4*)(x + (size_t)m * D))[k];
acc[m] += (float)wv.x * (float)xv.x + (float)wv.y * (float)xv.y +
(float)wv.z * (float)xv.z + (float)wv.w * (float)xv.w;
}
}
for (int m = 0; m < M; ++m) {
float r = simd_sum(acc[m]);
if (lane == 0) y[(size_t)m * O + o] = (half)r;
}
)";

// Two output columns per simdgroup: each activation load feeds two weight
// rows, halving the redundant x traffic that degrades the one-column variant
// past M ~ 8. Best for M in [5, 16]; the crossover vs the tiled GEMM is past
// 16 (66 GB/s at M=16 vs the GEMM's flat ~56 on M1 Pro).
constexpr const char* kSourceTwoCol = R"(
const uint lane = thread_position_in_grid.x; // 0..31
const uint pair = thread_position_in_grid.y; // output column pair
const int D = w_shape[1];
const int O = w_shape[0];
const uint o0 = pair * 2;
const uint o1 = o0 + 1;
const bool has1 = o1 < (uint)O;

const device half4* w0 = (const device half4*)(w + (size_t)o0 * D);
const device half4* w1 = (const device half4*)(w + (size_t)(has1 ? o1 : o0) * D);
float acc0[M];
float acc1[M];
for (int m = 0; m < M; ++m) { acc0[m] = 0.0f; acc1[m] = 0.0f; }

for (int k = lane; k < D / 4; k += 32) {
half4 wv0 = w0[k];
half4 wv1 = w1[k];
for (int m = 0; m < M; ++m) {
half4 xv = ((const device half4*)(x + (size_t)m * D))[k];
acc0[m] += (float)wv0.x * (float)xv.x + (float)wv0.y * (float)xv.y +
(float)wv0.z * (float)xv.z + (float)wv0.w * (float)xv.w;
acc1[m] += (float)wv1.x * (float)xv.x + (float)wv1.y * (float)xv.y +
(float)wv1.z * (float)xv.z + (float)wv1.w * (float)xv.w;
}
}
for (int m = 0; m < M; ++m) {
float r0 = simd_sum(acc0[m]);
float r1 = simd_sum(acc1[m]);
if (lane == 0) {
y[(size_t)m * O + o0] = (half)r0;
if (has1) y[(size_t)m * O + o1] = (half)r1;
}
}
)";

constexpr int kOneColMaxM = 4;
constexpr int kMaxM = 16;

} // namespace

bool skinny_matmul_applies(const mx::array& x, const mx::array& w) {
if (x.dtype() != mx::float16 || w.dtype() != mx::float16) return false;
if (w.ndim() != 2 || w.shape()[1] % 128 != 0) return false;
const int nd = x.ndim();
if (nd == 3 && x.shape()[1] != 1) return false; // decode shape only, never prefill
if (nd != 2 && nd != 3) return false;
const int m = x.shape()[0];
return m >= 2 && m <= kMaxM && x.shape()[nd - 1] == w.shape()[1];
}

mx::array skinny_matmul(const mx::array& x, const mx::array& w) {
static const auto one_col = mx::fast::metal_kernel(
"mlxforge_gemv_multirow", {"x", "w"}, {"y"}, kSourceOneCol);
static const auto two_col = mx::fast::metal_kernel(
"mlxforge_gemv_multirow2", {"x", "w"}, {"y"}, kSourceTwoCol);

const int m = x.shape()[0];
const int o = w.shape()[0];
mx::array x2 = x.ndim() == 3 ? mx::reshape(x, {m, x.shape()[2]}) : x;
const bool narrow = m <= kOneColMaxM;
std::vector<mx::array> out = (narrow ? one_col : two_col)(
{x2, w}, {mx::Shape{m, o}}, {mx::float16},
/*grid=*/{32, narrow ? o : (o + 1) / 2, 1}, /*threadgroup=*/{32, 1, 1},
/*template_args=*/{{"M", m}},
/*init_value=*/std::nullopt, /*verbose=*/false, {});
return x.ndim() == 3 ? mx::reshape(out[0], {m, 1, o}) : out[0];
}

} // namespace mlxforge
37 changes: 37 additions & 0 deletions src/model/skinny_matmul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Multi-row GEMV kernels for the batched-decode regime.
//
// MLX's Metal matmul drops from ~161 GB/s (GEMV, M=1) to ~56 GB/s (tiled GEMM)
// the moment M reaches 2, and every M in [2, 64] pays the same tile cost
// (ml-explore/mlx#3661) — exactly the continuous-batching decode shape, where
// each per-token step is weight-bandwidth-bound. These custom
// fast::metal_kernel kernels read each weight row once per simdgroup and keep
// the M activation rows as register accumulators, recovering GEMV-class
// bandwidth: a one-column-per-simdgroup variant for M in [2, 4] (~161 GB/s)
// and a two-column variant for M in [5, 16] (the doubled arithmetic intensity
// halves the redundant activation reads that degrade larger M; ~125 GB/s at
// M=8, still ahead of the tiled GEMM at M=16). Beyond 16 the fallback GEMM is
// the right path.
//
// Accumulation is fp32 but in a different order than mx::matmul, so logits can
// differ at fp16-noise scale — the same class as the decode-vs-recompute gap.
// Gated by EngineConfig::skinny_mm (default on) via DecoderModel::set_skinny_mm
// and token-equality tests against the stock-matmul stream.
#pragma once

#include "mlx/array.h"

namespace mx = mlx::core;

namespace mlxforge {

// True when the kernel path applies to the shapes: x is (B, 1, D) or (B, D)
// fp16 with B in [2, 16], w is a dense fp16 (O, D) weight, and D is a multiple
// of 128 (half4 loads across 32 lanes). Enablement is the caller's flag
// (DecoderModel::skinny_mm_); this checks shapes only.
bool skinny_matmul_applies(const mx::array& x, const mx::array& w);

// x @ w.T via the multi-row GEMV kernels. Preserves x's leading shape:
// (B, 1, D) -> (B, 1, O), (B, D) -> (B, O).
mx::array skinny_matmul(const mx::array& x, const mx::array& w);

} // namespace mlxforge
2 changes: 1 addition & 1 deletion src/runtime/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ Engine::Engine(EngineConfig cfg, Loaded loaded)
// cfg_ is initialized above, so the KV-quant validation sees the model.
worker_(make_factory(std::move(loaded.dir), loaded.is_gguf), &scheduler_, &tok_,
validate_kv_quant(cfg, cfg_), validate_prefix_cache(cfg, cfg_, model_name_),
validate_prefill_chunk(cfg)) {
validate_prefill_chunk(cfg), cfg.skinny_mm) {
// Configure the max waiting requests for the batch scheduler.
scheduler_.set_max_waiting(cfg.max_waiting);

Expand Down
6 changes: 6 additions & 0 deletions src/runtime/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ struct EngineConfig {
// benchmarked sweet spot); 0 = monolithic prefill per admission. Negative
// values are rejected at construction.
int prefill_chunk = 256;
// Multi-row GEMV decode kernels (model/skinny_matmul): dense fp16 matmuls of
// the batched-decode shape (B in [2, 16], L == 1) bypass MLX's tiled GEMM,
// which runs at a fraction of GEMV bandwidth there (ml-explore/mlx#3661).
// On by default; logits may differ from the stock kernel at fp16-noise scale
// (fp32 accumulation in a different order), token-equality gated in tests.
bool skinny_mm = true;
};

// Per-call embedding options. The two int fields are tri-state: -1 means "use
Expand Down
6 changes: 4 additions & 2 deletions src/runtime/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ bool consume(Request& req, int& produced, int id, const TokenLogprob* lp) {
} // namespace

Worker::Worker(ModelFactory factory, Scheduler* scheduler, const Tokenizer* tok,
KVQuantConfig kv_quant, PrefixCacheConfig prefix, int prefill_chunk)
KVQuantConfig kv_quant, PrefixCacheConfig prefix, int prefill_chunk,
bool skinny_mm)
: factory_(std::move(factory)), sched_(scheduler), tok_(tok), kv_quant_(kv_quant),
prefix_cfg_(prefix), prefill_chunk_(prefill_chunk) {}
prefix_cfg_(prefix), prefill_chunk_(prefill_chunk), skinny_mm_(skinny_mm) {}

Worker::~Worker() { stop(); }

Expand Down Expand Up @@ -195,6 +196,7 @@ void Worker::stop() {
void Worker::run() {
log::info("worker: loading model...");
model_ = factory_(); // load the model on this thread so its arrays live here
model_->set_skinny_mm(skinny_mm_); // batched-decode GEMV kernels (engine option)
// The prefix pool holds MLX arrays, so it lives (and dies) with this thread.
if (prefix_cfg_.enabled) {
prefix_ = std::make_unique<PrefixCache>(prefix_cfg_);
Expand Down
10 changes: 7 additions & 3 deletions src/runtime/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ class Worker {
// `kv_quant` selects the decode cache's storage (dense fp16 by default) and
// `prefix` the prefix-cache setting; the Engine validates both against the
// model before construction. `prefill_chunk` is the interleaved-admission
// chunk size in tokens (0 = monolithic prefill, see below). Defined
// chunk size in tokens (0 = monolithic prefill, see below). `skinny_mm`
// routes batched-decode dense matmuls through the multi-row GEMV kernels
// (model/skinny_matmul; applied to the model after it loads). Defined
// out-of-line (with the destructor) because the unique_ptr<VitEncoder>
// member needs the complete type for cleanup.
Worker(ModelFactory factory, Scheduler* scheduler, const Tokenizer* tok = nullptr,
KVQuantConfig kv_quant = {}, PrefixCacheConfig prefix = {}, int prefill_chunk = 256);
KVQuantConfig kv_quant = {}, PrefixCacheConfig prefix = {}, int prefill_chunk = 256,
bool skinny_mm = true);
~Worker();

Worker(const Worker&) = delete;
Expand Down Expand Up @@ -193,7 +196,8 @@ class Worker {

// Interleaved-prefill queue (worker thread only; empty when idle or feature off).
std::deque<PendingPrefill> pending_;
int prefill_chunk_; // chunk size in tokens; 0 = monolithic admits
int prefill_chunk_; // chunk size in tokens; 0 = monolithic admits
bool skinny_mm_; // multi-row GEMV decode kernels (set on the model post-load)

std::atomic<long> decode_steps_{0};
std::atomic<bool> ready_{false};
Expand Down
Loading
Loading