Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,21 @@ reference/.venv/bin/python reference/dump_ref.py
math.
- **Decode-with-cache vs full-recompute logits differ by fp16 accumulation
order** — compare argmax / exact tokens, not raw logits at tight tolerance.
- **Quantized KV (kv_bits 8|4) mirrors mlx-lm's QuantizedKVCache** — triplet
storage quantized at write time (`cache/kv_quant`), attention via the
hand-rolled `quantized_sdpa` (`model/sdpa`, a port of mlx_lm base.py; MLX has
no fused quantized SDPA kernel). Three traps: (1) the batched additive mask
must be reshaped `(B,1,N,T)→(B,1,1,N,T)` under GQA, and masked columns must be
**overridden** with `finfo(fp16).min`, never added — a fully-masked left-pad
row makes NaN that `+(-inf)` cannot cancel; (2) quantized matmuls are
**fusion-context-sensitive** (~1 logit shift between lazy and materialized
inputs — mlx-lm disagrees with itself across graph contexts), so the golden
gates are teacher-forced and margin-gated (`greedy_gaps_kvq*.npy`), never raw
exact-stream asserts; (3) both caches deliberately share the block-grow +
`slice_update` storage writer (`update_kv_components`) — buffer strides
affect kernel accumulation order. Engine-wide setting, default off;
vision/hybrid models are rejected at engine creation (no silent fp16
fallback).
- **Qwen3-VL interleaved M-RoPE can't use `fast::rope`** (it takes a 1D offset,
not 3D `(t,h,w)` positions). `Qwen3VLModel` hand-rolls a half-split rotation
with a per-frequency t/h/w selector; text tokens have `t==h==w` so it reduces to
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ add_library(mlxforge_core STATIC
src/core/gguf.cpp
src/core/model_source.cpp
src/model/decoder_model.cpp
src/model/sdpa.cpp
src/model/qwen3.cpp
src/model/qwen3_moe.cpp
src/model/qwen3_5.cpp
Expand All @@ -66,6 +67,7 @@ add_library(mlxforge_core STATIC
src/model/model_factory.cpp
src/cache/kv_cache.cpp
src/cache/batch_kv_cache.cpp
src/cache/kv_quant.cpp
src/cache/kv_budget.cpp
src/sample/sampler.cpp
src/sample/json_grammar.cpp
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ the C-ABI / Node quickstart.
caller of `eval`/`async_eval`; exactly **one `async_eval` per decode step** over the
whole batch, with batch-size bucketing.
- **KV cache** — single-sequence and batched (`BatchKVCache`), left-padded, grown in
256-token blocks, with `filter` (eviction) / `merge` (admission).
256-token blocks, with `filter` (eviction) / `merge` (admission). Optional
**KV-cache quantization** (`--kv-bits 8|4`, default fp16): mlx-lm-matching quantized
storage + attention for ~1.9×/~3.6× less cache memory — including the active
continuous-decode batch, which no other MLX server quantizes.
- **Sampling as graph ops** — greedy, temperature, top-k, top-p (no host readback).
- **Embeddings** — `engine.embed` runs the decoder to its final hidden states, pools
(mean or last-token) and L2-normalizes. **Qwen3-Embedding** is first-class. Exposed
Expand Down
14 changes: 9 additions & 5 deletions apps/mlxforge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ void print_help() {
" --max-ctx <N> max prompt length in tokens (default 8192)\n"
" --max-waiting <N> max queued requests (default 256)\n"
" --kv-budget <B> KV cache budget in bytes, 0 = unbounded (default 0)\n"
" --kv-bits <N> KV cache quantization: 0 = fp16, 8 or 4 (default 0)\n"
" -h, --help show this help and exit\n"
"\n"
"The model may be given via -m or the config file's \"model\" key.\n"
"Config precedence (low to high): defaults < config file < env vars < CLI flags.\n"
"Env vars: MLXFORGE_HOST, MLXFORGE_PORT, MLXFORGE_MAX_CTX, MLXFORGE_MAX_WAITING, "
"MLXFORGE_KV_BUDGET.");
"MLXFORGE_KV_BUDGET, MLXFORGE_KV_BITS.");
std::fflush(stdout);
}

Expand Down Expand Up @@ -126,8 +127,11 @@ int main(int argc, char** argv) {
// engine boundary; the server below is just one consumer of it.
std::unique_ptr<mlxforge::Engine> engine;
try {
engine = std::make_unique<mlxforge::Engine>(
mlxforge::EngineConfig{sc.model_dir, sc.max_waiting});
mlxforge::EngineConfig ec;
ec.model_spec = sc.model_dir;
ec.max_waiting = sc.max_waiting;
ec.kv_bits = sc.kv_bits;
engine = std::make_unique<mlxforge::Engine>(std::move(ec));
} catch (const std::exception& e) {
mlxforge::log::error("model error: {}", e.what());
return 2;
Expand Down Expand Up @@ -163,8 +167,8 @@ int main(int argc, char** argv) {
std::signal(SIGTERM, on_signal);

// Info log: server has started, print bind details and config bounds.
mlxforge::log::info("mlxforge serving on http://{}:{} (max_ctx={} max_waiting={})",
sc.host, sc.port, sc.max_ctx, sc.max_waiting);
mlxforge::log::info("mlxforge serving on http://{}:{} (max_ctx={} max_waiting={} kv_bits={})",
sc.host, sc.port, sc.max_ctx, sc.max_waiting, sc.kv_bits);

// Run the server's request loop (blocks until stop() is called).
server.listen(sc.host, sc.port);
Expand Down
27 changes: 20 additions & 7 deletions apps/mlxforge_cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
// mlxforge-cli dump-weights <dir>
// - Loads a model's weights from the supplied directory, prints key/shape/dtype for each tensor,
// asserts that all tensors are fp16, and reports the peak resident memory used.
// mlxforge-cli generate <model> <prompt> [max_tokens] [--logprobs [N]]
// mlxforge-cli generate <model> <prompt> [max_tokens] [--logprobs [N]] [--kv-bits N]
// - Runs greedy single-stream generation: pre-fills the prompt (as raw text using the chat template
// or as a pre-tokenized .npy of ids), then streams the detokenized text to stdout until EOS or
// max_tokens. With --logprobs [N], each emitted token's log-prob (and its N most-likely
// alternatives) is printed to stderr after generation; stdout stays the generated text.
// --kv-bits 8|4 stores the KV cache quantized (the manual harness for the quantized path).
// mlxforge-cli bench <model> [max_tokens] [runs]
// - Repeatable throughput benchmark over a fixed prompt: one discarded warmup run, then `runs`
// timed runs (defaults: max_tokens=128, runs=3) reporting time-to-first-token and decode tok/s.
Expand Down Expand Up @@ -181,7 +182,7 @@ std::string show_token(const std::string& s) {
// `top_logprobs` mirrors the engine knob: -1 = off; 0 = each token's own log-prob;
// N > 0 = also its N most-likely alternatives (printed to stderr after generation).
int run_generate(const std::string& spec, const std::string& prompt_arg, int max_tokens,
int top_logprobs = -1) {
int top_logprobs = -1, int kv_bits = 0) {
// Resolve and load the model (GGUF file or safetensors dir; downloads if needed)
LoadedModel lm = load_for_inference(spec);
mlxforge::DecoderModel& model = *lm.model;
Expand Down Expand Up @@ -212,7 +213,7 @@ int run_generate(const std::string& spec, const std::string& prompt_arg, int max
std::string piece = detok.add(id);
std::fwrite(piece.data(), 1, piece.size(), stdout);
std::fflush(stdout);
}, top_logprobs);
}, top_logprobs, mlxforge::KVQuantConfig{kv_bits, 64});

// Output any final detokenized tail remaining in the streaming detokenizer
std::string tail = detok.finish();
Expand Down Expand Up @@ -370,25 +371,37 @@ int main(int argc, char** argv) {
if (argc < 4) {
std::fprintf(stderr,
"usage: mlxforge-cli generate <model_dir> <prompt_ids.npy> [max_tokens] "
"[--logprobs [N]]\n");
"[--logprobs [N]] [--kv-bits N]\n");
return 2;
}
// Positional [max_tokens] (default 64) plus an optional --logprobs [N] flag (N
// alternatives, default 0 = the chosen token's own log-prob only).
// Positional [max_tokens] (default 64), an optional --logprobs [N] flag (N
// alternatives, default 0 = the chosen token's own log-prob only), and an
// optional --kv-bits N (0 = fp16 cache, 8 or 4 = quantized).
int max_tokens = 64;
int top_logprobs = -1;
int kv_bits = 0;
for (int i = 4; i < argc; ++i) {
const std::string a = argv[i];
if (a == "--logprobs") {
if (i + 1 < argc && std::isdigit(static_cast<unsigned char>(argv[i + 1][0])))
top_logprobs = std::stoi(argv[++i]);
else
top_logprobs = 0;
} else if (a == "--kv-bits") {
if (i + 1 >= argc) {
std::fprintf(stderr, "error: --kv-bits needs a value (0, 4, or 8)\n");
return 2;
}
kv_bits = std::stoi(argv[++i]);
if (kv_bits != 0 && kv_bits != 4 && kv_bits != 8) {
std::fprintf(stderr, "error: --kv-bits must be 0, 4, or 8\n");
return 2;
}
} else {
max_tokens = std::stoi(a);
}
}
return run_generate(argv[2], argv[3], max_tokens, top_logprobs);
return run_generate(argv[2], argv[3], max_tokens, top_logprobs, kv_bits);
}
if (cmd == "image") {
// Vision-language generation: describe / answer about an image.
Expand Down
9 changes: 9 additions & 0 deletions bindings/node/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@
export interface EngineOptions {
/** Max queued requests before submit is rejected (default 256). */
maxWaiting?: number;
/**
* KV-cache quantization bits (engine-wide): 0 (default) keeps the cache
* dense fp16; 8 is near-lossless at ~1.9x less cache memory; 4 is ~3.6x.
* Unsupported models (vision-language, hybrid Qwen3.5) fail engine creation
* rather than silently falling back to fp16.
*/
kvBits?: 0 | 4 | 8;
/** Quantization group size (default 64; must divide the model's head_dim). */
kvGroupSize?: number;
}

export interface SamplingOptions {
Expand Down
9 changes: 7 additions & 2 deletions bindings/node/src/addon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,20 @@ class EngineWrap : public Napi::ObjectWrap<EngineWrap> {
throw Napi::TypeError::New(env, "new Engine(spec, opts?) requires a model spec string");
std::string spec = info[0].As<Napi::String>().Utf8Value();

mlxforge_engine_opts opts = {};
mlxforge_engine_opts2 opts = {};
opts.struct_size = sizeof(opts);
if (info.Length() >= 2 && info[1].IsObject()) {
Napi::Object o = info[1].As<Napi::Object>();
if (o.Has("maxWaiting") && o.Get("maxWaiting").IsNumber())
opts.max_waiting = o.Get("maxWaiting").As<Napi::Number>().Int32Value();
if (o.Has("kvBits") && o.Get("kvBits").IsNumber())
opts.kv_bits = o.Get("kvBits").As<Napi::Number>().Int32Value();
if (o.Has("kvGroupSize") && o.Get("kvGroupSize").IsNumber())
opts.kv_group_size = o.Get("kvGroupSize").As<Napi::Number>().Int32Value();
}

char* err = nullptr;
eng_ = mlxforge_engine_create(spec.c_str(), &opts, &err);
eng_ = mlxforge_engine_create2(spec.c_str(), &opts, &err);
if (!eng_) {
std::string msg = err ? err : "failed to create engine";
mlxforge_string_free(err);
Expand Down
1 change: 1 addition & 0 deletions cmake/abi-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mlxforge_abi_version
mlxforge_embed
mlxforge_embed_ex
mlxforge_engine_create
mlxforge_engine_create2
mlxforge_engine_free
mlxforge_engine_model_name
mlxforge_engine_ready
Expand Down
1 change: 1 addition & 0 deletions doc/applications.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ with environment-variable fallbacks (`server/config`):
| `--max-ctx` | `MLXFORGE_MAX_CTX` | `8192` | Reject prompts longer than this → `400`. |
| `--max-waiting` | `MLXFORGE_MAX_WAITING` | `256` | Bounded waiting queue → `429` on overflow. |
| `--kv-budget` | `MLXFORGE_KV_BUDGET` | `0` (unbounded) | KV-memory admission budget in bytes. |
| `--kv-bits` | `MLXFORGE_KV_BITS` | `0` (fp16) | KV-cache quantization: `8` (~1.9× less cache memory, near-lossless) or `4` (~3.6×). Unsupported models (vision, hybrid Qwen3.5) fail startup rather than silently falling back. |

### Logging

Expand Down
37 changes: 36 additions & 1 deletion doc/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,39 @@ of `B` sequences each reaching `max_len + max_new` tokens is refused/queued if i
would exceed the configured `--kv-budget`. Combined with the bounded waiting
queue (which returns `429` on overflow), this is the real OOM defence.

## KV-cache quantization

`--kv-bits 8|4` (engine option `kv_bits`; default 0 = dense fp16) stores the KV
cache quantized, cutting its memory ~1.9× (8-bit, near-lossless) or ~3.6×
(4-bit). The port mirrors mlx-lm exactly:

- **Storage** (`cache/kv_quant`): each cached K/V tensor is the `mx::quantize`
triplet — packed uint32 words plus per-group fp16 scales and biases
(group size 64) — quantized **at write time**, per position, so prefill
chunking cannot change stored values. Both `KVCache` and `BatchKVCache` hold
per-layer component vectors (1 array dense, 3 quantized); all batch surgery
(`filter`/`merge`/`pad_dummies`, block growth) runs per component unchanged.
- **Attention** (`model/sdpa`): MLX has no fused quantized SDPA kernel, so
`quantized_sdpa` ports `mlx_lm/models/base.py` op-for-op — `quantized_matmul`
for the scores and the output, GQA via a `(B, n_kv, n_rep, L, D)` reshape,
precise softmax. `sdpa_with_cache` is the dispatch seam every model attention
call site uses (dense fast-kernel vs quantized path, by the cache's config).
- **Setting scope**: engine-wide, never per-request — the batched cache's
storage is physically shared across rows. Unsupported setups (vision-language
and hybrid Qwen3.5 models, which have no quantized golden reference yet;
group sizes that don't divide `head_dim`) **fail engine creation**; there is
no silent fp16 fallback.
- **Gating**: teacher-forced greedy walks against mlx-lm `QuantizedKVCache`
streams (Llama + Qwen3, 8- and 4-bit), asserting token equality at every step
whose reference top-2 margin clears the fusion-context noise (quantized
matmuls shift ~1 logit between lazy and materialized inputs, so bit-exact
cross-implementation gating is unsound); plus an exact batched-vs-single-
stream coherence gate.

The per-token budget figure adjusts accordingly: a K-or-V head row is
`head_dim × bits/8` packed bytes plus a fp16 scale and bias per group (D=64/g=64:
68 B at 8-bit, 36 B at 4-bit, vs 128 B fp16).

## Module map

Source lives under `src/`, grouped by responsibility. Tests mirror the module
Expand All @@ -186,7 +219,9 @@ path under `tests/`.
| `model/` | The transformer: `DecoderModel` base (embedding, RMSNorm, RoPE, GQA SDPA, SwiGLU, LM head; fp16 and quantized paths; single-stream and batched forward) with `LlamaModel`/`Qwen3Model`/`Qwen3MoeModel` subclasses and a `create_model` factory. |
| `cache/kv_cache` | Single-sequence KV cache (the simplest prefill/decode split). |
| `cache/batch_kv_cache` | Batched, left-padded KV cache: `update_and_fetch`, `filter` (evict), `merge` (admit), `pad_dummies` (bucketing). |
| `cache/kv_budget` | KV memory projection / admission gate. |
| `cache/kv_quant` | Quantized-KV shared types (`KVQuantConfig`, triplets) + the block-grow component writer both caches use. |
| `cache/kv_budget` | KV memory projection / admission gate (fp16 and quantized accounting). |
| `model/sdpa` | Cache-aware SDPA dispatch: dense fast kernel vs the hand-rolled quantized path (mlx-lm port). |
| `sample/sampler` | greedy / temperature / top-k / top-p, all as MLX graph ops. |
| `scheduler/request` | The `Request` struct and the bounded, blocking `TokenQueue`. |
| `scheduler/scheduler` | The waiting queue + worker handoff (mutex + condition variable). |
Expand Down
4 changes: 4 additions & 0 deletions doc/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ typedef struct mlxforge_request mlxforge_request;
// Create one engine; it owns the GPU worker thread and the batching scheduler.
mlxforge_engine* eng = mlxforge_engine_create("mlx-community/Llama-3.2-1B-Instruct-4bit",
/*opts=*/NULL, &err);
// Or with extended options (ABI v6+): a quantized KV cache cuts the dominant
// growing allocation ~1.9x (8-bit, near-lossless) or ~3.6x (4-bit).
// mlxforge_engine_opts2 opts = { .struct_size = sizeof(opts), .kv_bits = 8 };
// eng = mlxforge_engine_create2(spec, &opts, &err);
while (!mlxforge_engine_ready(eng)) { /* model still loading on the worker thread */ }

// Submit many requests concurrently — they share one batched engine. This is the
Expand Down
33 changes: 33 additions & 0 deletions reference/dump_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,33 @@ def _dump_greedy(model, save, prompt_ids):
save("greedy_tokens", np.array(greedy, dtype=np.int32))


def _dump_greedy_quantized(model, save, prompt_ids, bits, group_size=64):
"""Greedy continuation with mlx-lm's QuantizedKVCache (prefill + cached
decode). The oracle for the C++ quantized-KV path, which ports the same
triplet storage and quantized_matmul SDPA.

Besides the tokens, the per-step top-2 logit margin is dumped: 4-bit
quantized matmuls at head_dim 128 are fusion-context-sensitive (the same
mlx-lm function on the same values shifts by ~1 logit depending on whether
its inputs are materialized or lazy), so the C++ gate for such combinations
is teacher-forced and asserts token equality only where the margin exceeds
that noise — exact gating would be asserting kernel fusion, not math."""
from mlx_lm.models.cache import QuantizedKVCache

cache = [QuantizedKVCache(group_size=group_size, bits=bits) for _ in model.layers]
logits = model(mx.array(prompt_ids, dtype=mx.int32)[None], cache=cache)[:, -1, :]
greedy, gaps = [], []
for i in range(GREEDY_MAX_NEW):
top2 = mx.sort(logits.astype(mx.float32), axis=-1)[0, -2:]
gaps.append(float(top2[1] - top2[0]))
nxt = int(mx.argmax(logits, axis=-1).item())
greedy.append(nxt)
if i + 1 < GREEDY_MAX_NEW:
logits = model(mx.array([[nxt]], dtype=mx.int32), cache=cache)[:, -1, :]
save(f"greedy_tokens_kvq{bits}", np.array(greedy, dtype=np.int32))
save(f"greedy_gaps_kvq{bits}", np.array(gaps, dtype=np.float32))


def _write_manifest(fixtures_dir, manifest, tok):
manifest["eos_token_ids"] = sorted(int(x) for x in tok.eos_token_ids)
with open(os.path.join(fixtures_dir, "manifest.json"), "w") as f:
Expand Down Expand Up @@ -638,6 +665,12 @@ def save(name, arr):

# --- Greedy token stream (full-recompute reference; gates the greedy decode stream) ---
_dump_greedy(model, save, prompt_id_lists[0])

# Quantized-KV greedy streams (8- and 4-bit): gate the C++ quantized cache +
# quantized SDPA end to end against mlx-lm's QuantizedKVCache.
for bits in (8, 4):
_dump_greedy_quantized(model, save, prompt_id_lists[0], bits)

_write_manifest(FIXTURES_DIR, manifest, tok)


Expand Down
Binary file added reference/fixtures/greedy_gaps_kvq4.npy
Binary file not shown.
Binary file added reference/fixtures/greedy_gaps_kvq8.npy
Binary file not shown.
Binary file added reference/fixtures/greedy_tokens_kvq4.npy
Binary file not shown.
Binary file added reference/fixtures/greedy_tokens_kvq8.npy
Binary file not shown.
24 changes: 24 additions & 0 deletions reference/fixtures/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,30 @@
20
],
"dtype": "int32"
},
"greedy_tokens_kvq8": {
"shape": [
20
],
"dtype": "int32"
},
"greedy_gaps_kvq8": {
"shape": [
20
],
"dtype": "float32"
},
"greedy_tokens_kvq4": {
"shape": [
20
],
"dtype": "int32"
},
"greedy_gaps_kvq4": {
"shape": [
20
],
"dtype": "float32"
}
},
"eos_token_ids": [
Expand Down
Binary file added reference/fixtures_qwen3/greedy_gaps_kvq4.npy
Binary file not shown.
Binary file added reference/fixtures_qwen3/greedy_gaps_kvq8.npy
Binary file not shown.
Binary file added reference/fixtures_qwen3/greedy_tokens_kvq4.npy
Binary file not shown.
Binary file added reference/fixtures_qwen3/greedy_tokens_kvq8.npy
Binary file not shown.
24 changes: 24 additions & 0 deletions reference/fixtures_qwen3/manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,30 @@
20
],
"dtype": "int32"
},
"greedy_tokens_kvq8": {
"shape": [
20
],
"dtype": "int32"
},
"greedy_gaps_kvq8": {
"shape": [
20
],
"dtype": "float32"
},
"greedy_tokens_kvq4": {
"shape": [
20
],
"dtype": "int32"
},
"greedy_gaps_kvq4": {
"shape": [
20
],
"dtype": "float32"
}
},
"eos_token_ids": [
Expand Down
Loading
Loading