diff --git a/doc/whitepaper/mlxforge-whitepaper.tex b/doc/whitepaper/mlxforge-whitepaper.tex index 267e3b5..e8599f8 100644 --- a/doc/whitepaper/mlxforge-whitepaper.tex +++ b/doc/whitepaper/mlxforge-whitepaper.tex @@ -92,7 +92,7 @@ \vspace{2cm} {\large Technical White Paper\par} \vspace{0.3cm} - {\normalsize Engine version: continuous-batching core, C ABI v5\par} + {\normalsize Engine version: continuous-batching core, C ABI v7\par} \vfill {\small This document consolidates the design of the \code{mlxforge} engine: the product thesis, the threading and continuous-batching model, the mathematics of @@ -111,7 +111,12 @@ the Metal backend. It serves LLaMA-family decoder models (Llama-3.2, Qwen3 dense/MoE, Qwen3.5 hybrid) and the Qwen3-VL vision-language model, with \emph{continuous batching}: many concurrent requests share one resident model and -one GPU worker, dynamically admitted and evicted from an active batch. +one GPU worker, dynamically admitted and evicted from an active batch. The KV +cache can be stored quantized (8- or 4-bit, mirroring \code{mlx-lm}'s +\code{QuantizedKVCache}), and an opt-in prefix cache pools immutable KV blocks +keyed by a salted chain hash of the token prefix --- with an SSD spill tier that +survives engine restarts --- so requests sharing a system prompt or a +conversation history skip recomputing the shared span. The engine occupies a specific gap in the Apple MLX ecosystem. Single-stream libraries (\code{node-mlx}, Apple's \code{MLXLLM}) cannot batch; batched servers @@ -243,7 +248,9 @@ \section{Module map} chat-template rendering; streaming detokenization. \\ \code{model/} & The transformer: \code{DecoderModel} base plus family subclasses; \code{vision/} holds the ViT encoder. \\ -\code{cache/} & Single-sequence and batched KV caches; the KV memory budget. \\ +\code{cache/} & Single-sequence and batched KV caches; quantized KV storage; + the prefix-cache block pool and its SSD spill tier; the KV + memory budget. \\ \code{sample/} & Sampling (greedy/temperature/top-k/top-p/min-p/penalties), log-probabilities, and JSON-grammar constrained decoding. \\ \code{scheduler/} & The thread-safe request queue and the \code{Request} struct @@ -479,7 +486,11 @@ \section{Memory admission gate} \end{equation} where the factor 2 accounts for keys and values. For Llama-3.2-1B ($n_{\text{layers}}=16$, $n_{\text{kv\,heads}}=8$, $d_{\text{head}}=64$) this is -$2\cdot16\cdot8\cdot64\cdot2 = 32{,}768$ bytes ($32$\,KiB) per token. A batch whose +$2\cdot16\cdot8\cdot64\cdot2 = 32{,}768$ bytes ($32$\,KiB) per token. Under KV-cache +quantization (Section~\ref{sec:kvquant}) the per-head row cost changes to +$d\cdot b/8$ packed bytes plus an fp16 scale and bias per group: with $d=64$ and +group size 64, 68 bytes at 8-bit and 36 bytes at 4-bit versus 128 bytes fp16, +and the budget projection accounts for it. A batch whose projected footprint would exceed the configured budget (\code{src/cache/kv\_budget}) is refused. Together with the bounded waiting queue (which returns \code{429} on overflow), this keeps the engine from running out of @@ -689,7 +700,9 @@ \section{Batched cache: left-padded and contiguous} serves both efficiency and multi-user serving. It is \emph{left-padded and contiguous}, not paged: MLX's C++ surface has no paged-attention primitive, and SDPA wants contiguous K/V. At the $\sim$1B parameter scale the padding waste is -acceptable, and the contiguous layout keeps the kernel simple. The cache tracks: +acceptable, and the contiguous layout keeps the kernel simple. (Prefix sharing is +layered on top through a block pool rather than through paging; see +Section~\ref{sec:prefixcache}.) The cache tracks: \begin{itemize}[nosep] \item \code{idx}: the populated sequence length (physical write position); \item \code{offset}: a per-row $(B,)$ RoPE position, which can differ from @@ -722,6 +735,131 @@ \section{Batch surgery} (Chapter~\ref{ch:batching}). \end{itemize} +\section{Quantized KV storage} +\label{sec:kvquant} +The engine option \code{kv\_bits} (default 0, i.e.\ dense fp16; 8 or 4 to enable) +stores the KV cache quantized, cutting its memory by $\sim$1.9$\times$ at 8-bit +(near-lossless) or $\sim$3.6$\times$ at 4-bit. The implementation +(\code{src/cache/kv\_quant}) deliberately mirrors \code{mlx-lm}'s +\code{QuantizedKVCache}: + +\begin{itemize}[nosep] + \item \textbf{Triplet storage, quantized at write time.} Each cached K or V + tensor is the 3-tuple \code{mx::quantize} produces: packed + \code{uint32} words plus per-group fp16 scales and biases (group size 64, + which must divide $d$). Quantization happens per position as tokens are + written, so prefill chunking can never change the stored values. Both + \code{KVCache} and \code{BatchKVCache} hold per-layer \emph{component + vectors} (one array dense, three quantized); all batch surgery + (\code{filter}/\code{merge}/\code{pad\_dummies}, block growth) runs per + component unchanged. + \item \textbf{Hand-rolled quantized attention.} MLX has no fused quantized SDPA + kernel, so \code{quantized\_sdpa} (\code{src/model/sdpa}) ports + \code{mlx\_lm/models/base.py} op-for-op: \code{quantized\_matmul} for the + scores and the output, GQA via a $(B, n_{kv}, n_{\text{rep}}, L, d)$ + reshape, precise softmax. \code{sdpa\_with\_cache} is the dispatch seam + every model attention call site uses, selecting the dense fast kernel or + the quantized path by the cache's configuration. + \item \textbf{Engine-wide scope, no silent fallback.} The batched cache's + storage is physically shared across rows, so the setting cannot be + per-request. Unsupported setups (vision-language and hybrid Qwen3.5 + models, which have no quantized golden reference yet; group sizes that do + not divide $d$) fail engine creation rather than falling back to fp16. +\end{itemize} + +Three numerical traps shaped the implementation. First, under GQA the batched +additive mask must be reshaped $(B,1,N,T)\to(B,1,1,N,T)$, and masked columns +must be \emph{overridden} with $\min(\text{fp16})$, never added: a fully-masked +left-pad row produces NaN that adding $-\infty$ cannot cancel. Second, quantized +matmuls are \emph{fusion-context-sensitive}: the same matmul shifts by $\sim$1 +logit between lazy and materialized inputs, and \code{mlx-lm} disagrees with +itself across graph contexts, so bit-exact cross-implementation gating is +unsound. The golden gates are therefore teacher-forced and \emph{margin-gated}: +token equality is asserted at every step whose reference top-2 logit margin +clears the fusion-context noise, plus an exact batched-versus-single-stream +coherence gate. Third, both caches share the block-grow $+$ +\code{slice\_update} storage writer (\code{update\_kv\_components}) because +buffer shapes and strides affect kernel accumulation order; the exact-token +gates depend on the layouts matching \code{mlx-lm} bit-for-bit. + +\section{Prefix cache: block-pool KV storage} +\label{sec:prefixcache} +The engine option \code{prefix\_cache} (default off) reuses K/V across requests +that share a token prefix --- the shared-system-prompt and multi-turn +conversation shapes. On a 2048-token shared prefix the warm time-to-first-token +drops $\sim$20$\times$ (measured by \code{mlxforge-cli bench-prefix}); decode +throughput is unchanged. + +The design is \emph{gather-on-admit, not paged attention}. vLLM's +PagedAttention~\citep{kwon2023vllm} runs attention directly over scattered +pages, but MLX has no paged SDPA kernel and \code{mlx-lm} has no paged reference +to gate one against, so the decode batch stays the contiguous left-padded +\code{BatchKVCache} of this chapter. The \emph{pages} live in a pool instead, +and matched pages are copied into a row's cache on admission --- cheap under +Apple Silicon's unified memory next to the prefill they replace. + +\begin{itemize}[nosep] + \item \textbf{Block pool.} \code{BlockPool} (\code{src/cache/block\_pool}) + holds immutable \code{KVBlock}s of \code{kv\_block\_size} tokens (default + 256; all layers, the same dense or quantized component vectors the caches + store), LRU-evicted under a configurable byte budget. Each block is keyed + by a \emph{chain hash}: an FNV-1a-64 hash of the block's own token ids + chained onto the previous block's key, so a key identifies the + \emph{entire} token prefix up to the block's end --- two prompts share a + block only if they share every token before it. Keys are salted with the + model fingerprint and storage configuration, so a persisted block can + never cross models or quantization settings. + \item \textbf{Matching and admission.} \code{PrefixCache} + (\code{src/cache/prefix\_cache}) matches a prompt to its longest chain of + consecutive cached full blocks, clamped to $\text{prompt\_len}-1$ tokens + so the admission still produces next-token logits (the last prompt token + is always recomputed --- the same rule vLLM and + SGLang~\citep{zheng2024sglang} use). Matched blocks seed a batch-1 cache + via \code{BatchKVCache::from\_prefix}, written through the standard + block-grow storage writer so the buffer layout matches a cold prefill + (strides are load-bearing, per Section~\ref{sec:kvquant}); only the + suffix is prefilled, and the row then merges into the decode batch like + any single-row admission. + \item \textbf{Harvest is prompt-only.} When a row finishes, only its + \emph{prompt} span is sealed into the pool --- never decode-produced K/V. + Decode-with-cache K/V differs from a recompute by fp16 accumulation order + (the decode-versus-recompute gap of Chapter~\ref{ch:testing}) and + demonstrably flips later greedy choices; prefill-produced K/V is the + proven exact-stable class, so pooling only it keeps the feature's gate + (warm $=$ cold, token-exact) sound. Multi-turn reuse still converges: + the next turn's prompt contains the prior answer as text, so its + (prefix-seeded) prefill recomputes that span once and pools it. Harvested + slices are materialized (\code{mx::contiguous} $+$ eval) so the pool + never pins the batch cache's buffers, and multimodal rows are never + harvested or matched (a token-id hash cannot identify image content or + 3D positions). +\end{itemize} + +Like \code{kv\_bits}, the setting is engine-wide (the pool stores one storage +layout), and hybrid and vision-language models reject it at engine creation. + +\section{SSD spill tier} +\label{sec:spill} +An optional spill directory (\code{kv\_spill\_dir}) adds a second cache level +under the RAM pool (\code{src/cache/block\_store}). Blocks LRU-evicted from RAM +are serialized and written as one file per block (\code{.kvb}, +created \code{0600} since the cache holds conversation content, written +tmp-then-rename); a pool miss revives the file synchronously --- an SSD read of +a few megabytes replaces a far more expensive prefill. The directory is +rescanned at construction, so the prefix cache survives engine restarts, and an +on-disk byte budget LRU-deletes files beyond it. + +The threading split follows the thread-bound-arrays rule +(Chapter~\ref{ch:threading}): \code{BlockStore} itself never touches MLX arrays +--- its asynchronous writer thread and file index handle only raw byte buffers, +while the array$\leftrightarrow$bytes conversions run on the worker thread, +which owns every pooled array. The writer keeps a queued block visible to +\code{get()}/\code{contains()} until its file lands. The versioned on-disk +format embeds the salt, verified on load (any mismatch or truncation is treated +as a plain miss), and its serialize order --- per layer, K components then V --- +is gated by an exact-token spill test, because an order mismatch produced +silent garbage. + % =========================================================================== \chapter{Sampling and Constrained Decoding} \label{ch:sampling} @@ -941,6 +1079,9 @@ \chapter{Quantization} transparently. \end{itemize} +This chapter covers \emph{weight} quantization; the KV cache can independently be +stored quantized at 8 or 4 bits (Section~\ref{sec:kvquant}). + % =========================================================================== \chapter{Tokenizers} \label{ch:tok} @@ -995,7 +1136,7 @@ \section{The boundary} \end{itemize} \section{Append-only versioning} -\code{MLXFORGE\_ABI\_VERSION} is currently 5. The surface is append-only; each +\code{MLXFORGE\_ABI\_VERSION} is currently 7. The surface is append-only; each version added capability without removing symbols (Table~\ref{tab:abi}). The guard \code{scripts/check-abi.sh} enforces two invariants against \code{cmake/abi-baseline.txt}: the baseline symbols remain present (no breaking @@ -1016,6 +1157,11 @@ \section{Append-only versioning} v3 & \code{mlxforge\_submit\_image} (single image). \\ v4 & \code{mlxforge\_image} + \code{mlxforge\_submit\_images} ($N$ images). \\ v5 & \code{mlxforge\_sampling.logprobs} + \code{mlxforge\_request\_logprobs}. \\ +v6 & \code{mlxforge\_engine\_create2} + \code{mlxforge\_engine\_opts2} + (KV-cache quantization: \code{kv\_bits}, \code{kv\_group\_size}). \\ +v7 & \code{mlxforge\_engine\_opts2} prefix-cache fields (\code{prefix\_cache}, + \code{kv\_block\_size}, \code{kv\_pool\_bytes}, \code{kv\_spill\_dir}, + \code{kv\_spill\_bytes}), appended struct-size-gated. \\ \bottomrule \end{tabular} \end{table} @@ -1113,12 +1259,16 @@ \section{Two test tiers} A green \code{ctest} without the model present only exercised the pure-logic units; the numerical and scheduler paths require the model to be downloaded. -\section{Two comparison modes} +\section{Comparison modes} \begin{itemize}[nosep] \item \code{assert\_close}: elementwise allclose at fp16 relative tolerance $\sim$1e-2, comparing in fp32 to avoid rounding in the comparison itself, reporting the first divergent coordinate. \item \code{assert\_tokens\_equal}: exact token-sequence equality. + \item Margin-gated teacher-forced walks (quantized KV): token equality asserted + only at steps whose reference top-2 logit margin clears the + fusion-context noise of quantized matmuls, since \code{mlx-lm} itself is + not bit-stable across graph contexts (Section~\ref{sec:kvquant}). \end{itemize} Decode-with-cache and full-recompute logits differ by fp16 accumulation order, so those paths are compared by $\arg\max$ / exact tokens, not by raw logits at tight @@ -1126,6 +1276,15 @@ \section{Two comparison modes} \code{dump\_ref.py} to emit the intermediate tensor and assert against it. That is how the front-half embedding/post-norm/RoPE'd-$Q/K$ bugs were originally found. +\section{Equivalence gates without new fixtures} +The prefix cache needs no new \code{mlx-lm} fixtures: the cold path is already +golden-gated, and prefix reuse is an engine-internal \emph{equivalence} property. +The gate is warm $=$ cold --- reuse may change speed, never tokens --- enforced +exactly (including through an SSD spill and reload, and across an engine +restart) by \code{tests/scheduler/prefix\_reuse\_test.cpp} and +\code{tests/scheduler/prefix\_spill\_test.cpp}. The prompt-only harvest rule of +Section~\ref{sec:prefixcache} is what keeps this gate sound. + \section{Hardening} Beyond correctness gates, the C ABI has fuzz tests (random/hostile sampling params), endurance tests (long-running stress), and the ABI guard described in @@ -1151,7 +1310,8 @@ \section{HTTP server} \section{CLI} The CLI (\code{apps/mlxforge\_cli.cpp}) is the golden-reference and weight-inspection smoke test, with subcommands: \code{generate} (single-stream -greedy), \code{bench} (TTFT and decode tokens/s), \code{embed} (pooled embeddings), +greedy), \code{bench} (TTFT and decode tokens/s), \code{bench-prefix} (cold- +versus warm-prefix TTFT for the prefix cache), \code{embed} (pooled embeddings), and \code{dump-weights} (every tensor's shape/dtype, fp16 assertion, peak memory). % =========================================================================== @@ -1164,9 +1324,13 @@ \chapter{Conclusion and Future Work} are thread-bound, so a single worker owns the GPU. The discipline: silent numerical error is the enemy, so every sensitive stage is golden-gated. -A few directions remain open. Paged attention (once an MLX primitive exists) would -remove the left-padding waste and enable prefix sharing. Per-row 3D positions in the -batched cache would let vision \emph{prefill} batch too, not just decode. And new +A few directions remain open. Prefix sharing already exists via the +gather-on-admit block pool (Section~\ref{sec:prefixcache}); a true paged +attention (once an MLX primitive exists) would additionally remove the +left-padding waste and the admission-time gather copies. Per-row 3D positions in +the batched cache would let vision \emph{prefill} batch too, not just decode. +Extending the quantized-KV and prefix-cache golden gates to the hybrid and +vision-language families would lift their engine-creation rejections. And new model families can slot in behind the same \code{DecoderModel} hooks. % =========================================================================== @@ -1208,6 +1372,14 @@ \chapter{Glossary} \item[SPSC] Single-producer single-consumer (the bounded token queue). \item[TTFT] Time to first token. \item[GGUF] The llama.cpp universal model file format. + \item[Quantized KV] KV cache stored as \code{mx::quantize} triplets (packed + words + per-group scales/biases), 8- or 4-bit, quantized at write time. + \item[Chain hash] A block key hashing the block's own token ids onto the + previous block's key, identifying the whole prefix up to the block's end. + \item[Block pool] LRU pool of immutable fixed-size KV blocks backing the + prefix cache. + \item[Spill tier] SSD second level under the block pool; one salted, versioned + file per evicted block, rescanned at startup. \end{description} % =========================================================================== diff --git a/doc/whitepaper/references.bib b/doc/whitepaper/references.bib index 1e303bc..51775c2 100644 --- a/doc/whitepaper/references.bib +++ b/doc/whitepaper/references.bib @@ -173,6 +173,13 @@ @inproceedings{yu2022orca year = {2022} } +@inproceedings{zheng2024sglang, + title = {SGLang: Efficient Execution of Structured Language Model Programs}, + author = {Zheng, Lianmin and Yin, Liangsheng and Xie, Zhiqiang and Sun, Chuyue and Huang, Jeff and Yu, Cody Hao and Cao, Shiyi and Kozyrakis, Christos and Stoica, Ion and Gonzalez, Joseph E. and Barrett, Clark and Sheng, Ying}, + booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, + year = {2024} +} + % --------------------------------------------------------------------------- % Tokenization % --------------------------------------------------------------------------- diff --git a/src/runtime/worker.cpp b/src/runtime/worker.cpp index 93c2d36..b9a5e1b 100644 --- a/src/runtime/worker.cpp +++ b/src/runtime/worker.cpp @@ -6,6 +6,7 @@ #include #include "cache/block_store.h" +#include "core/env.h" #include "core/logging.h" #include "model/qwen3_vl.h" #include "model/vision/vit.h" @@ -61,7 +62,11 @@ bool consume(Request& req, int& produced, int id, const TokenLogprob* lp) { Worker::Worker(ModelFactory factory, Scheduler* scheduler, const Tokenizer* tok, KVQuantConfig kv_quant, PrefixCacheConfig prefix) : factory_(std::move(factory)), sched_(scheduler), tok_(tok), kv_quant_(kv_quant), - prefix_cfg_(prefix) {} + prefix_cfg_(prefix), + prefill_chunk_(static_cast(env_long("MLXFORGE_PREFILL_CHUNK", 0))) { + if (prefill_chunk_ > 0) + log::info("worker: EXPERIMENTAL interleaved prefill on (chunk={} tokens)", prefill_chunk_); +} Worker::~Worker() { stop(); } @@ -226,15 +231,17 @@ void Worker::run() { while (true) { std::vector> incoming; - if (reqs_.empty()) { + if (reqs_.empty() && !pending_) { auto r = sched_->next_waiting(); // block until work or stop+drained if (!r) break; incoming.push_back(r); auto more = sched_->take_waiting(kPrefillBatchSize - 1); incoming.insert(incoming.end(), more.begin(), more.end()); - } else { + } else if (!pending_) { incoming = sched_->take_waiting(kPrefillBatchSize); // non-blocking top-up } + // With a prefill in flight, no new admissions are taken: one chunk advances + // per iteration, with a decode step in between (interleaved mode only). try { if (!incoming.empty()) { @@ -249,8 +256,15 @@ void Worker::run() { else if (r->is_multimodal()) admit_multimodal(r); else gen.push_back(std::move(r)); } - if (!gen.empty()) admit(gen); + // Interleaved mode hands cold admissions to the chunked state machine; + // with the prefix cache on (heterogeneous warm suffixes) it falls back + // to the monolithic path. + if (!gen.empty()) { + if (prefill_chunk_ > 0 && !prefix_) start_chunked_prefill(gen); + else admit(gen); + } } + if (pending_) advance_chunked_prefill(); evict_finished(); // a row may finish on its very first token if (reqs_.empty()) continue; @@ -317,6 +331,54 @@ void Worker::admit(const std::vector>& incoming) { } } +void Worker::start_chunked_prefill(const std::vector>& cold) { + // Same left-padding as prefill() (batching.cpp); only the chunk loop is + // spread across worker iterations so decode steps can interleave. + const int B = static_cast(cold.size()); + int p_max = 0; + for (const auto& r : cold) p_max = std::max(p_max, static_cast(r->prompt_ids.size())); + + std::vector left_padding(B); + std::vector padded(static_cast(B) * p_max, 0); + for (int b = 0; b < B; ++b) { + const auto& ids = cold[b]->prompt_ids; + const int pad = p_max - static_cast(ids.size()); + left_padding[b] = pad; + for (size_t j = 0; j < ids.size(); ++j) padded[b * p_max + pad + j] = ids[j]; + } + + auto pending = std::make_unique(PendingPrefill{ + cold, mx::array(padded.data(), {B, p_max}, mx::int32), + std::make_unique(model_->config().n_layers, left_padding, kv_quant_), + p_max, /*pos=*/0}); + pending_ = std::move(pending); + log::debug("worker: chunked prefill started ({} rows, {} tokens, chunk={})", B, p_max, + prefill_chunk_); +} + +void Worker::advance_chunked_prefill() { + PendingPrefill& p = *pending_; + const int B = static_cast(p.reqs.size()); + const int n = std::min(prefill_chunk_, p.p_max - p.pos); + mx::array chunk = mx::slice(p.tokens, {0, p.pos}, {B, p.pos + n}); + mx::array logits = model_->forward(chunk, *p.cache); + p.cache->eval_state(); // same per-chunk materialization as prefill() + p.pos += n; + if (p.pos < p.p_max) return; + + // Final chunk: every row's last real token is at p_max-1, the last column. + const int n_last = logits.shape()[1]; + const int vocab = logits.shape()[2]; + mx::array last = + mx::reshape(mx::slice(logits, {0, n_last - 1, 0}, {B, n_last, vocab}), {B, vocab}); + mx::eval(last); + + if (!cache_) cache_ = std::move(p.cache); + else cache_->merge(*p.cache); + register_rows(p.reqs, last); + pending_.reset(); +} + void Worker::register_rows(const std::vector>& incoming, const mx::array& last_logits) { // Register the new rows before sampling so sample_rows() can read their params, diff --git a/src/runtime/worker.h b/src/runtime/worker.h index 081be28..cb347f3 100644 --- a/src/runtime/worker.h +++ b/src/runtime/worker.h @@ -86,6 +86,21 @@ class Worker { // push each row's token, marking finished rows. void decode_step(); + // EXPERIMENTAL chunked-prefill interleaving (MLXFORGE_PREFILL_CHUNK > 0, + // prefix cache off): a cold admission's prefill advances one chunk per loop + // iteration with a decode step in between, so in-flight rows keep producing + // tokens during long or queued prefills instead of stalling completely. + // Off by default (0): admissions prefill monolithically, exactly as before. + struct PendingPrefill { + std::vector> reqs; + mx::array tokens; // (B, p_max) left-padded prompt ids + std::unique_ptr cache; + int p_max = 0; + int pos = 0; // prompt tokens consumed so far + }; + void start_chunked_prefill(const std::vector>& cold); + void advance_chunked_prefill(); // one chunk; merges + registers rows when done + // Result of sampling the active batch in one graph: the chosen tokens, plus — // for the rows that requested log-probs (params.top_logprobs >= 0) — their // per-row log-prob arrays. All are built into the same graph so one async_eval @@ -156,6 +171,10 @@ class Worker { std::vector> history_; // prompt+generated ids per row (penalties) std::vector rng_keys_; // per-row RNG key, advanced each step + // Interleaved-prefill state (worker thread only; null when idle or feature off). + std::unique_ptr pending_; + int prefill_chunk_ = 0; // from MLXFORGE_PREFILL_CHUNK; 0 = monolithic admits + std::atomic decode_steps_{0}; std::atomic ready_{false};