From 8d2fea917c90aed87cc97b34703c87023b54bceb Mon Sep 17 00:00:00 2001 From: Helder Vasconcelos Date: Wed, 10 Jun 2026 23:57:19 +0100 Subject: [PATCH] Add KV-cache quantization (kv_bits 8|4), golden-gated against mlx-lm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Engine-wide opt-in setting (default fp16): the KV cache is stored as mx::quantize triplets (packed words + per-group fp16 scales/biases, group 64), quantized at write time, in both the single-stream KVCache and the continuous-batching BatchKVCache — including the active decode batch, which no other MLX server quantizes. ~1.9x (8-bit, near-lossless) to ~3.6x (4-bit) less cache memory. - cache/kv_quant: KVQuantConfig + triplet types + the shared block-grow slice_update writer both caches use (buffer strides affect kernel accumulation order, so the storage strategy mirrors mlx-lm exactly). - model/sdpa: quantized_sdpa ports mlx_lm/models/base.py op-for-op (quantized_matmul scores/output, GQA reshape, precise softmax; MLX has no fused quantized SDPA kernel) plus the sdpa_with_cache dispatch seam all model attention call sites now use. The batched additive mask is reshaped (B,1,N,T)->(B,1,1,N,T) under GQA and OVERRIDES masked columns with finfo(fp16).min — a fully-masked left-pad row makes NaN that +(-inf) cannot cancel. - Plumbing: EngineConfig.kv_bits/kv_group_size with hard validation (vision and hybrid Qwen3.5 models are rejected — no silent fp16 fallback), server --kv-bits/MLXFORGE_KV_BITS/"kv_bits", CLI --kv-bits, KVBudget quantized accounting. - C ABI v6: mlxforge_engine_create2 + mlxforge_engine_opts2 (struct_size so future fields append without a create3); Node bindings kvBits/ kvGroupSize. - Gating: teacher-forced margin-gated greedy walks vs mlx-lm QuantizedKVCache streams (Llama + Qwen3, 8/4-bit). Exact-stream gating is unsound here: quantized matmuls are fusion-context-sensitive (~1 logit shift between lazy and materialized inputs — mlx-lm disagrees with itself across graph contexts), so tokens are asserted where the reference top-2 margin clears that noise. Batched-vs-single-stream coherence is gated exactly, plus model-free triplet-surgery and SDPA-vs-dequantized unit tests. Co-Authored-By: Claude Fable 5 --- CLAUDE.md | 15 ++ CMakeLists.txt | 2 + README.md | 5 +- apps/mlxforge.cpp | 14 +- apps/mlxforge_cli.cpp | 27 ++- bindings/node/index.d.ts | 9 + bindings/node/src/addon.cc | 9 +- cmake/abi-baseline.txt | 1 + doc/applications.md | 1 + doc/architecture.md | 37 ++- doc/embedding.md | 4 + reference/dump_ref.py | 33 +++ reference/fixtures/greedy_gaps_kvq4.npy | Bin 0 -> 208 bytes reference/fixtures/greedy_gaps_kvq8.npy | Bin 0 -> 208 bytes reference/fixtures/greedy_tokens_kvq4.npy | Bin 0 -> 208 bytes reference/fixtures/greedy_tokens_kvq8.npy | Bin 0 -> 208 bytes reference/fixtures/manifest.json | 24 ++ reference/fixtures_qwen3/greedy_gaps_kvq4.npy | Bin 0 -> 208 bytes reference/fixtures_qwen3/greedy_gaps_kvq8.npy | Bin 0 -> 208 bytes .../fixtures_qwen3/greedy_tokens_kvq4.npy | Bin 0 -> 208 bytes .../fixtures_qwen3/greedy_tokens_kvq8.npy | Bin 0 -> 208 bytes reference/fixtures_qwen3/manifest.json | 24 ++ src/cache/batch_kv_cache.cpp | 160 ++++++------- src/cache/batch_kv_cache.h | 36 ++- src/cache/kv_budget.cpp | 14 +- src/cache/kv_budget.h | 7 +- src/cache/kv_cache.cpp | 34 ++- src/cache/kv_cache.h | 41 +++- src/cache/kv_quant.cpp | 58 +++++ src/cache/kv_quant.h | 67 ++++++ src/capi/mlxforge.cpp | 35 +++ src/capi/mlxforge.h | 30 ++- src/model/decoder_model.cpp | 16 +- src/model/qwen3_5.cpp | 15 +- src/model/sdpa.cpp | 97 ++++++++ src/model/sdpa.h | 45 ++++ src/runtime/batching.cpp | 4 +- src/runtime/batching.h | 6 +- src/runtime/engine.cpp | 29 ++- src/runtime/engine.h | 8 + src/runtime/single_stream.cpp | 5 +- src/runtime/single_stream.h | 6 +- src/runtime/worker.cpp | 7 +- src/runtime/worker.h | 10 +- src/server/config.cpp | 12 +- src/server/config.h | 10 +- tests/CMakeLists.txt | 3 + tests/cache/kv_budget_test.cpp | 10 + tests/cache/kv_cache_quantized_test.cpp | 224 ++++++++++++++++++ tests/cache/kv_quant_cache_test.cpp | 164 +++++++++++++ tests/model/quantized_sdpa_test.cpp | 107 +++++++++ 51 files changed, 1300 insertions(+), 165 deletions(-) create mode 100644 reference/fixtures/greedy_gaps_kvq4.npy create mode 100644 reference/fixtures/greedy_gaps_kvq8.npy create mode 100644 reference/fixtures/greedy_tokens_kvq4.npy create mode 100644 reference/fixtures/greedy_tokens_kvq8.npy create mode 100644 reference/fixtures_qwen3/greedy_gaps_kvq4.npy create mode 100644 reference/fixtures_qwen3/greedy_gaps_kvq8.npy create mode 100644 reference/fixtures_qwen3/greedy_tokens_kvq4.npy create mode 100644 reference/fixtures_qwen3/greedy_tokens_kvq8.npy create mode 100644 src/cache/kv_quant.cpp create mode 100644 src/cache/kv_quant.h create mode 100644 src/model/sdpa.cpp create mode 100644 src/model/sdpa.h create mode 100644 tests/cache/kv_cache_quantized_test.cpp create mode 100644 tests/cache/kv_quant_cache_test.cpp create mode 100644 tests/model/quantized_sdpa_test.cpp diff --git a/CLAUDE.md b/CLAUDE.md index efc16d1..a8d7c44 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -149,6 +149,21 @@ reference/.venv/bin/python reference/dump_ref.py math. - **Decode-with-cache vs full-recompute logits differ by fp16 accumulation order** — compare argmax / exact tokens, not raw logits at tight tolerance. +- **Quantized KV (kv_bits 8|4) mirrors mlx-lm's QuantizedKVCache** — triplet + storage quantized at write time (`cache/kv_quant`), attention via the + hand-rolled `quantized_sdpa` (`model/sdpa`, a port of mlx_lm base.py; MLX has + no fused quantized SDPA kernel). Three traps: (1) the batched additive mask + must be reshaped `(B,1,N,T)→(B,1,1,N,T)` under GQA, and masked columns must be + **overridden** with `finfo(fp16).min`, never added — a fully-masked left-pad + row makes NaN that `+(-inf)` cannot cancel; (2) quantized matmuls are + **fusion-context-sensitive** (~1 logit shift between lazy and materialized + inputs — mlx-lm disagrees with itself across graph contexts), so the golden + gates are teacher-forced and margin-gated (`greedy_gaps_kvq*.npy`), never raw + exact-stream asserts; (3) both caches deliberately share the block-grow + + `slice_update` storage writer (`update_kv_components`) — buffer strides + affect kernel accumulation order. Engine-wide setting, default off; + vision/hybrid models are rejected at engine creation (no silent fp16 + fallback). - **Qwen3-VL interleaved M-RoPE can't use `fast::rope`** (it takes a 1D offset, not 3D `(t,h,w)` positions). `Qwen3VLModel` hand-rolls a half-split rotation with a per-frequency t/h/w selector; text tokens have `t==h==w` so it reduces to diff --git a/CMakeLists.txt b/CMakeLists.txt index 40f0da4..1821229 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(mlxforge_core STATIC src/core/gguf.cpp src/core/model_source.cpp src/model/decoder_model.cpp + src/model/sdpa.cpp src/model/qwen3.cpp src/model/qwen3_moe.cpp src/model/qwen3_5.cpp @@ -66,6 +67,7 @@ add_library(mlxforge_core STATIC src/model/model_factory.cpp src/cache/kv_cache.cpp src/cache/batch_kv_cache.cpp + src/cache/kv_quant.cpp src/cache/kv_budget.cpp src/sample/sampler.cpp src/sample/json_grammar.cpp diff --git a/README.md b/README.md index 194972c..9e504ac 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,10 @@ the C-ABI / Node quickstart. caller of `eval`/`async_eval`; exactly **one `async_eval` per decode step** over the whole batch, with batch-size bucketing. - **KV cache** — single-sequence and batched (`BatchKVCache`), left-padded, grown in - 256-token blocks, with `filter` (eviction) / `merge` (admission). + 256-token blocks, with `filter` (eviction) / `merge` (admission). Optional + **KV-cache quantization** (`--kv-bits 8|4`, default fp16): mlx-lm-matching quantized + storage + attention for ~1.9×/~3.6× less cache memory — including the active + continuous-decode batch, which no other MLX server quantizes. - **Sampling as graph ops** — greedy, temperature, top-k, top-p (no host readback). - **Embeddings** — `engine.embed` runs the decoder to its final hidden states, pools (mean or last-token) and L2-normalizes. **Qwen3-Embedding** is first-class. Exposed diff --git a/apps/mlxforge.cpp b/apps/mlxforge.cpp index 85d0995..5d8378a 100644 --- a/apps/mlxforge.cpp +++ b/apps/mlxforge.cpp @@ -72,12 +72,13 @@ void print_help() { " --max-ctx max prompt length in tokens (default 8192)\n" " --max-waiting max queued requests (default 256)\n" " --kv-budget KV cache budget in bytes, 0 = unbounded (default 0)\n" + " --kv-bits KV cache quantization: 0 = fp16, 8 or 4 (default 0)\n" " -h, --help show this help and exit\n" "\n" "The model may be given via -m or the config file's \"model\" key.\n" "Config precedence (low to high): defaults < config file < env vars < CLI flags.\n" "Env vars: MLXFORGE_HOST, MLXFORGE_PORT, MLXFORGE_MAX_CTX, MLXFORGE_MAX_WAITING, " - "MLXFORGE_KV_BUDGET."); + "MLXFORGE_KV_BUDGET, MLXFORGE_KV_BITS."); std::fflush(stdout); } @@ -126,8 +127,11 @@ int main(int argc, char** argv) { // engine boundary; the server below is just one consumer of it. std::unique_ptr engine; try { - engine = std::make_unique( - mlxforge::EngineConfig{sc.model_dir, sc.max_waiting}); + mlxforge::EngineConfig ec; + ec.model_spec = sc.model_dir; + ec.max_waiting = sc.max_waiting; + ec.kv_bits = sc.kv_bits; + engine = std::make_unique(std::move(ec)); } catch (const std::exception& e) { mlxforge::log::error("model error: {}", e.what()); return 2; @@ -163,8 +167,8 @@ int main(int argc, char** argv) { std::signal(SIGTERM, on_signal); // Info log: server has started, print bind details and config bounds. - mlxforge::log::info("mlxforge serving on http://{}:{} (max_ctx={} max_waiting={})", - sc.host, sc.port, sc.max_ctx, sc.max_waiting); + mlxforge::log::info("mlxforge serving on http://{}:{} (max_ctx={} max_waiting={} kv_bits={})", + sc.host, sc.port, sc.max_ctx, sc.max_waiting, sc.kv_bits); // Run the server's request loop (blocks until stop() is called). server.listen(sc.host, sc.port); diff --git a/apps/mlxforge_cli.cpp b/apps/mlxforge_cli.cpp index 87694d4..9db2a2c 100644 --- a/apps/mlxforge_cli.cpp +++ b/apps/mlxforge_cli.cpp @@ -6,11 +6,12 @@ // mlxforge-cli dump-weights // - Loads a model's weights from the supplied directory, prints key/shape/dtype for each tensor, // asserts that all tensors are fp16, and reports the peak resident memory used. -// mlxforge-cli generate [max_tokens] [--logprobs [N]] +// mlxforge-cli generate [max_tokens] [--logprobs [N]] [--kv-bits N] // - Runs greedy single-stream generation: pre-fills the prompt (as raw text using the chat template // or as a pre-tokenized .npy of ids), then streams the detokenized text to stdout until EOS or // max_tokens. With --logprobs [N], each emitted token's log-prob (and its N most-likely // alternatives) is printed to stderr after generation; stdout stays the generated text. +// --kv-bits 8|4 stores the KV cache quantized (the manual harness for the quantized path). // mlxforge-cli bench [max_tokens] [runs] // - Repeatable throughput benchmark over a fixed prompt: one discarded warmup run, then `runs` // timed runs (defaults: max_tokens=128, runs=3) reporting time-to-first-token and decode tok/s. @@ -181,7 +182,7 @@ std::string show_token(const std::string& s) { // `top_logprobs` mirrors the engine knob: -1 = off; 0 = each token's own log-prob; // N > 0 = also its N most-likely alternatives (printed to stderr after generation). int run_generate(const std::string& spec, const std::string& prompt_arg, int max_tokens, - int top_logprobs = -1) { + int top_logprobs = -1, int kv_bits = 0) { // Resolve and load the model (GGUF file or safetensors dir; downloads if needed) LoadedModel lm = load_for_inference(spec); mlxforge::DecoderModel& model = *lm.model; @@ -212,7 +213,7 @@ int run_generate(const std::string& spec, const std::string& prompt_arg, int max std::string piece = detok.add(id); std::fwrite(piece.data(), 1, piece.size(), stdout); std::fflush(stdout); - }, top_logprobs); + }, top_logprobs, mlxforge::KVQuantConfig{kv_bits, 64}); // Output any final detokenized tail remaining in the streaming detokenizer std::string tail = detok.finish(); @@ -370,13 +371,15 @@ int main(int argc, char** argv) { if (argc < 4) { std::fprintf(stderr, "usage: mlxforge-cli generate [max_tokens] " - "[--logprobs [N]]\n"); + "[--logprobs [N]] [--kv-bits N]\n"); return 2; } - // Positional [max_tokens] (default 64) plus an optional --logprobs [N] flag (N - // alternatives, default 0 = the chosen token's own log-prob only). + // Positional [max_tokens] (default 64), an optional --logprobs [N] flag (N + // alternatives, default 0 = the chosen token's own log-prob only), and an + // optional --kv-bits N (0 = fp16 cache, 8 or 4 = quantized). int max_tokens = 64; int top_logprobs = -1; + int kv_bits = 0; for (int i = 4; i < argc; ++i) { const std::string a = argv[i]; if (a == "--logprobs") { @@ -384,11 +387,21 @@ int main(int argc, char** argv) { top_logprobs = std::stoi(argv[++i]); else top_logprobs = 0; + } else if (a == "--kv-bits") { + if (i + 1 >= argc) { + std::fprintf(stderr, "error: --kv-bits needs a value (0, 4, or 8)\n"); + return 2; + } + kv_bits = std::stoi(argv[++i]); + if (kv_bits != 0 && kv_bits != 4 && kv_bits != 8) { + std::fprintf(stderr, "error: --kv-bits must be 0, 4, or 8\n"); + return 2; + } } else { max_tokens = std::stoi(a); } } - return run_generate(argv[2], argv[3], max_tokens, top_logprobs); + return run_generate(argv[2], argv[3], max_tokens, top_logprobs, kv_bits); } if (cmd == "image") { // Vision-language generation: describe / answer about an image. diff --git a/bindings/node/index.d.ts b/bindings/node/index.d.ts index f35567d..7631eea 100644 --- a/bindings/node/index.d.ts +++ b/bindings/node/index.d.ts @@ -3,6 +3,15 @@ export interface EngineOptions { /** Max queued requests before submit is rejected (default 256). */ maxWaiting?: number; + /** + * KV-cache quantization bits (engine-wide): 0 (default) keeps the cache + * dense fp16; 8 is near-lossless at ~1.9x less cache memory; 4 is ~3.6x. + * Unsupported models (vision-language, hybrid Qwen3.5) fail engine creation + * rather than silently falling back to fp16. + */ + kvBits?: 0 | 4 | 8; + /** Quantization group size (default 64; must divide the model's head_dim). */ + kvGroupSize?: number; } export interface SamplingOptions { diff --git a/bindings/node/src/addon.cc b/bindings/node/src/addon.cc index f127e67..af3a3a2 100644 --- a/bindings/node/src/addon.cc +++ b/bindings/node/src/addon.cc @@ -251,15 +251,20 @@ class EngineWrap : public Napi::ObjectWrap { throw Napi::TypeError::New(env, "new Engine(spec, opts?) requires a model spec string"); std::string spec = info[0].As().Utf8Value(); - mlxforge_engine_opts opts = {}; + mlxforge_engine_opts2 opts = {}; + opts.struct_size = sizeof(opts); if (info.Length() >= 2 && info[1].IsObject()) { Napi::Object o = info[1].As(); if (o.Has("maxWaiting") && o.Get("maxWaiting").IsNumber()) opts.max_waiting = o.Get("maxWaiting").As().Int32Value(); + if (o.Has("kvBits") && o.Get("kvBits").IsNumber()) + opts.kv_bits = o.Get("kvBits").As().Int32Value(); + if (o.Has("kvGroupSize") && o.Get("kvGroupSize").IsNumber()) + opts.kv_group_size = o.Get("kvGroupSize").As().Int32Value(); } char* err = nullptr; - eng_ = mlxforge_engine_create(spec.c_str(), &opts, &err); + eng_ = mlxforge_engine_create2(spec.c_str(), &opts, &err); if (!eng_) { std::string msg = err ? err : "failed to create engine"; mlxforge_string_free(err); diff --git a/cmake/abi-baseline.txt b/cmake/abi-baseline.txt index cc2cf02..d65075a 100644 --- a/cmake/abi-baseline.txt +++ b/cmake/abi-baseline.txt @@ -2,6 +2,7 @@ mlxforge_abi_version mlxforge_embed mlxforge_embed_ex mlxforge_engine_create +mlxforge_engine_create2 mlxforge_engine_free mlxforge_engine_model_name mlxforge_engine_ready diff --git a/doc/applications.md b/doc/applications.md index 657b232..9c43c00 100644 --- a/doc/applications.md +++ b/doc/applications.md @@ -57,6 +57,7 @@ with environment-variable fallbacks (`server/config`): | `--max-ctx` | `MLXFORGE_MAX_CTX` | `8192` | Reject prompts longer than this → `400`. | | `--max-waiting` | `MLXFORGE_MAX_WAITING` | `256` | Bounded waiting queue → `429` on overflow. | | `--kv-budget` | `MLXFORGE_KV_BUDGET` | `0` (unbounded) | KV-memory admission budget in bytes. | +| `--kv-bits` | `MLXFORGE_KV_BITS` | `0` (fp16) | KV-cache quantization: `8` (~1.9× less cache memory, near-lossless) or `4` (~3.6×). Unsupported models (vision, hybrid Qwen3.5) fail startup rather than silently falling back. | ### Logging diff --git a/doc/architecture.md b/doc/architecture.md index e5f81e5..bb0aa6e 100644 --- a/doc/architecture.md +++ b/doc/architecture.md @@ -171,6 +171,39 @@ of `B` sequences each reaching `max_len + max_new` tokens is refused/queued if i would exceed the configured `--kv-budget`. Combined with the bounded waiting queue (which returns `429` on overflow), this is the real OOM defence. +## KV-cache quantization + +`--kv-bits 8|4` (engine option `kv_bits`; default 0 = dense fp16) stores the KV +cache quantized, cutting its memory ~1.9× (8-bit, near-lossless) or ~3.6× +(4-bit). The port mirrors mlx-lm exactly: + +- **Storage** (`cache/kv_quant`): each cached K/V tensor is the `mx::quantize` + triplet — packed uint32 words plus per-group fp16 scales and biases + (group size 64) — quantized **at write time**, per position, so prefill + chunking cannot change stored values. Both `KVCache` and `BatchKVCache` hold + per-layer component vectors (1 array dense, 3 quantized); all batch surgery + (`filter`/`merge`/`pad_dummies`, block growth) runs per component unchanged. +- **Attention** (`model/sdpa`): MLX has no fused quantized SDPA kernel, so + `quantized_sdpa` ports `mlx_lm/models/base.py` op-for-op — `quantized_matmul` + for the scores and the output, GQA via a `(B, n_kv, n_rep, L, D)` reshape, + precise softmax. `sdpa_with_cache` is the dispatch seam every model attention + call site uses (dense fast-kernel vs quantized path, by the cache's config). +- **Setting scope**: engine-wide, never per-request — the batched cache's + storage is physically shared across rows. Unsupported setups (vision-language + and hybrid Qwen3.5 models, which have no quantized golden reference yet; + group sizes that don't divide `head_dim`) **fail engine creation**; there is + no silent fp16 fallback. +- **Gating**: teacher-forced greedy walks against mlx-lm `QuantizedKVCache` + streams (Llama + Qwen3, 8- and 4-bit), asserting token equality at every step + whose reference top-2 margin clears the fusion-context noise (quantized + matmuls shift ~1 logit between lazy and materialized inputs, so bit-exact + cross-implementation gating is unsound); plus an exact batched-vs-single- + stream coherence gate. + +The per-token budget figure adjusts accordingly: a K-or-V head row is +`head_dim × bits/8` packed bytes plus a fp16 scale and bias per group (D=64/g=64: +68 B at 8-bit, 36 B at 4-bit, vs 128 B fp16). + ## Module map Source lives under `src/`, grouped by responsibility. Tests mirror the module @@ -186,7 +219,9 @@ path under `tests/`. | `model/` | The transformer: `DecoderModel` base (embedding, RMSNorm, RoPE, GQA SDPA, SwiGLU, LM head; fp16 and quantized paths; single-stream and batched forward) with `LlamaModel`/`Qwen3Model`/`Qwen3MoeModel` subclasses and a `create_model` factory. | | `cache/kv_cache` | Single-sequence KV cache (the simplest prefill/decode split). | | `cache/batch_kv_cache` | Batched, left-padded KV cache: `update_and_fetch`, `filter` (evict), `merge` (admit), `pad_dummies` (bucketing). | -| `cache/kv_budget` | KV memory projection / admission gate. | +| `cache/kv_quant` | Quantized-KV shared types (`KVQuantConfig`, triplets) + the block-grow component writer both caches use. | +| `cache/kv_budget` | KV memory projection / admission gate (fp16 and quantized accounting). | +| `model/sdpa` | Cache-aware SDPA dispatch: dense fast kernel vs the hand-rolled quantized path (mlx-lm port). | | `sample/sampler` | greedy / temperature / top-k / top-p, all as MLX graph ops. | | `scheduler/request` | The `Request` struct and the bounded, blocking `TokenQueue`. | | `scheduler/scheduler` | The waiting queue + worker handoff (mutex + condition variable). | diff --git a/doc/embedding.md b/doc/embedding.md index ef140da..fca9f3d 100644 --- a/doc/embedding.md +++ b/doc/embedding.md @@ -91,6 +91,10 @@ typedef struct mlxforge_request mlxforge_request; // Create one engine; it owns the GPU worker thread and the batching scheduler. mlxforge_engine* eng = mlxforge_engine_create("mlx-community/Llama-3.2-1B-Instruct-4bit", /*opts=*/NULL, &err); +// Or with extended options (ABI v6+): a quantized KV cache cuts the dominant +// growing allocation ~1.9x (8-bit, near-lossless) or ~3.6x (4-bit). +// mlxforge_engine_opts2 opts = { .struct_size = sizeof(opts), .kv_bits = 8 }; +// eng = mlxforge_engine_create2(spec, &opts, &err); while (!mlxforge_engine_ready(eng)) { /* model still loading on the worker thread */ } // Submit many requests concurrently — they share one batched engine. This is the diff --git a/reference/dump_ref.py b/reference/dump_ref.py index 287c37e..543f9b6 100644 --- a/reference/dump_ref.py +++ b/reference/dump_ref.py @@ -192,6 +192,33 @@ def _dump_greedy(model, save, prompt_ids): save("greedy_tokens", np.array(greedy, dtype=np.int32)) +def _dump_greedy_quantized(model, save, prompt_ids, bits, group_size=64): + """Greedy continuation with mlx-lm's QuantizedKVCache (prefill + cached + decode). The oracle for the C++ quantized-KV path, which ports the same + triplet storage and quantized_matmul SDPA. + + Besides the tokens, the per-step top-2 logit margin is dumped: 4-bit + quantized matmuls at head_dim 128 are fusion-context-sensitive (the same + mlx-lm function on the same values shifts by ~1 logit depending on whether + its inputs are materialized or lazy), so the C++ gate for such combinations + is teacher-forced and asserts token equality only where the margin exceeds + that noise — exact gating would be asserting kernel fusion, not math.""" + from mlx_lm.models.cache import QuantizedKVCache + + cache = [QuantizedKVCache(group_size=group_size, bits=bits) for _ in model.layers] + logits = model(mx.array(prompt_ids, dtype=mx.int32)[None], cache=cache)[:, -1, :] + greedy, gaps = [], [] + for i in range(GREEDY_MAX_NEW): + top2 = mx.sort(logits.astype(mx.float32), axis=-1)[0, -2:] + gaps.append(float(top2[1] - top2[0])) + nxt = int(mx.argmax(logits, axis=-1).item()) + greedy.append(nxt) + if i + 1 < GREEDY_MAX_NEW: + logits = model(mx.array([[nxt]], dtype=mx.int32), cache=cache)[:, -1, :] + save(f"greedy_tokens_kvq{bits}", np.array(greedy, dtype=np.int32)) + save(f"greedy_gaps_kvq{bits}", np.array(gaps, dtype=np.float32)) + + def _write_manifest(fixtures_dir, manifest, tok): manifest["eos_token_ids"] = sorted(int(x) for x in tok.eos_token_ids) with open(os.path.join(fixtures_dir, "manifest.json"), "w") as f: @@ -638,6 +665,12 @@ def save(name, arr): # --- Greedy token stream (full-recompute reference; gates the greedy decode stream) --- _dump_greedy(model, save, prompt_id_lists[0]) + + # Quantized-KV greedy streams (8- and 4-bit): gate the C++ quantized cache + + # quantized SDPA end to end against mlx-lm's QuantizedKVCache. + for bits in (8, 4): + _dump_greedy_quantized(model, save, prompt_id_lists[0], bits) + _write_manifest(FIXTURES_DIR, manifest, tok) diff --git a/reference/fixtures/greedy_gaps_kvq4.npy b/reference/fixtures/greedy_gaps_kvq4.npy new file mode 100644 index 0000000000000000000000000000000000000000..c5ce7a855bfe34d73dc8f68919a45d663c74339f GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$720EHL3bhL411<&zIR^#?1$zbt0U%BQ;x-_^21us>X(b?T0g2l&FlYeT eD}d}CAkF}?)q!{eP&@|6KLo`8>=_ykH~;{&VJU0? literal 0 HcmV?d00001 diff --git a/reference/fixtures/greedy_gaps_kvq8.npy b/reference/fixtures/greedy_gaps_kvq8.npy new file mode 100644 index 0000000000000000000000000000000000000000..7897587a08a8688731cdeacd760b85b75e63a74e GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$720EHL3bhL411<&zO$P=B273mE3w8_)JoXF;42}#8B906VcYxxjfZ`{B k;zdBV4iH}g@(X~t!Ipu63#iT>DE0-&X8~emAbYa|00$*0JOBUy literal 0 HcmV?d00001 diff --git a/reference/fixtures/greedy_tokens_kvq4.npy b/reference/fixtures/greedy_tokens_kvq4.npy new file mode 100644 index 0000000000000000000000000000000000000000..c843717afe0cb75c306ce934a30bdc604494381a GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlWC%^qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$720EHL3bhL411>)U1_oXR1_mc428Ih#3=Gzc3=DB<3=CyJ8YTy#L40u_ LMpgsThb{*IH=HA) literal 0 HcmV?d00001 diff --git a/reference/fixtures/greedy_tokens_kvq8.npy b/reference/fixtures/greedy_tokens_kvq8.npy new file mode 100644 index 0000000000000000000000000000000000000000..0cea3926c5bed748adfc2f4e062c60b7911adc4b GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlWC%^qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$720EHL3bhL411>)U1_oXR1_mc428OGQ3=9c83=FzJdVK@~Lm7~6mt|lu Z0Mf|nPN*?396QLs5akF_qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$720EHL3bhL411<)J9d-;1CxF<&o`GQ-5DNhD8#{&uaR&wlIUrsEl=B4A i;XrH)#KjH_4WSMU4$VMzA&_o&U^vhRWX}iUWexySeJkJq literal 0 HcmV?d00001 diff --git a/reference/fixtures_qwen3/greedy_gaps_kvq8.npy b/reference/fixtures_qwen3/greedy_gaps_kvq8.npy new file mode 100644 index 0000000000000000000000000000000000000000..523ebe07bb212d3b323dcbbffc920ea89eaa4077 GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$720EHL3bhL411<)JU@;0{}b-BcuQT literal 0 HcmV?d00001 diff --git a/reference/fixtures_qwen3/manifest.json b/reference/fixtures_qwen3/manifest.json index bafa8ae..4e0a014 100644 --- a/reference/fixtures_qwen3/manifest.json +++ b/reference/fixtures_qwen3/manifest.json @@ -122,6 +122,30 @@ 20 ], "dtype": "int32" + }, + "greedy_tokens_kvq8": { + "shape": [ + 20 + ], + "dtype": "int32" + }, + "greedy_gaps_kvq8": { + "shape": [ + 20 + ], + "dtype": "float32" + }, + "greedy_tokens_kvq4": { + "shape": [ + 20 + ], + "dtype": "int32" + }, + "greedy_gaps_kvq4": { + "shape": [ + 20 + ], + "dtype": "float32" } }, "eos_token_ids": [ diff --git a/src/cache/batch_kv_cache.cpp b/src/cache/batch_kv_cache.cpp index 7561a9d..dd420de 100644 --- a/src/cache/batch_kv_cache.cpp +++ b/src/cache/batch_kv_cache.cpp @@ -1,6 +1,7 @@ #include "cache/batch_kv_cache.h" #include +#include #include #include @@ -16,20 +17,15 @@ mx::array neg(const std::vector& v) { for (size_t i = 0; i < v.size(); ++i) out[i] = -v[i]; return mx::array(out.data(), {static_cast(out.size())}, mx::int32); } - -// Slice [start, stop) along the sequence axis (2), keeping axes 0/1/3 full. -mx::array slice_seq(const mx::array& a, int start, int stop) { - const auto& s = a.shape(); - return mx::slice(a, {0, 0, start, 0}, {s[0], s[1], stop, s[3]}); -} } // namespace -BatchKVCache::BatchKVCache(int n_layers, const std::vector& left_padding) +BatchKVCache::BatchKVCache(int n_layers, const std::vector& left_padding, KVQuantConfig qcfg) : batch_(static_cast(left_padding.size())), keys_(n_layers), values_(n_layers), conv_state_(n_layers), recur_state_(n_layers), + qcfg_(qcfg), offset_(neg(left_padding)), left_padding_(mx::array(left_padding.data(), {static_cast(left_padding.size())}, mx::int32)) {} @@ -39,8 +35,8 @@ BatchKVCache BatchKVCache::from_single_sequence( const int n_layers = static_cast(kv_per_layer.size()); BatchKVCache c(n_layers, std::vector{0}); // batch 1, no left padding for (int l = 0; l < n_layers; ++l) { - c.keys_[l] = std::move(kv_per_layer[l].first); - c.values_[l] = std::move(kv_per_layer[l].second); + c.keys_[l] = {std::move(kv_per_layer[l].first)}; + c.values_[l] = {std::move(kv_per_layer[l].second)}; } c.idx_ = seq; // physical sequence length (drives the attention mask) int off = decode_offset; @@ -50,49 +46,41 @@ BatchKVCache BatchKVCache::from_single_sequence( } int BatchKVCache::s_cap() const { - return keys_[0].has_value() ? keys_[0]->shape()[2] : 0; + return keys_[0].empty() ? 0 : keys_[0][0].shape()[2]; } std::pair BatchKVCache::update_and_fetch(int layer, const mx::array& k, const mx::array& v) { - const int prev = idx_; - const int L = k.shape()[2]; - const int end = prev + L; - const int B = k.shape()[0]; - const int H = k.shape()[1]; - const int Dk = k.shape()[3]; - const int Dv = v.shape()[3]; - - const int cap = keys_[layer].has_value() ? keys_[layer]->shape()[2] : 0; - if (!keys_[layer].has_value() || end > cap) { - const int n_steps = (kStep + L - 1) / kStep; - const int add = n_steps * kStep; - mx::array new_k = mx::zeros({B, H, add, Dk}, k.dtype()); - mx::array new_v = mx::zeros({B, H, add, Dv}, v.dtype()); - if (keys_[layer].has_value()) { - mx::array kk = *keys_[layer]; - mx::array vv = *values_[layer]; - // Drop any unused tail of the last block before growing. - if (prev % kStep != 0) { - kk = mx::slice(kk, {0, 0, 0, 0}, {B, H, prev, Dk}); - vv = mx::slice(vv, {0, 0, 0, 0}, {B, H, prev, Dv}); - } - keys_[layer] = mx::concatenate({kk, new_k}, /*axis=*/2); - values_[layer] = mx::concatenate({vv, new_v}, /*axis=*/2); - } else { - keys_[layer] = new_k; - values_[layer] = new_v; - } - } + if (quantized()) + throw std::logic_error("BatchKVCache::update_and_fetch: cache is quantized"); + std::vector ks = update_kv_components(keys_[layer], {k}, idx_, kStep); + std::vector vs = update_kv_components(values_[layer], {v}, idx_, kStep); + return {ks[0], vs[0]}; +} - keys_[layer] = mx::slice_update(*keys_[layer], k, {0, 0, prev, 0}, {B, H, end, Dk}); - values_[layer] = mx::slice_update(*values_[layer], v, {0, 0, prev, 0}, {B, H, end, Dv}); - return {mx::slice(*keys_[layer], {0, 0, 0, 0}, {B, H, end, Dk}), - mx::slice(*values_[layer], {0, 0, 0, 0}, {B, H, end, Dv})}; +QuantizedKVSlice BatchKVCache::update_and_fetch_quantized(int layer, const mx::array& k, + const mx::array& v) { + if (!quantized()) + throw std::logic_error("BatchKVCache::update_and_fetch_quantized: cache is dense"); + QuantizedKV ks = quantize_and_update(keys_[layer], k, qcfg_.group_size, qcfg_.bits, idx_, kStep); + QuantizedKV vs = + quantize_and_update(values_[layer], v, qcfg_.group_size, qcfg_.bits, idx_, kStep); + return {ks, vs}; } std::pair BatchKVCache::fetch(int layer) const { - return {slice_seq(*keys_[layer], 0, idx_), slice_seq(*values_[layer], 0, idx_)}; + if (quantized()) throw std::logic_error("BatchKVCache::fetch: cache is quantized"); + return {slice_seq(keys_[layer][0], 0, idx_), slice_seq(values_[layer][0], 0, idx_)}; +} + +std::pair BatchKVCache::fetch_dequantized(int layer) const { + if (!quantized()) + throw std::logic_error("BatchKVCache::fetch_dequantized: cache is dense"); + auto deq = [&](const std::vector& t) { + return mx::dequantize(slice_seq(t[0], 0, idx_), slice_seq(t[1], 0, idx_), + slice_seq(t[2], 0, idx_), qcfg_.group_size, qcfg_.bits); + }; + return {deq(keys_[layer]), deq(values_[layer])}; } void BatchKVCache::advance(int n_tokens) { @@ -103,14 +91,16 @@ void BatchKVCache::advance(int n_tokens) { void BatchKVCache::pad_dummies(int extra) { if (extra <= 0) return; + auto pad_rows = [&](std::vector& comps) { + for (auto& c : comps) { + mx::array d = mx::zeros({extra, c.shape()[1], c.shape()[2], c.shape()[3]}, c.dtype()); + c = mx::concatenate({c, d}, /*axis=*/0); + } + }; for (size_t l = 0; l < keys_.size(); ++l) { - if (!keys_[l].has_value()) continue; - const auto& k = *keys_[l]; - const auto& v = *values_[l]; - mx::array dk = mx::zeros({extra, k.shape()[1], k.shape()[2], k.shape()[3]}, k.dtype()); - mx::array dv = mx::zeros({extra, v.shape()[1], v.shape()[2], v.shape()[3]}, v.dtype()); - keys_[l] = mx::concatenate({k, dk}, /*axis=*/0); - values_[l] = mx::concatenate({v, dv}, /*axis=*/0); + if (keys_[l].empty()) continue; + pad_rows(keys_[l]); + pad_rows(values_[l]); } // Linear-attention layers: append `extra` zero state rows (fixed-size axes). for (size_t l = 0; l < conv_state_.size(); ++l) { @@ -133,9 +123,9 @@ void BatchKVCache::pad_dummies(int extra) { void BatchKVCache::eval_state() { std::vector state = {offset_, left_padding_}; for (auto& k : keys_) - if (k.has_value()) state.push_back(*k); + for (auto& c : k) state.push_back(c); for (auto& v : values_) - if (v.has_value()) state.push_back(*v); + for (auto& c : v) state.push_back(c); for (auto& c : conv_state_) if (c.has_value()) state.push_back(*c); for (auto& r : recur_state_) @@ -154,10 +144,8 @@ int scalar_int(const mx::array& a) { void BatchKVCache::filter(const std::vector& keep) { mx::array idxs(keep.data(), {static_cast(keep.size())}, mx::int32); for (int l = 0; l < static_cast(keys_.size()); ++l) { - if (keys_[l].has_value()) { - keys_[l] = mx::take(*keys_[l], idxs, /*axis=*/0); - values_[l] = mx::take(*values_[l], idxs, /*axis=*/0); - } + for (auto& c : keys_[l]) c = mx::take(c, idxs, /*axis=*/0); + for (auto& c : values_[l]) c = mx::take(c, idxs, /*axis=*/0); // Linear state is fixed-size; eviction is a plain batch-axis gather. if (conv_state_[l].has_value()) { conv_state_[l] = mx::take(*conv_state_[l], idxs, /*axis=*/0); @@ -172,10 +160,8 @@ void BatchKVCache::filter(const std::vector& keep) { const int min_left_pad = scalar_int(mx::min(left_padding_, /*keepdims=*/false)); if (min_left_pad > 0) { for (int l = 0; l < static_cast(keys_.size()); ++l) { - if (keys_[l].has_value()) { - keys_[l] = slice_seq(*keys_[l], min_left_pad, keys_[l]->shape()[2]); - values_[l] = slice_seq(*values_[l], min_left_pad, values_[l]->shape()[2]); - } + for (auto& c : keys_[l]) c = slice_seq(c, min_left_pad, c.shape()[2]); + for (auto& c : values_[l]) c = slice_seq(c, min_left_pad, c.shape()[2]); } idx_ -= min_left_pad; left_padding_ = mx::subtract(left_padding_, mx::array(min_left_pad, mx::int32)); @@ -184,6 +170,8 @@ void BatchKVCache::filter(const std::vector& keep) { } void BatchKVCache::merge(BatchKVCache& other) { + if (qcfg_ != other.qcfg_) + throw std::logic_error("BatchKVCache::merge: KV quantization configs differ"); if (other.batch_ == 0) return; if (batch_ == 0) { keys_ = other.keys_; @@ -203,31 +191,39 @@ void BatchKVCache::merge(BatchKVCache& other) { const int max_size = std::max(l1, l2); // Pad one cache's layer `l` so it is right-justified at max_idx and sized - // max_size on the sequence axis. Returns the padded K and V. - auto pad_layer = [&](BatchKVCache& c, int l) -> std::pair { - mx::array k = *c.keys_[l]; - mx::array v = *c.values_[l]; - const int len = k.shape()[2]; - const int left = max_idx - c.idx_; - int right = max_size - len - left; - if (right < 0) { // trim the unused tail - k = slice_seq(k, 0, len + right); - v = slice_seq(v, 0, len + right); - right = 0; - } - if (left != 0 || right != 0) { - std::vector> pw = {{0, 0}, {0, 0}, {left, right}, {0, 0}}; - k = mx::pad(k, pw); - v = mx::pad(v, pw); + // max_size on the sequence axis. Returns the padded components. + auto pad_layer = [&](BatchKVCache& c, std::vector& comps) { + std::vector out; + out.reserve(comps.size()); + for (const auto& a : comps) { + mx::array x = a; + const int len = x.shape()[2]; + const int left = max_idx - c.idx_; + int right = max_size - len - left; + if (right < 0) { // trim the unused tail + x = slice_seq(x, 0, len + right); + right = 0; + } + if (left != 0 || right != 0) { + std::vector> pw = {{0, 0}, {0, 0}, {left, right}, {0, 0}}; + x = mx::pad(x, pw); + } + out.push_back(std::move(x)); } - return {k, v}; + return out; }; for (int l = 0; l < static_cast(keys_.size()); ++l) { - auto a = pad_layer(*this, l); - auto b = pad_layer(other, l); - keys_[l] = mx::concatenate({a.first, b.first}, /*axis=*/0); - values_[l] = mx::concatenate({a.second, b.second}, /*axis=*/0); + std::vector ka = pad_layer(*this, keys_[l]); + std::vector kb = pad_layer(other, other.keys_[l]); + std::vector va = pad_layer(*this, values_[l]); + std::vector vb = pad_layer(other, other.values_[l]); + keys_[l].clear(); + values_[l].clear(); + for (size_t i = 0; i < ka.size(); ++i) { + keys_[l].push_back(mx::concatenate({ka[i], kb[i]}, /*axis=*/0)); + values_[l].push_back(mx::concatenate({va[i], vb[i]}, /*axis=*/0)); + } // Linear state: fixed-size, so admission is a plain batch-axis concatenate // (no sequence-length right-justification — the recurrent state already // summarizes each row's history regardless of length). diff --git a/src/cache/batch_kv_cache.h b/src/cache/batch_kv_cache.h index 619f3ba..3a0195a 100644 --- a/src/cache/batch_kv_cache.h +++ b/src/cache/batch_kv_cache.h @@ -12,12 +12,20 @@ // update_and_fetch(layer, k, v) writes one layer's slice; advance(n) bumps the // shared bookkeeping once per token sweep. filter()/merge() do the batch-axis // surgery (eviction/admission) the scheduler needs. +// +// With a KVQuantConfig (kv_bits > 0) each layer's K/V is instead the +// mx::quantize triplet (cache/kv_quant.h), quantized at write time exactly like +// mlx_lm's QuantizedKVCache. Every axis-0/2 operation (grow, write, filter, +// merge, pad) is last-axis-agnostic, so the same surgery runs per component; +// zero-filled pad regions dequantize to exactly 0 and are masked anyway. #pragma once #include #include #include +#include "cache/kv_quant.h" + #include "mlx/array.h" namespace mlxforge { @@ -29,7 +37,7 @@ class BatchKVCache { static constexpr int kStep = 256; // One cache for all layers; `left_padding[i]` is the pad count of batch row i. - BatchKVCache(int n_layers, const std::vector& left_padding); + BatchKVCache(int n_layers, const std::vector& left_padding, KVQuantConfig qcfg = {}); // Build a batch-1 cache from one already-prefilled sequence's per-layer K/V // (each (1, n_kv_heads, seq, head_dim)). `decode_offset` seeds the per-row RoPE @@ -40,6 +48,7 @@ class BatchKVCache { // decode (whose mask works in physical-slot space while RoPE uses offset) is // numerically correct unchanged. left_padding is 0 (a fresh single-row prefill // has no padding). Used to admit a vision prompt into the decode pool (merge()). + // Always dense (the vision path is rejected when KV quantization is on). static BatchKVCache from_single_sequence( std::vector> kv_per_layer, int seq, int decode_offset); @@ -50,12 +59,21 @@ class BatchKVCache { const mx::array& offset() const { return offset_; } // (B,) int32 const mx::array& left_padding() const { return left_padding_; } // (B,) int32 + bool quantized() const { return qcfg_.enabled(); } + const KVQuantConfig& quant_config() const { return qcfg_; } + // Append layer `layer`'s K/V (each (B, n_kv_heads, L, head_dim)) at the // current write position, growing capacity in `step` blocks if needed, and - // return the populated slice [..., :idx+L, :] to attend over. + // return the populated slice [..., :idx+L, :] to attend over. Dense caches + // only (throws on a quantized cache). std::pair update_and_fetch(int layer, const mx::array& k, const mx::array& v); + // Quantized counterpart: quantize the incoming K/V (mx::quantize, per + // position), write each triplet component at the current position with the + // same block growth, and return the populated triplets for quantized_sdpa. + QuantizedKVSlice update_and_fetch_quantized(int layer, const mx::array& k, const mx::array& v); + // Advance the shared offset/idx by `n_tokens` (call once after all layers). void advance(int n_tokens); @@ -64,8 +82,13 @@ class BatchKVCache { void eval_state(); // Populated K/V slice [..., :idx, :] for a layer (for inspection/tests). + // Dense caches only. std::pair fetch(int layer) const; + // Populated K/V slice of a quantized layer, dequantized back to fp16 (for + // inspection/tests — the model attends over the triplets directly). + std::pair fetch_dequantized(int layer) const; + // Eviction: keep only the given batch rows (take on axis 0) across every // layer's K/V plus offset/left_padding, then shift off any common left // padding. `keep` indexes the current batch rows. @@ -73,7 +96,8 @@ class BatchKVCache { // Admission: pad this cache and `other` to a common S_cap (right-justified by // write index) and concatenate on the batch axis. Used to admit a freshly - // prefilled batch into the decode cache. + // prefilled batch into the decode cache. Both caches must share a + // KVQuantConfig. void merge(BatchKVCache& other); // Append `extra` masked dummy rows on the batch axis to reach a decode @@ -99,10 +123,12 @@ class BatchKVCache { private: int batch_; int idx_ = 0; - std::vector> keys_; - std::vector> values_; + // Per layer: empty until written; 1 component dense, 3 quantized. + std::vector> keys_; + std::vector> values_; std::vector> conv_state_; std::vector> recur_state_; + KVQuantConfig qcfg_; mx::array offset_; // (B,) mx::array left_padding_; // (B,) }; diff --git a/src/cache/kv_budget.cpp b/src/cache/kv_budget.cpp index 905d524..bdde82e 100644 --- a/src/cache/kv_budget.cpp +++ b/src/cache/kv_budget.cpp @@ -37,11 +37,19 @@ std::size_t compute_linear_state_bytes(const ModelConfig& cfg) { const std::size_t recur_bytes = value_dim * cfg.linear_key_head_dim * kFp32Bytes; return linear_layers * (conv_bytes + recur_bytes); } +// One K-or-V row of head_dim values: fp16 when dense; packed bits plus the +// per-group fp16 scale and bias when quantized (e.g. D=64/g=64: 128 B fp16, +// 68 B at 8-bit, 36 B at 4-bit). +std::size_t per_head_row_bytes(const ModelConfig& cfg, KVQuantConfig q) { + const std::size_t d = static_cast(cfg.head_dim); + if (!q.enabled()) return d * kFp16Bytes; + return d * q.bits / 8 + (d / q.group_size) * 2 /*scale+bias*/ * kFp16Bytes; +} } // namespace -KVBudget::KVBudget(const ModelConfig& cfg, std::size_t budget_bytes) - : bytes_per_token_(2 /*K and V*/ * kv_layer_count(cfg) * cfg.n_kv_heads * cfg.head_dim * - kFp16Bytes), +KVBudget::KVBudget(const ModelConfig& cfg, std::size_t budget_bytes, KVQuantConfig kv_quant) + : bytes_per_token_(2 /*K and V*/ * kv_layer_count(cfg) * cfg.n_kv_heads * + per_head_row_bytes(cfg, kv_quant)), linear_state_bytes_(compute_linear_state_bytes(cfg)), budget_bytes_(budget_bytes) {} diff --git a/src/cache/kv_budget.h b/src/cache/kv_budget.h index cda6f03..a584076 100644 --- a/src/cache/kv_budget.h +++ b/src/cache/kv_budget.h @@ -9,14 +9,17 @@ #include +#include "cache/kv_quant.h" #include "core/config.h" namespace mlxforge { class KVBudget { public: - // budget_bytes == 0 means "unbounded" (admission always allowed). - KVBudget(const ModelConfig& cfg, std::size_t budget_bytes); + // budget_bytes == 0 means "unbounded" (admission always allowed). With a + // KVQuantConfig the per-token figure accounts for the quantized triplet + // storage (packed words + per-group fp16 scales and biases) instead of fp16. + KVBudget(const ModelConfig& cfg, std::size_t budget_bytes, KVQuantConfig kv_quant = {}); // KV bytes consumed by one token across all layers (the 32 KiB/token figure). std::size_t bytes_per_token() const { return bytes_per_token_; } diff --git a/src/cache/kv_cache.cpp b/src/cache/kv_cache.cpp index 64dc606..eb38a5a 100644 --- a/src/cache/kv_cache.cpp +++ b/src/cache/kv_cache.cpp @@ -1,20 +1,42 @@ #include "cache/kv_cache.h" +#include + #include "mlx/ops.h" namespace mlxforge { +std::pair KVCache::fetch(int layer) const { + if (quantized()) throw std::logic_error("KVCache::fetch: cache is quantized"); + return {keys_[layer][0], values_[layer][0]}; +} + std::pair KVCache::update_and_fetch(int layer, const mx::array& k, const mx::array& v) { - if (!keys_[layer]) { - keys_[layer] = k; - values_[layer] = v; + if (quantized()) throw std::logic_error("KVCache::update_and_fetch: cache is quantized"); + if (keys_[layer].empty()) { + keys_[layer] = {k}; + values_[layer] = {v}; } else { // Sequence axis is 2: (1, n_kv_heads, S, head_dim). - keys_[layer] = mx::concatenate({*keys_[layer], k}, /*axis=*/2); - values_[layer] = mx::concatenate({*values_[layer], v}, /*axis=*/2); + keys_[layer][0] = mx::concatenate({keys_[layer][0], k}, /*axis=*/2); + values_[layer][0] = mx::concatenate({values_[layer][0], v}, /*axis=*/2); } - return {*keys_[layer], *values_[layer]}; + return {keys_[layer][0], values_[layer][0]}; +} + +QuantizedKVSlice KVCache::update_and_fetch_quantized(int layer, const mx::array& k, + const mx::array& v) { + if (!quantized()) + throw std::logic_error("KVCache::update_and_fetch_quantized: cache is dense"); + // Block-grown storage written at offset() (the prefill/decode protocol bumps + // it via advance() once per token sweep), mirroring mlx-lm's QuantizedKVCache + // exactly — the golden gates depend on the matching buffer strides. + QuantizedKV ks = + quantize_and_update(keys_[layer], k, qcfg_.group_size, qcfg_.bits, offset_, kStep); + QuantizedKV vs = + quantize_and_update(values_[layer], v, qcfg_.group_size, qcfg_.bits, offset_, kStep); + return {ks, vs}; } } // namespace mlxforge diff --git a/src/cache/kv_cache.h b/src/cache/kv_cache.h index 4f95a08..d2b5bab 100644 --- a/src/cache/kv_cache.h +++ b/src/cache/kv_cache.h @@ -2,11 +2,20 @@ // underlies the batched cache (BatchKVCache) the server needs: prefill fills the // cache once; each decode step appends one token's K/V and attends over the // cached history. +// +// Storage is per layer a small component vector: empty until written, one array +// when dense ((1, n_kv_heads, S, head_dim) fp16), three when quantized (the +// mx::quantize triplet, see cache/kv_quant.h). The dense pair API and the +// quantized API are mutually exclusive, selected by the KVQuantConfig at +// construction; the model dispatches on quantized() (model/sdpa.h). #pragma once #include +#include #include +#include "cache/kv_quant.h" + #include "mlx/array.h" namespace mlxforge { @@ -15,8 +24,11 @@ namespace mx = mlx::core; class KVCache { public: - explicit KVCache(int n_layers) - : keys_(n_layers), values_(n_layers), conv_state_(n_layers), recur_state_(n_layers) {} + static constexpr int kStep = 256; // quantized-storage block growth size + + explicit KVCache(int n_layers, KVQuantConfig qcfg = {}) + : keys_(n_layers), values_(n_layers), conv_state_(n_layers), recur_state_(n_layers), + qcfg_(qcfg) {} // Tokens written so far (== sequence length). Used as the RoPE position // offset for the next chunk. Stays fixed across a single token's layer sweep; @@ -26,19 +38,26 @@ class KVCache { int n_layers() const { return static_cast(keys_.size()); } + bool quantized() const { return qcfg_.enabled(); } + const KVQuantConfig& quant_config() const { return qcfg_; } + // Stored K/V for a layer (each (1, n_kv_heads, offset, head_dim), no capacity - // padding). Valid only after the layer has been written. Lets a prefilled - // single sequence be handed to the batched cache for continuous-batching decode - // (BatchKVCache::from_single_sequence). - std::pair fetch(int layer) const { - return {*keys_[layer], *values_[layer]}; - } + // padding). Valid only after the layer has been written; dense caches only. + // Lets a prefilled single sequence be handed to the batched cache for + // continuous-batching decode (BatchKVCache::from_single_sequence). + std::pair fetch(int layer) const; // Append this layer's K/V (each (1, n_kv_heads, L, head_dim)) along the // sequence axis and return the full cached (keys, values) to attend over. + // Dense caches only (throws on a quantized cache). std::pair update_and_fetch(int layer, const mx::array& k, const mx::array& v); + // Quantized counterpart: quantize the incoming K/V (mx::quantize, per + // position), append each triplet component along the sequence axis, and return + // the full cached triplets for quantized_sdpa. Quantized caches only. + QuantizedKVSlice update_and_fetch_quantized(int layer, const mx::array& k, const mx::array& v); + // Gated-DeltaNet recurrent state for hybrid models (Qwen3.5): the linear // layers carry a fixed conv buffer (1, K-1, conv_dim) and a delta-rule state // (1, Hv, Dv, Dk) instead of a growing KV. Lazily set by the model on the first @@ -53,10 +72,12 @@ class KVCache { } private: - std::vector> keys_; - std::vector> values_; + // Per layer: empty until written; 1 component dense, 3 quantized. + std::vector> keys_; + std::vector> values_; std::vector> conv_state_; std::vector> recur_state_; + KVQuantConfig qcfg_; int offset_ = 0; }; diff --git a/src/cache/kv_quant.cpp b/src/cache/kv_quant.cpp new file mode 100644 index 0000000..cf6b99d --- /dev/null +++ b/src/cache/kv_quant.cpp @@ -0,0 +1,58 @@ +#include "cache/kv_quant.h" + +#include + +#include "mlx/ops.h" + +namespace mlxforge { + +mx::array slice_seq(const mx::array& a, int start, int stop) { + const auto& s = a.shape(); + return mx::slice(a, {0, 0, start, 0}, {s[0], s[1], stop, s[3]}); +} + +std::vector update_kv_components(std::vector& store, + const std::vector& in, int prev, + int step) { + const int L = in[0].shape()[2]; + const int end = prev + L; + + const int cap = store.empty() ? 0 : store[0].shape()[2]; + if (store.empty() || end > cap) { + const int n_steps = (step + L - 1) / step; + const int add = n_steps * step; + std::vector grown; + grown.reserve(in.size()); + for (size_t i = 0; i < in.size(); ++i) { + const auto& s = in[i].shape(); + mx::array fresh = mx::zeros({s[0], s[1], add, s[3]}, in[i].dtype()); + if (store.empty()) { + grown.push_back(std::move(fresh)); + } else { + mx::array cur = store[i]; + // Drop any unused tail of the last block before growing. + if (prev % step != 0) cur = slice_seq(cur, 0, prev); + grown.push_back(mx::concatenate({cur, fresh}, /*axis=*/2)); + } + } + store = std::move(grown); + } + + std::vector out; + out.reserve(in.size()); + for (size_t i = 0; i < in.size(); ++i) { + const auto& s = store[i].shape(); + store[i] = mx::slice_update(store[i], in[i], {0, 0, prev, 0}, {s[0], s[1], end, s[3]}); + out.push_back(slice_seq(store[i], 0, end)); + } + return out; +} + +QuantizedKV quantize_and_update(std::vector& store, const mx::array& x, int group_size, + int bits, int pos, int step) { + std::vector t = update_kv_components(store, mx::quantize(x, group_size, bits), pos, + step); + return {t[0], t[1], t[2]}; +} + +} // namespace mlxforge diff --git a/src/cache/kv_quant.h b/src/cache/kv_quant.h new file mode 100644 index 0000000..0904340 --- /dev/null +++ b/src/cache/kv_quant.h @@ -0,0 +1,67 @@ +// Shared types for quantized KV-cache storage. +// +// Mirrors mlx_lm/models/cache.py::QuantizedKVCache's layout: each cached K or V +// tensor is the 3-tuple mx::quantize() produces — packed uint32 words plus +// per-group fp16 scales and biases — quantized at write time, per position, so +// prefill chunking can never change the stored values. The attention math over +// this storage lives in model/sdpa.h (quantized_sdpa), ported from +// mlx_lm/models/base.py. +#pragma once + +#include + +#include "mlx/array.h" + +namespace mlxforge { + +namespace mx = mlx::core; + +// Engine-wide KV-cache quantization setting. bits == 0 keeps the cache dense +// fp16 (the default); 8 or 4 enable quantized storage. group_size must divide +// head_dim (64 and 128 both work with the default 64). +struct KVQuantConfig { + int bits = 0; + int group_size = 64; + bool enabled() const { return bits > 0; } + bool operator==(const KVQuantConfig& o) const { + return bits == o.bits && group_size == o.group_size; + } + bool operator!=(const KVQuantConfig& o) const { return !(*this == o); } +}; + +// One quantized tensor: w (..., S, head_dim*bits/32) uint32, scales/biases +// (..., S, head_dim/group_size) in the source dtype (fp16). +struct QuantizedKV { + mx::array w; + mx::array scales; + mx::array biases; +}; + +// Slice [start, stop) along the sequence axis (2), keeping axes 0/1/3 full. +// Shared by both caches' growth/eviction surgery. +mx::array slice_seq(const mx::array& a, int start, int stop); + +// The populated K/V slice of one layer, ready for quantized_sdpa. +struct QuantizedKVSlice { + QuantizedKV k; + QuantizedKV v; +}; + +// Write `in`'s components (1 dense, 3 quantized) at sequence positions +// [prev, prev + L) of `store`, growing capacity in `step` blocks (zeros + +// concatenate, trimming any unused tail of the last block first), and return +// each component's populated slice [..., :prev+L, :]. This block-grow + +// slice_update + slice-view strategy deliberately mirrors mlx-lm's caches +// bit-for-bit — the returned views' buffer shapes/strides affect kernel +// accumulation order, and the exact-token golden gates depend on it. +std::vector update_kv_components(std::vector& store, + const std::vector& in, int prev, int step); + +// Quantize `x` (mx::quantize, per position) and write the resulting triplet into +// `store` at sequence position `pos` via update_kv_components, returning the +// populated triplet for quantized_sdpa. Shared by KVCache and BatchKVCache so +// the quantize-then-update op order is identical on both paths. +QuantizedKV quantize_and_update(std::vector& store, const mx::array& x, int group_size, + int bits, int pos, int step); + +} // namespace mlxforge diff --git a/src/capi/mlxforge.cpp b/src/capi/mlxforge.cpp index 712a026..9156457 100644 --- a/src/capi/mlxforge.cpp +++ b/src/capi/mlxforge.cpp @@ -138,6 +138,41 @@ mlxforge_engine* mlxforge_engine_create(const char* model_spec, return nullptr; } +mlxforge_engine* mlxforge_engine_create2(const char* model_spec, + const mlxforge_engine_opts2* opts, char** err) { + if (err) *err = nullptr; + if (!model_spec || !*model_spec) { + set_err(err, "model_spec is null or empty"); + return nullptr; + } + try { + mlxforge::EngineConfig cfg; + cfg.model_spec = model_spec; + // Read only the fields the caller's struct_size covers, so a binary built + // against this header stays correct when later versions append fields. + auto covered = [&](const void* field_end) { + return opts && static_cast(static_cast(field_end) - + reinterpret_cast(opts)) <= + opts->struct_size; + }; + if (covered(&opts->max_waiting + 1) && opts->max_waiting > 0) + cfg.max_waiting = opts->max_waiting; + if (covered(&opts->kv_bits + 1)) cfg.kv_bits = opts->kv_bits; + if (covered(&opts->kv_group_size + 1) && opts->kv_group_size > 0) + cfg.kv_group_size = opts->kv_group_size; + + auto handle = std::make_unique(); + handle->model_name = model_spec; + handle->engine = std::make_unique(std::move(cfg)); + return handle.release(); + } catch (const std::exception& e) { + set_err(err, e.what()); + } catch (...) { + set_err(err, "unknown error creating engine"); + } + return nullptr; +} + int mlxforge_engine_ready(mlxforge_engine* engine) { if (!engine || !engine->engine) return 0; try { diff --git a/src/capi/mlxforge.h b/src/capi/mlxforge.h index d09ccbf..f9edf6a 100644 --- a/src/capi/mlxforge.h +++ b/src/capi/mlxforge.h @@ -41,8 +41,10 @@ extern "C" { * v5: mlxforge_sampling.logprobs + mlxforge_request_logprobs (OpenAI per-token * log-probabilities: the chosen token's logprob and its top-N alternatives, * accumulated as the request is drained and returned as an OpenAI-shaped - * JSON array). */ -#define MLXFORGE_ABI_VERSION 5 + * JSON array). + * v6: mlxforge_engine_create2 + mlxforge_engine_opts2 (KV-cache quantization; + * opts2 carries struct_size so future fields append without a create3). */ +#define MLXFORGE_ABI_VERSION 6 typedef struct mlxforge_engine mlxforge_engine; typedef struct mlxforge_request mlxforge_request; @@ -112,6 +114,30 @@ void mlxforge_floats_free(float* p); mlxforge_engine* mlxforge_engine_create(const char* model_spec, const mlxforge_engine_opts* opts, char** err); +/* Extended engine creation options (v6+). Set struct_size = sizeof(...) and + * zero-initialize the rest for defaults; the library reads only the fields + * struct_size covers, so binaries built against this header stay correct when + * later versions append fields. + * + * kv_bits enables KV-cache quantization (engine-wide; the batched cache's + * storage is shared across rows, so it cannot be per-request): 0 = dense fp16 + * (the default), 8 or 4 store the cache as quantized triplets matching + * mlx-lm's QuantizedKVCache (8 is near-lossless at ~1.9x less cache memory; + * 4 is ~3.6x). Unsupported setups (vision-language or hybrid Qwen3.5 models, + * invalid bits/group sizes) FAIL engine creation with a clear *err — there is + * never a silent fp16 fallback. */ +typedef struct { + size_t struct_size; /* caller sets sizeof(mlxforge_engine_opts2) */ + int max_waiting; /* max queued requests; <= 0 => default (256) */ + int kv_bits; /* 0 = fp16 KV cache (default); 8 or 4 = quantized */ + int kv_group_size; /* quantization group size; <= 0 => default (64) */ +} mlxforge_engine_opts2; + +/* Create an engine with extended options (v6+). Identical contract to + * mlxforge_engine_create; `opts` may be NULL for all defaults. */ +mlxforge_engine* mlxforge_engine_create2(const char* model_spec, + const mlxforge_engine_opts2* opts, char** err); + /* Non-zero once the model has finished loading on the worker thread. Requests * may be submitted before this returns true; they are served once ready. */ int mlxforge_engine_ready(mlxforge_engine* engine); diff --git a/src/model/decoder_model.cpp b/src/model/decoder_model.cpp index b506098..f7a659a 100644 --- a/src/model/decoder_model.cpp +++ b/src/model/decoder_model.cpp @@ -7,6 +7,8 @@ #include #include "core/logging.h" +#include "model/sdpa.h" + #include "mlx/fast.h" #include "mlx/ops.h" #include "mlx/transforms.h" @@ -183,18 +185,12 @@ mx::array DecoderModel::attention(const mx::array& x, int layer, int offset, KVC const int L = x.shape()[1]; QKV qkv = attn_qkv(x, layer, offset); // q (B, n_heads, L, head_dim), k/v use n_kv_heads - if (cache) { - auto kv = cache->update_and_fetch(layer, qkv.k, qkv.v); - qkv.k = kv.first; - qkv.v = kv.second; - } const float scale = 1.0f / std::sqrt(static_cast(cfg_.head_dim)); // Multi-token chunks (prefill) are causal; a single decode token attends over // the whole cached history unmasked. GQA is handled natively by SDPA. const std::string mask_mode = L > 1 ? "causal" : ""; - mx::array out = - mx::fast::scaled_dot_product_attention(qkv.q, qkv.k, qkv.v, scale, mask_mode); + mx::array out = sdpa_with_cache(qkv.q, qkv.k, qkv.v, cache, layer, scale, mask_mode); // (B, n_heads, L, head_dim) -> (B, L, n_heads*head_dim) out = mx::reshape(mx::transpose(out, {0, 2, 1, 3}), {B, L, cfg_.n_heads * cfg_.head_dim}); @@ -246,12 +242,10 @@ mx::array DecoderModel::attention_batched(const mx::array& x, int layer, const m QKV p = project_qkv(x, layer); mx::array q = apply_rope(p.q, offset); - mx::array k = apply_rope(p.k, offset); - auto kv = cache.update_and_fetch(layer, k, p.v); // append roped K, un-roped V + mx::array k = apply_rope(p.k, offset); // append roped K, un-roped V const float scale = 1.0f / std::sqrt(static_cast(cfg_.head_dim)); - mx::array out = mx::fast::scaled_dot_product_attention(q, kv.first, kv.second, scale, - /*mask_mode=*/"", mask); + mx::array out = sdpa_with_cache(q, k, p.v, cache, layer, scale, mask); out = mx::reshape(mx::transpose(out, {0, 2, 1, 3}), {B, L, cfg_.n_heads * cfg_.head_dim}); return linear(out, layer_key(layer, "self_attn.o_proj.weight")); } diff --git a/src/model/qwen3_5.cpp b/src/model/qwen3_5.cpp index f8aa4b6..25fd9e0 100644 --- a/src/model/qwen3_5.cpp +++ b/src/model/qwen3_5.cpp @@ -4,6 +4,8 @@ #include #include +#include "model/sdpa.h" + #include "mlx/fast.h" #include "mlx/ops.h" @@ -59,15 +61,10 @@ mx::array Qwen35Model::attention(const mx::array& x, int layer, int offset, KVCa queries = partial_rope(queries, offset); keys = partial_rope(keys, offset); - if (cache) { - auto kv = cache->update_and_fetch(layer, keys, values); - keys = kv.first; - values = kv.second; - } const float scale = 1.0f / std::sqrt(static_cast(D)); const std::string mask_mode = L > 1 ? "causal" : ""; - mx::array out = mx::fast::scaled_dot_product_attention(queries, keys, values, scale, mask_mode); + mx::array out = sdpa_with_cache(queries, keys, values, cache, layer, scale, mask_mode); // (B, H, L, D) -> (B, L, H*D), apply the sigmoid output gate, then o_proj. out = mx::reshape(mx::transpose(out, {0, 2, 1, 3}), {B, L, H * D}); @@ -266,12 +263,10 @@ mx::array Qwen35Model::attention_batched_gated(const mx::array& x, int layer, {0, 2, 1, 3}); queries = partial_rope(queries, offset); - keys = partial_rope(keys, offset); - auto kv = cache.update_and_fetch(layer, keys, values); // append roped K, un-roped V + keys = partial_rope(keys, offset); // append roped K, un-roped V const float scale = 1.0f / std::sqrt(static_cast(D)); - mx::array out = mx::fast::scaled_dot_product_attention(queries, kv.first, kv.second, scale, - /*mask_mode=*/"", mask); + mx::array out = sdpa_with_cache(queries, keys, values, cache, layer, scale, mask); out = mx::reshape(mx::transpose(out, {0, 2, 1, 3}), {B, L, H * D}); // A left-padding query position attends to zero valid keys, so SDPA returns NaN diff --git a/src/model/sdpa.cpp b/src/model/sdpa.cpp new file mode 100644 index 0000000..7feede9 --- /dev/null +++ b/src/model/sdpa.cpp @@ -0,0 +1,97 @@ +#include "model/sdpa.h" + +#include "mlx/fast.h" +#include "mlx/ops.h" +#include "mlx/utils.h" + +namespace mlxforge { + +mx::array quantized_sdpa(const mx::array& q, const QuantizedKV& keys, const QuantizedKV& values, + float scale, const std::string& mask_mode, + const std::optional& mask, int group_size, int bits) { + const int B = q.shape()[0]; + const int n_q_heads = q.shape()[1]; + const int L = q.shape()[2]; + const int D = q.shape()[3]; + const int n_kv_heads = keys.w.shape()[1]; + const int n_repeats = n_q_heads / n_kv_heads; + + mx::array queries = mx::multiply(q, mx::array(scale, q.dtype())); + + QuantizedKV k = keys; + QuantizedKV v = values; + std::optional m = mask; + if (n_repeats > 1) { + // GQA: group the query heads under their kv head; the kv triplets gain a + // broadcast axis. quantized_matmul has no native GQA handling. + queries = mx::reshape(queries, {B, n_kv_heads, n_repeats, L, D}); + k = {mx::expand_dims(k.w, -3), mx::expand_dims(k.scales, -3), mx::expand_dims(k.biases, -3)}; + v = {mx::expand_dims(v.w, -3), mx::expand_dims(v.scales, -3), mx::expand_dims(v.biases, -3)}; + // Scores are now 5-D (B, n_kv, n_rep, N, T_kv); the (B, 1, N, T_kv) mask + // must gain the n_rep axis too, or trailing-axis broadcasting would align + // B against n_rep — silently wrong attention. + if (m) m = mx::expand_dims(*m, 1); + } + + mx::array scores = + mx::quantized_matmul(queries, k.w, k.scales, k.biases, /*transpose=*/true, group_size, bits); + + if (mask_mode == "causal") { + const int kL = scores.shape().back(); + const int qL = L; + // Bool is safe here: these are plain ops on the scores, not the fast::SDPA + // kernel mask (#2894). finfo(...).min (not -inf) matches mlx-lm exactly. + mx::array q_idx = mx::reshape(mx::arange(kL - qL, kL, 1, mx::int32), {qL, 1}); + mx::array k_idx = mx::reshape(mx::arange(0, kL, 1, mx::int32), {1, kL}); + mx::array causal = mx::greater_equal(q_idx, k_idx); + const float lowest = static_cast(mx::finfo(scores.dtype()).min); + scores = mx::where(causal, scores, mx::array(lowest, scores.dtype())); + } else if (m) { + // The mask must OVERRIDE masked columns, not add: a fully-masked left-pad + // query row yields NaN attention output, which poisons that position's K/V + // in later layers, and NaN + (-inf) stays NaN (the fused dense kernel + // handles this internally; this path must do it explicitly). Real columns + // are unchanged (mask 0), and exp(lowest - max) underflows to exactly 0 in + // the precise softmax, so real rows match the plain additive form. + const float lowest = static_cast(mx::finfo(scores.dtype()).min); + scores = mx::where(mx::isneginf(*m), mx::array(lowest, scores.dtype()), + mx::add(scores, *m)); + } + + scores = mx::softmax(scores, /*axis=*/-1, /*precise=*/true); + mx::array out = + mx::quantized_matmul(scores, v.w, v.scales, v.biases, /*transpose=*/false, group_size, bits); + + if (n_repeats > 1) { + out = mx::reshape(out, {B, n_q_heads, L, out.shape().back()}); + } + return out; +} + +mx::array sdpa_with_cache(const mx::array& q, const mx::array& k, const mx::array& v, + KVCache* cache, int layer, float scale, const std::string& mask_mode) { + if (cache && cache->quantized()) { + QuantizedKVSlice s = cache->update_and_fetch_quantized(layer, k, v); + const KVQuantConfig& qc = cache->quant_config(); + return quantized_sdpa(q, s.k, s.v, scale, mask_mode, std::nullopt, qc.group_size, qc.bits); + } + if (cache) { + auto kv = cache->update_and_fetch(layer, k, v); + return mx::fast::scaled_dot_product_attention(q, kv.first, kv.second, scale, mask_mode); + } + return mx::fast::scaled_dot_product_attention(q, k, v, scale, mask_mode); +} + +mx::array sdpa_with_cache(const mx::array& q, const mx::array& k, const mx::array& v, + BatchKVCache& cache, int layer, float scale, const mx::array& mask) { + if (cache.quantized()) { + QuantizedKVSlice s = cache.update_and_fetch_quantized(layer, k, v); + const KVQuantConfig& qc = cache.quant_config(); + return quantized_sdpa(q, s.k, s.v, scale, /*mask_mode=*/"", mask, qc.group_size, qc.bits); + } + auto kv = cache.update_and_fetch(layer, k, v); + return mx::fast::scaled_dot_product_attention(q, kv.first, kv.second, scale, + /*mask_mode=*/"", mask); +} + +} // namespace mlxforge diff --git a/src/model/sdpa.h b/src/model/sdpa.h new file mode 100644 index 0000000..5e642da --- /dev/null +++ b/src/model/sdpa.h @@ -0,0 +1,45 @@ +// Cache-aware scaled-dot-product attention dispatch. +// +// sdpa_with_cache() is the one seam the models call after projecting Q/K/V: it +// appends K/V to the cache and attends over the stored history, picking the +// dense fast::scaled_dot_product_attention kernel or the hand-rolled quantized +// path by the cache's KVQuantConfig — the C++ twin of mlx_lm/models/base.py:: +// scaled_dot_product_attention's hasattr(cache, "bits") dispatch. +// +// quantized_sdpa() is ported op-for-op from mlx_lm/models/base.py:: +// quantized_scaled_dot_product_attention (MLX has no fused quantized SDPA +// kernel): fp16 q*scale, GQA via a (B, n_kv_heads, n_repeats, L, D) reshape, +// mx::quantized_matmul for both the scores and the output, and a precise +// softmax. Deviating from any of those breaks the exact-token golden gate. +#pragma once + +#include +#include + +#include "cache/batch_kv_cache.h" +#include "cache/kv_cache.h" +#include "cache/kv_quant.h" + +#include "mlx/array.h" + +namespace mlxforge { + +namespace mx = mlx::core; + +// Attention over quantized K/V triplets. `mask_mode` is "causal" or "" (mirrors +// fast::scaled_dot_product_attention); `mask` is the batched additive fp16 +// (B, 1, N, T_kv) mask, exclusive with mask_mode. +mx::array quantized_sdpa(const mx::array& q, const QuantizedKV& keys, const QuantizedKV& values, + float scale, const std::string& mask_mode, + const std::optional& mask, int group_size, int bits); + +// Single-stream path (prefill is causal, decode unmasked). `cache` may be null +// (full recompute): plain fast SDPA with no cache write. +mx::array sdpa_with_cache(const mx::array& q, const mx::array& k, const mx::array& v, + KVCache* cache, int layer, float scale, const std::string& mask_mode); + +// Continuous-batching path: per-row additive fp16 mask from batch_mask(). +mx::array sdpa_with_cache(const mx::array& q, const mx::array& k, const mx::array& v, + BatchKVCache& cache, int layer, float scale, const mx::array& mask); + +} // namespace mlxforge diff --git a/src/runtime/batching.cpp b/src/runtime/batching.cpp index a0309ca..d264e80 100644 --- a/src/runtime/batching.cpp +++ b/src/runtime/batching.cpp @@ -15,7 +15,7 @@ int next_bucket(int n) { } PrefillResult prefill(const DecoderModel& model, const std::vector>& prompts, - int step_size, int pad_id) { + int step_size, int pad_id, KVQuantConfig kv_quant) { const int B = static_cast(prompts.size()); int p_max = 0; for (const auto& p : prompts) p_max = std::max(p_max, static_cast(p.size())); @@ -30,7 +30,7 @@ PrefillResult prefill(const DecoderModel& model, const std::vector>& prompts, - int step_size = kPrefillStepSize, int pad_id = 0); + int step_size = kPrefillStepSize, int pad_id = 0, + KVQuantConfig kv_quant = {}); } // namespace mlxforge diff --git a/src/runtime/engine.cpp b/src/runtime/engine.cpp index c4a9134..38b9210 100644 --- a/src/runtime/engine.cpp +++ b/src/runtime/engine.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -36,6 +37,30 @@ void detect_embedding_defaults(const std::string& dir, int& pooling, bool& add_e } } +// Validate the engine's KV-quantization request against the loaded model and +// return the Worker's KVQuantConfig. Unsupported setups are hard errors — +// never a silent fp16 fallback (the failure mode here is silent numerical +// garbage, so the caller must know exactly what storage they got). +KVQuantConfig validate_kv_quant(const EngineConfig& ec, const ModelConfig& mc) { + if (ec.kv_bits == 0) return {}; + if (ec.kv_bits != 4 && ec.kv_bits != 8) + throw std::runtime_error("kv_bits must be 0 (off), 4, or 8; got " + + std::to_string(ec.kv_bits)); + if (ec.kv_group_size != 32 && ec.kv_group_size != 64 && ec.kv_group_size != 128) + throw std::runtime_error("kv_group_size must be 32, 64, or 128; got " + + std::to_string(ec.kv_group_size)); + if (mc.head_dim % ec.kv_group_size != 0) + throw std::runtime_error("kv_group_size " + std::to_string(ec.kv_group_size) + + " does not divide head_dim " + std::to_string(mc.head_dim)); + // Golden-gated for the standard attention paths only so far: the vision + // (mlx-vlm-gated) and Qwen3.5 hybrid streams have no quantized reference yet. + if (mc.has_vision_tower()) + throw std::runtime_error("KV-cache quantization is not supported for vision-language models"); + if (mc.full_attention_interval > 0) + throw std::runtime_error("KV-cache quantization is not supported for hybrid (Qwen3.5) models"); + return {ec.kv_bits, ec.kv_group_size}; +} + } // namespace // Loads the model directory, config, and tokenizer metadata, but not weights. @@ -112,7 +137,9 @@ Engine::Engine(EngineConfig cfg, Loaded loaded) embed_add_eos_default_(loaded.embed_add_eos_default), // Pass the tokenizer so the worker can build per-token byte strings for // constrained decoding. tok_ is initialized above and outlives worker_. - worker_(make_factory(std::move(loaded.dir), loaded.is_gguf), &scheduler_, &tok_) { + // cfg_ is initialized above, so the KV-quant validation sees the model. + worker_(make_factory(std::move(loaded.dir), loaded.is_gguf), &scheduler_, &tok_, + validate_kv_quant(cfg, cfg_)) { // Configure the max waiting requests for the batch scheduler. scheduler_.set_max_waiting(cfg.max_waiting); diff --git a/src/runtime/engine.h b/src/runtime/engine.h index 1150af2..94565f3 100644 --- a/src/runtime/engine.h +++ b/src/runtime/engine.h @@ -25,6 +25,14 @@ namespace mlxforge { struct EngineConfig { std::string model_spec; // Model description: local directory, HuggingFace repo id, or .gguf file path (to be resolved internally) int max_waiting = 256; // Maximum length of the Scheduler's waiting queue; 0 disables the cap + // KV-cache quantization (engine-wide: the batched cache's storage is shared + // across rows, so this cannot be per-request). 0 = dense fp16 (default); + // 8 or 4 store the cache as mx::quantize triplets, matching mlx-lm's + // QuantizedKVCache numerics. Validated against the model at construction: + // unsupported combinations (vision/hybrid models, group_size not dividing + // head_dim) throw rather than silently falling back. + int kv_bits = 0; + int kv_group_size = 64; }; // Per-call embedding options. The two int fields are tri-state: -1 means "use diff --git a/src/runtime/single_stream.cpp b/src/runtime/single_stream.cpp index eab2205..00ae263 100644 --- a/src/runtime/single_stream.cpp +++ b/src/runtime/single_stream.cpp @@ -55,7 +55,8 @@ int greedy_last(const mx::array& logits, int top_logprobs, TokenLogprob* lp) { GenerateResult greedy_generate(const DecoderModel& model, const std::vector& prompt_ids, int max_tokens, const std::vector& eos_ids, - const std::function& on_token, int top_logprobs) { + const std::function& on_token, int top_logprobs, + KVQuantConfig kv_quant) { auto is_eos = [&](int id) { return std::find(eos_ids.begin(), eos_ids.end(), id) != eos_ids.end(); }; @@ -66,7 +67,7 @@ GenerateResult greedy_generate(const DecoderModel& model, const std::vector }; GenerateResult result; - KVCache cache(model.config().n_layers); + KVCache cache(model.config().n_layers, kv_quant); mx::array prompt(prompt_ids.data(), {1, static_cast(prompt_ids.size())}, mx::int32); // `next_lp` carries the log-prob record for `next`, pushed when (and only when) diff --git a/src/runtime/single_stream.h b/src/runtime/single_stream.h index 3000255..a242d3e 100644 --- a/src/runtime/single_stream.h +++ b/src/runtime/single_stream.h @@ -32,10 +32,12 @@ struct GenerateResult { // emitted token. Stops when an EOS token would be produced or after max_tokens. // `top_logprobs` mirrors SamplingParams: -1 = off (no logprob work); 0 = record // each emitted token's own log-prob into result.token_logprobs; N > 0 = also -// record its N most-likely alternatives. +// record its N most-likely alternatives. `kv_quant` selects the cache storage +// (dense fp16 by default; the quantized-KV golden tests and the CLI's --kv-bits +// drive the quantized path). GenerateResult greedy_generate(const DecoderModel& model, const std::vector& prompt_ids, int max_tokens, const std::vector& eos_ids, const std::function& on_token = {}, - int top_logprobs = -1); + int top_logprobs = -1, KVQuantConfig kv_quant = {}); } // namespace mlxforge diff --git a/src/runtime/worker.cpp b/src/runtime/worker.cpp index 8340d39..db5e124 100644 --- a/src/runtime/worker.cpp +++ b/src/runtime/worker.cpp @@ -57,8 +57,9 @@ bool consume(Request& req, int& produced, int id, const TokenLogprob* lp) { } } // namespace -Worker::Worker(ModelFactory factory, Scheduler* scheduler, const Tokenizer* tok) - : factory_(std::move(factory)), sched_(scheduler), tok_(tok) {} +Worker::Worker(ModelFactory factory, Scheduler* scheduler, const Tokenizer* tok, + KVQuantConfig kv_quant) + : factory_(std::move(factory)), sched_(scheduler), tok_(tok), kv_quant_(kv_quant) {} Worker::~Worker() { stop(); } @@ -247,7 +248,7 @@ void Worker::admit(const std::vector>& incoming) { log::debug("worker: admitting {} request(s) (batch {} -> {})", incoming.size(), reqs_.size(), reqs_.size() + incoming.size()); - PrefillResult pr = prefill(*model_, prompts); + PrefillResult pr = prefill(*model_, prompts, kPrefillStepSize, /*pad_id=*/0, kv_quant_); if (!cache_) { cache_ = std::make_unique(std::move(pr.cache)); diff --git a/src/runtime/worker.h b/src/runtime/worker.h index 9e9a213..5e62d89 100644 --- a/src/runtime/worker.h +++ b/src/runtime/worker.h @@ -35,9 +35,12 @@ class Worker { // `tok` (optional) supplies the per-token byte strings used for constrained // decoding; when null, grammar-constrained requests fall back to unconstrained. - // Defined out-of-line (with the destructor) because the unique_ptr - // member needs the complete type for cleanup. - Worker(ModelFactory factory, Scheduler* scheduler, const Tokenizer* tok = nullptr); + // `kv_quant` selects the decode cache's storage (dense fp16 by default); the + // Engine validates it against the model before construction. Defined + // out-of-line (with the destructor) because the unique_ptr member + // needs the complete type for cleanup. + Worker(ModelFactory factory, Scheduler* scheduler, const Tokenizer* tok = nullptr, + KVQuantConfig kv_quant = {}); ~Worker(); Worker(const Worker&) = delete; @@ -120,6 +123,7 @@ class Worker { ModelFactory factory_; Scheduler* sched_; const Tokenizer* tok_; // for per-token bytes (grammar masking); may be null + KVQuantConfig kv_quant_; // decode-cache storage (dense when bits == 0) std::vector token_bytes_; // id -> output bytes ("" for specials) bool token_bytes_built_ = false; std::unique_ptr model_; // constructed and owned on the worker thread diff --git a/src/server/config.cpp b/src/server/config.cpp index da3b11c..c76e6a9 100644 --- a/src/server/config.cpp +++ b/src/server/config.cpp @@ -57,7 +57,7 @@ ServerConfig ServerConfig::from_file(const std::string& path) { // Reject unknown keys up front so typos (e.g. "prot") fail loudly. static const std::set kKnownKeys = { - "model", "host", "port", "max_ctx", "max_waiting", "kv_budget"}; + "model", "host", "port", "max_ctx", "max_waiting", "kv_budget", "kv_bits"}; for (const auto& [key, _] : j.items()) { if (kKnownKeys.find(key) == kKnownKeys.end()) { throw std::runtime_error("config file: unknown key '" + key + "' in '" + path + "'"); @@ -86,6 +86,11 @@ ServerConfig ServerConfig::from_file(const std::string& path) { if (budget < 0) throw std::runtime_error("config file: 'kv_budget' must be >= 0"); c.kv_budget_bytes = static_cast(budget); } + if (j.contains("kv_bits")) { + c.kv_bits = require_type(j, "kv_bits"); + if (c.kv_bits != 0 && c.kv_bits != 4 && c.kv_bits != 8) + throw std::runtime_error("config file: 'kv_bits' must be 0, 4, or 8"); + } return c; } @@ -128,6 +133,7 @@ ServerConfig ServerConfig::parse(const std::vector& args) { c.max_waiting = static_cast(env_long("MLXFORGE_MAX_WAITING", c.max_waiting)); c.kv_budget_bytes = static_cast(env_long("MLXFORGE_KV_BUDGET", static_cast(c.kv_budget_bytes))); + c.kv_bits = static_cast(env_long("MLXFORGE_KV_BITS", c.kv_bits)); // Helper: extract value for a flag (accepts "--flag value" or "--flag=value") auto value_of = [&](const std::string& a, size_t& i) -> std::string { @@ -158,9 +164,13 @@ ServerConfig ServerConfig::parse(const std::vector& args) { c.max_waiting = std::stoi(value_of(a, i)); else if (flag == "--kv-budget") c.kv_budget_bytes = static_cast(std::stoll(value_of(a, i))); + else if (flag == "--kv-bits") + c.kv_bits = std::stoi(value_of(a, i)); else throw std::runtime_error("unknown flag: " + flag); } + if (c.kv_bits != 0 && c.kv_bits != 4 && c.kv_bits != 8) + throw std::runtime_error("--kv-bits must be 0, 4, or 8"); return c; } diff --git a/src/server/config.h b/src/server/config.h index fc2c0d1..d759710 100644 --- a/src/server/config.h +++ b/src/server/config.h @@ -27,17 +27,23 @@ struct ServerConfig { // Memory budget for the KV cache in bytes. 0 = unbounded (all requests admitted). std::size_t kv_budget_bytes = 0; + // KV-cache quantization bits: 0 = dense fp16 (default), 8 or 4 store the + // cache quantized (engine-wide; group size fixed at 64 for the server). + int kv_bits = 0; + // Parses command line arguments (the model via -m/--model, plus optional flags as --flag value or --flag=value), // layering configuration sources by precedence (lowest to highest): // struct defaults < config file (-c/--config) < environment variables < CLI flags. // The config file is a JSON object (see from_file); CLI flags always override it. - // Env vars: MLXFORGE_HOST, MLXFORGE_PORT, MLXFORGE_MAX_CTX, MLXFORGE_MAX_WAITING, MLXFORGE_KV_BUDGET. + // Env vars: MLXFORGE_HOST, MLXFORGE_PORT, MLXFORGE_MAX_CTX, MLXFORGE_MAX_WAITING, + // MLXFORGE_KV_BUDGET, MLXFORGE_KV_BITS. // Throws std::runtime_error if an unknown or malformed flag is encountered. static ServerConfig parse(const std::vector& args); // Loads and validates a JSON config file into a fully-populated ServerConfig, // with struct defaults filling any keys the file omits. Recognized keys - // (snake_case): "model", "host", "port", "max_ctx", "max_waiting", "kv_budget". + // (snake_case): "model", "host", "port", "max_ctx", "max_waiting", "kv_budget", + // "kv_bits". // Validates before applying: rejects unknown keys, wrong types, and out-of-range // values. Throws std::runtime_error (with the file path / offending key) on any // failure to open, parse, or validate. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ba8fc2a..c816559 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -17,7 +17,10 @@ add_executable(mlxforge_tests cache/kv_cache_test.cpp cache/batch_kv_cache_test.cpp cache/batch_kv_filter_merge_test.cpp + cache/kv_quant_cache_test.cpp + cache/kv_cache_quantized_test.cpp cache/kv_budget_test.cpp + model/quantized_sdpa_test.cpp sample/sampler_test.cpp sample/json_grammar_test.cpp runtime/single_stream_test.cpp diff --git a/tests/cache/kv_budget_test.cpp b/tests/cache/kv_budget_test.cpp index 92f7498..26a77fa 100644 --- a/tests/cache/kv_budget_test.cpp +++ b/tests/cache/kv_budget_test.cpp @@ -42,3 +42,13 @@ TEST_CASE("a zero budget is treated as unbounded") { KVBudget budget(llama32_1b(), /*budget_bytes=*/0); CHECK(budget.can_admit(131072, 4096, 256)); } + +TEST_CASE("quantized KV accounting: packed words + per-group scales/biases") { + // Llama-3.2-1B, D=64, group 64: one K-or-V head row is 64*bits/8 packed bytes + // plus one fp16 scale + bias (4 bytes). 8-bit: 68 B vs 128 B fp16 (1.88x); + // 4-bit: 36 B (3.56x). + KVBudget q8(llama32_1b(), 0, KVQuantConfig{8, 64}); + CHECK(q8.bytes_per_token() == 2u * 16 * 8 * 68); + KVBudget q4(llama32_1b(), 0, KVQuantConfig{4, 64}); + CHECK(q4.bytes_per_token() == 2u * 16 * 8 * 36); +} diff --git a/tests/cache/kv_cache_quantized_test.cpp b/tests/cache/kv_cache_quantized_test.cpp new file mode 100644 index 0000000..30c025e --- /dev/null +++ b/tests/cache/kv_cache_quantized_test.cpp @@ -0,0 +1,224 @@ +// Quantized-KV golden gates. The correctness check for the quantized cache + +// quantized SDPA: a teacher-forced greedy walk over mlx-lm's QuantizedKVCache +// reference stream (fixtures greedy_tokens_kvq{8,4}.npy + per-step top-2 +// margins, dumped by reference/dump_ref.py) must reproduce the reference token +// at every step whose margin clears the cross-context noise. Exact full-stream +// equality is NOT a sound gate here: quantized matmuls are fusion-context +// sensitive (the same mlx-lm function on the same values shifts by ~1 logit +// between lazy and materialized inputs), so knife-edge steps may flip without +// either side being wrong. Teacher-forcing keeps every step's cache content +// identical to the reference run, so each step is gated independently. +// Llama covers GQA n_repeats=4 / head_dim=64; Qwen3 covers QK-Norm / +// head_dim=128. The batched case is gated exactly (C++ vs C++): the quantized +// BatchKVCache must reproduce the single-stream quantized rows. +#include + +#include +#include +#include +#include +#include + +#include "capi/mlxforge.h" +#include "cache/batch_kv_cache.h" +#include "cache/kv_cache.h" +#include "runtime/engine.h" +#include "support/model_fixture.h" +#include "support/reference.h" + +#include "mlx/ops.h" +#include "mlx/transforms.h" + +using namespace mlxforge::test; + +namespace { + +// Greedy next-token of the last position of each batch row: (B, L, vocab) -> (B,). +mx::array greedy_last(const mx::array& logits) { + const int L = logits.shape()[1]; + const int vocab = logits.shape()[2]; + mx::array last = mx::reshape( + mx::slice(logits, {0, L - 1, 0}, {logits.shape()[0], L, vocab}), {logits.shape()[0], vocab}); + return mx::astype(mx::argmax(last, /*axis=*/-1), mx::int32); // (B,) +} + +std::vector to_vec(const mx::array& a) { + mx::array c = mx::contiguous(mx::astype(a, mx::int32)); + mx::eval(c); + const int32_t* p = c.data(); + return std::vector(p, p + c.size()); +} + +// Greedy run for one prompt through a (possibly quantized) single-stream cache. +std::vector solo_run(const mlxforge::DecoderModel& model, const std::vector& ids, + int steps, mlxforge::KVQuantConfig qc) { + mlxforge::KVCache cache(model.config().n_layers, qc); + mx::array prompt(ids.data(), {1, static_cast(ids.size())}, mx::int32); + int next = to_vec(greedy_last(model.forward(prompt, &cache)))[0]; + std::vector out = {next}; + for (int s = 1; s < steps; ++s) { + mx::array step(&next, {1, 1}, mx::int32); + next = to_vec(greedy_last(model.forward(step, &cache)))[0]; + out.push_back(next); + } + return out; +} + +std::vector load_floats(const std::string& path) { + mx::array a = mx::contiguous(mx::astype(mx::load(path), mx::float32)); + mx::eval(a); + return std::vector(a.data(), a.data() + a.size()); +} + +// Steps whose reference top-2 margin is below this are knife edges that the +// fusion-context noise (~1 logit) can legitimately flip; above it a mismatch is +// a real bug. +constexpr float kMarginThreshold = 2.0f; + +// Teacher-forced margin-gated gate: walk the reference stream feeding the +// REFERENCE token each step (so the cache content tracks the reference run +// regardless of our own argmax), asserting our prediction equals the reference +// wherever its margin clears the threshold. +void gated_stream_check(const mlxforge::DecoderModel& model, const std::vector& prompt_ids, + const std::vector& ref, const std::vector& gaps, int bits) { + mlxforge::KVCache cache(model.config().n_layers, mlxforge::KVQuantConfig{bits, 64}); + mx::array prompt(prompt_ids.data(), {1, static_cast(prompt_ids.size())}, mx::int32); + int pred = to_vec(greedy_last(model.forward(prompt, &cache)))[0]; + int asserted = 0; + for (size_t i = 0; i < ref.size(); ++i) { + if (gaps[i] >= kMarginThreshold) { + CAPTURE(i); + CAPTURE(gaps[i]); + CHECK(pred == ref[i]); + ++asserted; + } + if (i + 1 == ref.size()) break; + int forced = ref[i]; // teacher-force the reference's own token + mx::array step(&forced, {1, 1}, mx::int32); + pred = to_vec(greedy_last(model.forward(step, &cache)))[0]; + } + CHECK(asserted > 0); // the fixture must gate something +} + +} // namespace + +TEST_CASE("Llama: quantized-KV stream matches mlx-lm's QuantizedKVCache (margin-gated)") { + if (!model_available()) { + MESSAGE("MLXFORGE_MODEL_DIR not present; skipping"); + return; + } + mlxforge::LlamaModel& model = shared_model(); + std::vector ids = load_token_ids("prompt_0_ids.npy"); + + for (int bits : {8, 4}) { + CAPTURE(bits); + const std::string suffix = bits == 8 ? "kvq8.npy" : "kvq4.npy"; + gated_stream_check(model, ids, load_token_ids("greedy_tokens_" + suffix), + load_floats(ref_path("greedy_gaps_" + suffix)), bits); + } +} + +TEST_CASE("Qwen3: quantized-KV stream matches mlx-lm's QuantizedKVCache (margin-gated)") { + if (!qwen3_model_available() || !std::ifstream(qwen3_ref_path("greedy_tokens_kvq8.npy")).good()) { + MESSAGE("Qwen3 model/quantized fixtures not present; skipping"); + return; + } + mlxforge::Qwen3Model& model = shared_qwen3_model(); + std::vector ids = load_qwen3_token_ids("prompt_0_ids.npy"); + + for (int bits : {8, 4}) { + CAPTURE(bits); + const std::string suffix = bits == 8 ? "kvq8.npy" : "kvq4.npy"; + gated_stream_check(model, ids, load_qwen3_token_ids("greedy_tokens_" + suffix), + load_floats(qwen3_ref_path("greedy_gaps_" + suffix)), bits); + } +} + +TEST_CASE("batched quantized decode matches single-stream quantized runs") { + if (!model_available()) { + MESSAGE("MLXFORGE_MODEL_DIR not present; skipping"); + return; + } + mlxforge::LlamaModel& model = shared_model(); + const mlxforge::KVQuantConfig qc{8, 64}; + constexpr int kSteps = 8; + + // Three fixed prompts of different lengths (ragged -> left-padding). + std::vector> prompts = { + load_token_ids("prompt_0_ids.npy"), + load_token_ids("prompt_1_ids.npy"), + load_token_ids("prompt_2_ids.npy"), + }; + const int B = static_cast(prompts.size()); + + std::vector> solo(B); + for (int b = 0; b < B; ++b) solo[b] = solo_run(model, prompts[b], kSteps, qc); + + int p_max = 0; + for (auto& p : prompts) p_max = std::max(p_max, static_cast(p.size())); + std::vector left_padding(B); + std::vector padded(B * p_max, 0); // pad id 0 (masked out) + for (int b = 0; b < B; ++b) { + const int pad = p_max - static_cast(prompts[b].size()); + left_padding[b] = pad; + for (size_t j = 0; j < prompts[b].size(); ++j) padded[b * p_max + pad + j] = prompts[b][j]; + } + + mlxforge::BatchKVCache cache(model.config().n_layers, left_padding, qc); + mx::array tokens(padded.data(), {B, p_max}, mx::int32); + + mx::array next = greedy_last(model.forward(tokens, cache)); + mx::eval(next); + std::vector> batched(B); + for (int b = 0; b < B; ++b) batched[b].push_back(to_vec(next)[b]); + + for (int s = 1; s < kSteps; ++s) { + mx::array step = mx::reshape(next, {B, 1}); + next = greedy_last(model.forward(step, cache)); + mx::eval(next); // single eval per decode step, whole batch + std::vector row = to_vec(next); + for (int b = 0; b < B; ++b) batched[b].push_back(row[b]); + } + + for (int b = 0; b < B; ++b) { + INFO("row " << b); + assert_tokens_equal(batched[b], solo[b]); + } +} + +TEST_CASE("engine rejects invalid or unsupported KV quantization (no silent fallback)") { + if (!model_available()) { + MESSAGE("MLXFORGE_MODEL_DIR not present; skipping"); + return; + } + // Validation runs in the Engine constructor after the (cheap) config/tokenizer + // head-load and before any weights load, so these are fast. + auto cfg_with = [&](int bits, int group = 64) { + mlxforge::EngineConfig ec; + ec.model_spec = model_dir(); + ec.kv_bits = bits; + ec.kv_group_size = group; + return ec; + }; + CHECK_THROWS_AS(mlxforge::Engine(cfg_with(5)), std::runtime_error); + CHECK_THROWS_AS(mlxforge::Engine(cfg_with(8, 48)), std::runtime_error); + + // The hybrid (Qwen3.5) family has no quantized golden reference yet. + if (qwen3_5_model_available()) { + mlxforge::EngineConfig ec; + ec.model_spec = qwen3_5_model_dir(); + ec.kv_bits = 8; + CHECK_THROWS_AS(mlxforge::Engine(std::move(ec)), std::runtime_error); + } + + // Same contract through the C ABI: create2 fails with an error message. + mlxforge_engine_opts2 opts = {}; + opts.struct_size = sizeof(opts); + opts.kv_bits = 5; + char* err = nullptr; + mlxforge_engine* eng = mlxforge_engine_create2(model_dir().c_str(), &opts, &err); + CHECK(eng == nullptr); + REQUIRE(err != nullptr); + CHECK(std::string(err).find("kv_bits") != std::string::npos); + mlxforge_string_free(err); +} diff --git a/tests/cache/kv_quant_cache_test.cpp b/tests/cache/kv_quant_cache_test.cpp new file mode 100644 index 0000000..f8d4380 --- /dev/null +++ b/tests/cache/kv_quant_cache_test.cpp @@ -0,0 +1,164 @@ +// Quantized (triplet) KV storage: growth, round-trip, and batch surgery +// (filter/merge/pad_dummies) on tiny tensors — no model required. The gate in +// every case: the quantized cache dequantizes to exactly what quantizing the +// same dense values would give (quantization is per position, so storage +// surgery and quantization commute). +#include + +#include + +#include "cache/batch_kv_cache.h" +#include "cache/kv_cache.h" + +#include "mlx/ops.h" +#include "mlx/transforms.h" + +using namespace mlxforge; +namespace mx = mlx::core; + +namespace { + +// Deterministic varied values (constants would quantize trivially). +mx::array varied(int B, int H, int L, int D, float phase) { + mx::array a = mx::arange(static_cast(B * H * L * D)); + a = mx::sin(mx::add(mx::multiply(a, mx::array(0.37f)), mx::array(phase))); + return mx::astype(mx::reshape(a, {B, H, L, D}), mx::float16); +} + +// Quantize-then-dequantize of dense values: the exact fp16 content a quantized +// cache must reproduce for the same written positions. +mx::array qdq(const mx::array& x, const KVQuantConfig& qc) { + std::vector t = mx::quantize(x, qc.group_size, qc.bits); + return mx::dequantize(t[0], t[1], t[2], qc.group_size, qc.bits); +} + +bool same(const mx::array& a, const mx::array& b) { + mx::array eq = mx::allclose(a, b, /*rtol=*/0.0, /*atol=*/1e-6); + mx::eval(eq); + return eq.item(); +} + +std::vector read_ints(const mx::array& a) { + mx::array c = mx::contiguous(mx::astype(a, mx::int32)); + mx::eval(c); + const int32_t* p = c.data(); + return std::vector(p, p + c.size()); +} + +constexpr int kD = 64; // head_dim must be a multiple of group_size + +} // namespace + +TEST_CASE("quantized KVCache round-trips appended K/V") { + for (int bits : {8, 4}) { + CAPTURE(bits); + const KVQuantConfig qc{bits, 64}; + KVCache cache(/*n_layers=*/1, qc); + CHECK(cache.quantized()); + + mx::array k1 = varied(1, 2, 5, kD, 0.1f), v1 = varied(1, 2, 5, kD, 0.2f); + mx::array k2 = varied(1, 2, 1, kD, 0.3f), v2 = varied(1, 2, 1, kD, 0.4f); + cache.update_and_fetch_quantized(0, k1, v1); + cache.advance(5); + QuantizedKVSlice s = cache.update_and_fetch_quantized(0, k2, v2); + cache.advance(1); + + mx::array got_k = mx::dequantize(s.k.w, s.k.scales, s.k.biases, qc.group_size, qc.bits); + mx::array got_v = mx::dequantize(s.v.w, s.v.scales, s.v.biases, qc.group_size, qc.bits); + CHECK(same(got_k, qdq(mx::concatenate({k1, k2}, 2), qc))); + CHECK(same(got_v, qdq(mx::concatenate({v1, v2}, 2), qc))); + } +} + +TEST_CASE("dense and quantized cache APIs are mutually exclusive") { + mx::array k = varied(1, 1, 1, kD, 0.0f); + KVCache dense(1); + CHECK_THROWS(dense.update_and_fetch_quantized(0, k, k)); + KVCache quant(1, KVQuantConfig{8, 64}); + CHECK_THROWS(quant.update_and_fetch(0, k, k)); + CHECK_THROWS(quant.fetch(0)); + + BatchKVCache bdense(1, {0}); + CHECK_THROWS(bdense.update_and_fetch_quantized(0, k, k)); + BatchKVCache bquant(1, {0}, KVQuantConfig{8, 64}); + CHECK_THROWS(bquant.update_and_fetch(0, k, k)); +} + +TEST_CASE("quantized BatchKVCache grows across a 256 boundary and preserves contents") { + const KVQuantConfig qc{8, 64}; + BatchKVCache cache(/*n_layers=*/1, /*left_padding=*/{0}, qc); + + mx::array k1 = varied(1, 1, 200, kD, 0.1f), v1 = varied(1, 1, 200, kD, 0.2f); + cache.update_and_fetch_quantized(0, k1, v1); + cache.advance(200); + CHECK(cache.s_cap() == 256); + + mx::array k2 = varied(1, 1, 100, kD, 0.3f), v2 = varied(1, 1, 100, kD, 0.4f); + cache.update_and_fetch_quantized(0, k2, v2); + cache.advance(100); + CHECK(cache.s_cap() == 456); // exactly one growth: 200 (trimmed) + 256 + + auto [k, v] = cache.fetch_dequantized(0); + CHECK(k.shape() == mx::Shape{1, 1, 300, kD}); + CHECK(same(k, qdq(mx::concatenate({k1, k2}, 2), qc))); + CHECK(same(v, qdq(mx::concatenate({v1, v2}, 2), qc))); +} + +TEST_CASE("quantized batch surgery (pad_dummies/filter/merge) matches dense + quantize") { + const KVQuantConfig qc{8, 64}; + + // Mirror every write and surgery op on a dense cache and a quantized one. + BatchKVCache dense(1, {1, 0}); + BatchKVCache quant(1, {1, 0}, qc); + mx::array k = varied(2, 1, 4, kD, 0.1f), v = varied(2, 1, 4, kD, 0.2f); + dense.update_and_fetch(0, k, v); + quant.update_and_fetch_quantized(0, k, v); + dense.advance(4); + quant.advance(4); + + auto check_equal = [&](const char* what) { + CAPTURE(what); + auto [dk, dv] = dense.fetch(0); + auto [qk, qv] = quant.fetch_dequantized(0); + CHECK(same(qk, qdq(dk, qc))); + CHECK(same(qv, qdq(dv, qc))); + CHECK(quant.idx() == dense.idx()); + CHECK(quant.batch_size() == dense.batch_size()); + CHECK(read_ints(quant.offset()) == read_ints(dense.offset())); + CHECK(read_ints(quant.left_padding()) == read_ints(dense.left_padding())); + }; + + dense.pad_dummies(2); + quant.pad_dummies(2); + check_equal("pad_dummies"); + + dense.filter({0, 1}); + quant.filter({0, 1}); + check_equal("filter trims dummies"); + + // Admit a freshly prefilled pair of rows with a different length. + BatchKVCache dense_in(1, {0, 2}); + BatchKVCache quant_in(1, {0, 2}, qc); + mx::array k2 = varied(2, 1, 7, kD, 0.3f), v2 = varied(2, 1, 7, kD, 0.4f); + dense_in.update_and_fetch(0, k2, v2); + quant_in.update_and_fetch_quantized(0, k2, v2); + dense_in.advance(7); + quant_in.advance(7); + + dense.merge(dense_in); + quant.merge(quant_in); + check_equal("merge right-justifies and concatenates"); + + dense.filter({1, 2}); + quant.filter({1, 2}); + check_equal("filter drops the common left padding"); + + // eval_state must materialize every triplet component without error. + quant.eval_state(); +} + +TEST_CASE("merge rejects mismatched KV quantization configs") { + BatchKVCache a(1, {0}, KVQuantConfig{8, 64}); + BatchKVCache b(1, {0}); + CHECK_THROWS(a.merge(b)); +} diff --git a/tests/model/quantized_sdpa_test.cpp b/tests/model/quantized_sdpa_test.cpp new file mode 100644 index 0000000..4fb59ea --- /dev/null +++ b/tests/model/quantized_sdpa_test.cpp @@ -0,0 +1,107 @@ +// quantized_sdpa vs fast::scaled_dot_product_attention over the dequantized +// K/V — the same math modulo accumulation order, so fp16-tolerance agreement +// gates the hand-rolled quantized path (and proves this MLX pin broadcasts +// quantized_matmul over the 5-D GQA shapes). No model required. +#include + +#include +#include +#include +#include + +#include "model/sdpa.h" + +#include "mlx/fast.h" +#include "mlx/ops.h" +#include "mlx/transforms.h" + +using namespace mlxforge; +namespace mx = mlx::core; + +namespace { + +constexpr int kD = 64; // head_dim (a multiple of the group size) + +// Deterministic varied fp16 values. +mx::array varied(int B, int H, int L, int D, float phase) { + mx::array a = mx::arange(static_cast(B * H * L * D)); + a = mx::sin(mx::add(mx::multiply(a, mx::array(0.61f)), mx::array(phase))); + return mx::astype(mx::reshape(a, {B, H, L, D}), mx::float16); +} + +QuantizedKV quantize_kv(const mx::array& x, const KVQuantConfig& qc) { + std::vector t = mx::quantize(x, qc.group_size, qc.bits); + return {t[0], t[1], t[2]}; +} + +mx::array dequant(const QuantizedKV& t, const KVQuantConfig& qc) { + return mx::dequantize(t.w, t.scales, t.biases, qc.group_size, qc.bits); +} + +void check_close(const mx::array& actual, const mx::array& expected) { + REQUIRE(actual.shape() == expected.shape()); + mx::array diff = mx::max(mx::abs(mx::subtract(mx::astype(actual, mx::float32), + mx::astype(expected, mx::float32)))); + mx::eval(diff); + CHECK(diff.item() < 2e-2f); +} + +// One comparison: quantized_sdpa over triplets vs fast SDPA over the +// dequantized K/V (so quantization error itself cancels out). +void compare(int B, int n_q_heads, int n_kv_heads, int L, int S, int bits, + const std::string& mask_mode, const std::optional& mask) { + CAPTURE(B); + CAPTURE(n_q_heads); + CAPTURE(n_kv_heads); + CAPTURE(L); + CAPTURE(S); + CAPTURE(bits); + const KVQuantConfig qc{bits, 64}; + const float scale = 1.0f / std::sqrt(static_cast(kD)); + + mx::array q = varied(B, n_q_heads, L, kD, 0.1f); + QuantizedKV k = quantize_kv(varied(B, n_kv_heads, S, kD, 0.2f), qc); + QuantizedKV v = quantize_kv(varied(B, n_kv_heads, S, kD, 0.3f), qc); + + mx::array got = quantized_sdpa(q, k, v, scale, mask_mode, mask, qc.group_size, qc.bits); + mx::array want = mx::fast::scaled_dot_product_attention(q, dequant(k, qc), dequant(v, qc), + scale, mask_mode, mask); + check_close(got, want); +} + +// Additive fp16 left-padding mask (B, 1, L, S), as batch_mask() builds it. +mx::array left_pad_mask(const std::vector& left_padding, int L, int S) { + const int B = static_cast(left_padding.size()); + mx::array lp(left_padding.data(), {B}, mx::int32); + mx::array kpos = mx::arange(0, S, 1, mx::int32); + mx::array valid = mx::less_equal(mx::reshape(lp, {B, 1, 1, 1}), mx::reshape(kpos, {1, 1, 1, S})); + valid = mx::broadcast_to(valid, {B, 1, L, S}); + const float ninf = -std::numeric_limits::infinity(); + return mx::where(valid, mx::array(0.0f, mx::float16), mx::array(ninf, mx::float16)); +} + +} // namespace + +TEST_CASE("quantized_sdpa matches fast SDPA over dequantized K/V") { + for (int bits : {8, 4}) { + // No GQA: causal prefill and unmasked decode. + compare(1, 2, 2, 8, 8, bits, "causal", std::nullopt); + compare(1, 2, 2, 1, 16, bits, "", std::nullopt); + // GQA (n_repeats=4): the 5-D quantized_matmul broadcast path. + compare(1, 8, 2, 8, 8, bits, "causal", std::nullopt); + compare(1, 8, 2, 1, 24, bits, "", std::nullopt); + // Bottom-right-aligned causal over a longer cached history. + compare(1, 8, 2, 4, 12, bits, "causal", std::nullopt); + } +} + +TEST_CASE("quantized_sdpa applies the batched additive mask under GQA") { + // B=2 with n_repeats=4 specifically exercises the (B,1,1,L,S) mask reshape: + // without it, trailing-axis broadcasting would align B against n_repeats. + for (int bits : {8, 4}) { + const int B = 2, L = 1, S = 16; + mx::array mask = left_pad_mask({3, 0}, L, S); + compare(B, 8, 2, L, S, bits, "", mask); + compare(B, 4, 4, L, S, bits, "", mask); // and the no-GQA masked path + } +}