From babce0aff209ca20b3d67df10360fc444e2e76ff Mon Sep 17 00:00:00 2001 From: Helder Vasconcelos Date: Tue, 9 Jun 2026 22:10:03 +0100 Subject: [PATCH 1/6] server: full multi-turn history + prompt_tokens for image requests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix the two multimodal-serving caveats by separating prompt construction (CPU: template + tokenize, needs only the image dimensions) from generation (worker: decode + ViT + generate): - image_info(): read an encoded image's H×W without decoding (stbi_info, no MLX), so a non-worker thread can size the placeholder run. - image_token_count(h, w, cfg): the <|image_pad|> count from dimensions alone (smart_resize grid / merge²). - generate_multimodal(): new core that generates from an already-templated prompt_ids + raw image (preprocess → ViT → assert the placeholder count matches the merged-patch count → M-RoPE → decode). generate_from_image is now a single-turn wrapper that builds prompt_ids and delegates. - worker.handle_multimodal: use req.prompt_ids when set (server full history), else mm_text (C ABI / CLI single turn). - http_server.make_request: for an image request, render the FULL chat history (system + prior turns), sizing the placeholder run from the image dimensions and attaching it to the last user turn -> req.prompt_ids (+ raw bytes). Errors if the model has no vision tower. This restores the system prompt + multi-turn context on the vision path and makes usage.prompt_tokens correct (it was 0). Verified live: a system + user-with-image request obeys the system instruction and reports prompt_tokens 95. Unit-gated (image_info, image_token_count); full suite green (191). Co-Authored-By: Claude Opus 4.8 --- src/runtime/multimodal_stream.cpp | 45 ++++++++++++++++++++++--------- src/runtime/multimodal_stream.h | 12 +++++++++ src/runtime/worker.cpp | 26 ++++++++++++------ src/server/http_server.cpp | 25 ++++++++++++----- src/vision/image_decode.cpp | 11 ++++++++ src/vision/image_decode.h | 6 +++++ src/vision/preprocess.cpp | 7 +++++ src/vision/preprocess.h | 6 +++++ tests/vision/preprocess_test.cpp | 25 +++++++++++++++++ 9 files changed, 136 insertions(+), 27 deletions(-) diff --git a/src/runtime/multimodal_stream.cpp b/src/runtime/multimodal_stream.cpp index db2a79b..b1d019b 100644 --- a/src/runtime/multimodal_stream.cpp +++ b/src/runtime/multimodal_stream.cpp @@ -2,6 +2,7 @@ #include #include +#include #include "cache/kv_cache.h" #include "sample/sampler.h" @@ -67,10 +68,9 @@ GenerateResult greedy_generate_multimodal(const Qwen3VLModel& model, return result; } -GenerateResult generate_from_image(const Qwen3VLModel& model, const VitEncoder& vit, - const Tokenizer& tokenizer, const std::string& user_text, - const mx::array& image_rgb, int max_tokens, - const std::vector& eos_ids, +GenerateResult generate_multimodal(const Qwen3VLModel& model, const VitEncoder& vit, + const std::vector& prompt_ids, const mx::array& image_rgb, + int max_tokens, const std::vector& eos_ids, const std::function& on_token, const PreprocessConfig* pcfg) { const ModelConfig& cfg = model.config(); @@ -81,19 +81,38 @@ GenerateResult generate_from_image(const Qwen3VLModel& model, const VitEncoder& mx::array grid(pre.grid_thw.data(), {1, 3}, mx::int32); VitEncoder::Output v = vit.forward(pre.pixel_values, grid); - // Prompt: one image whose placeholder count is the collapsed patch count. - const int merge = cfg.vision->merge_unit(); - const int image_tokens = pre.grid_thw[0] * pre.grid_thw[1] * pre.grid_thw[2] / merge; + // The prompt's placeholder run must match the merged-patch count, or the + // feature scatter (merge_image_features) misaligns. + const int merged = pre.grid_thw[0] * pre.grid_thw[1] * pre.grid_thw[2] / cfg.vision->merge_unit(); + const int pads = + static_cast(std::count(prompt_ids.begin(), prompt_ids.end(), cfg.image_token_id)); + if (pads != merged) { + throw std::runtime_error("multimodal prompt has " + std::to_string(pads) + + " image placeholder(s) but the image yields " + std::to_string(merged)); + } + + mx::array pos = mrope_position_ids(prompt_ids, {pre.grid_thw}, cfg); + return greedy_generate_multimodal(model, prompt_ids, v.hidden, v.deepstack, pos, max_tokens, + eos_ids, on_token); +} + +GenerateResult generate_from_image(const Qwen3VLModel& model, const VitEncoder& vit, + const Tokenizer& tokenizer, const std::string& user_text, + const mx::array& image_rgb, int max_tokens, + const std::vector& eos_ids, + const std::function& on_token, + const PreprocessConfig* pcfg) { + // Single-turn convenience: size the placeholder run from the image dimensions + // (CPU math), render a one-user-message prompt, then generate. The full chat + // history is handled by the caller building prompt_ids for generate_multimodal. + PreprocessConfig pc = pcfg ? *pcfg : PreprocessConfig::from(*model.config().vision); + const int n = image_token_count(image_rgb.shape()[0], image_rgb.shape()[1], pc); Tokenizer::Message msg; msg.role = "user"; msg.content = user_text; - msg.image_token_counts = {image_tokens}; + msg.image_token_counts = {n}; std::vector ids = tokenizer.apply_chat_template({msg}, /*add_generation_prompt=*/true); - - // M-RoPE positions, then generate. - mx::array pos = mrope_position_ids(ids, {pre.grid_thw}, cfg); - return greedy_generate_multimodal(model, ids, v.hidden, v.deepstack, pos, max_tokens, eos_ids, - on_token); + return generate_multimodal(model, vit, ids, image_rgb, max_tokens, eos_ids, on_token, &pc); } } // namespace mlxforge diff --git a/src/runtime/multimodal_stream.h b/src/runtime/multimodal_stream.h index 12c926e..f854996 100644 --- a/src/runtime/multimodal_stream.h +++ b/src/runtime/multimodal_stream.h @@ -35,6 +35,18 @@ GenerateResult greedy_generate_multimodal(const Qwen3VLModel& model, const std::vector& eos_ids, const std::function& on_token = {}); +// Generate from an already-templated multimodal prompt: `prompt_ids` must already +// contain the image placeholder run (image_token_id × N) at the right position — +// the caller renders the full chat history (system + turns) and sizes N from the +// image dimensions. Preprocesses + ViT-encodes `image_rgb`, checks N matches the +// merged-patch count, builds M-RoPE positions, and greedily generates. `pcfg` +// overrides the preprocessing config; NULL uses the model's defaults. +GenerateResult generate_multimodal(const Qwen3VLModel& model, const VitEncoder& vit, + const std::vector& prompt_ids, const mx::array& image_rgb, + int max_tokens, const std::vector& eos_ids, + const std::function& on_token = {}, + const PreprocessConfig* pcfg = nullptr); + // High-level single-image orchestration: smart-resize + preprocess `image_rgb` // (decoded H×W×3 uint8), encode the ViT, render the ChatML prompt with the right // number of image placeholders, build M-RoPE positions, and greedily generate. diff --git a/src/runtime/worker.cpp b/src/runtime/worker.cpp index 16c40c8..7a9c361 100644 --- a/src/runtime/worker.cpp +++ b/src/runtime/worker.cpp @@ -69,7 +69,7 @@ void Worker::handle_embedding(Request& req) { void Worker::handle_multimodal(Request& req) { try { auto* vl = dynamic_cast(model_.get()); - if (vl == nullptr || !model_->config().has_vision_tower() || tok_ == nullptr) { + if (vl == nullptr || !model_->config().has_vision_tower()) { throw std::runtime_error("loaded model is not a vision-language model"); } // The ViT borrows the model's weights; build it once and reuse it. @@ -79,13 +79,23 @@ void Worker::handle_multimodal(Request& req) { mx::array image = decode_image(req.mm_image.data(), req.mm_image.size()); int produced = 0; - GenerateResult r = generate_from_image( - *vl, *vit_, *tok_, req.mm_text, image, req.max_tokens, req.eos_ids, - [&](int id) { - if (produced == 0) req.first_token_time = Request::Clock::now(); - req.tokens.push(id); - ++produced; - }); + auto on_token = [&](int id) { + if (produced == 0) req.first_token_time = Request::Clock::now(); + req.tokens.push(id); + ++produced; + }; + // A caller that pre-rendered the full chat history (the server) supplies + // prompt_ids with the image placeholders already expanded; the simple path + // (C ABI / CLI) supplies just mm_text and we render one user turn. + GenerateResult r; + if (!req.prompt_ids.empty()) { + r = generate_multimodal(*vl, *vit_, req.prompt_ids, image, req.max_tokens, req.eos_ids, + on_token); + } else { + if (tok_ == nullptr) throw std::runtime_error("multimodal text prompt needs a tokenizer"); + r = generate_from_image(*vl, *vit_, *tok_, req.mm_text, image, req.max_tokens, req.eos_ids, + on_token); + } req.finish_reason = r.hit_eos ? "stop" : "length"; } catch (const std::exception& e) { log::error("worker: multimodal error: {}", e.what()); diff --git a/src/server/http_server.cpp b/src/server/http_server.cpp index a93e91a..60c4f5b 100644 --- a/src/server/http_server.cpp +++ b/src/server/http_server.cpp @@ -1,11 +1,15 @@ #include "server/http_server.h" +#include +#include #include #include #include #include "core/logging.h" #include "server/anthropic.h" +#include "vision/image_decode.h" +#include "vision/preprocess.h" namespace mlxforge { @@ -50,17 +54,26 @@ std::shared_ptr HttpServer::make_request(const ChatRequest& cr) const { req->eos_ids = cfg_.eos_token_ids; // Multimodal (Qwen3-VL): an attached image makes this a single-stream vision - // turn. The worker decodes the image, runs the ViT, and renders the prompt - // itself, so we hand it the raw bytes + the latest user text (not a tokenized - // prompt). Requires a vision model; otherwise the worker finishes it as an error. + // turn. We render the FULL chat history (system + prior turns) here, sizing the + // <|image_pad|> run from the image's dimensions (a CPU probe — no decode), and + // attach it to the last user turn. The worker decodes + ViT-encodes the bytes + // and generates from these prompt_ids. if (!cr.image.empty()) { - req->mm_image.assign(cr.image.begin(), cr.image.end()); - for (auto it = cr.messages.rbegin(); it != cr.messages.rend(); ++it) { + if (!cfg_.has_vision_tower()) + throw std::runtime_error("this model does not support image input"); + const std::array hw = + image_info(reinterpret_cast(cr.image.data()), cr.image.size()); + const int n = image_token_count(hw[0], hw[1], PreprocessConfig::from(*cfg_.vision)); + + std::vector msgs = cr.messages; // copy: set the placeholder count + for (auto it = msgs.rbegin(); it != msgs.rend(); ++it) { if (it->role == "user") { - req->mm_text = it->content; + it->image_token_counts = {n}; break; } } + req->prompt_ids = tok_->apply_chat_template(msgs, true, "", {}, cr.enable_thinking); + req->mm_image.assign(cr.image.begin(), cr.image.end()); return req; } diff --git a/src/vision/image_decode.cpp b/src/vision/image_decode.cpp index 532f343..9704b9a 100644 --- a/src/vision/image_decode.cpp +++ b/src/vision/image_decode.cpp @@ -1,5 +1,6 @@ #include "vision/image_decode.h" +#include #include #include #include @@ -24,6 +25,16 @@ mx::array decode_image(const uint8_t* data, std::size_t len) { return mx::array(buf.data(), {h, w, 3}, mx::uint8); } +std::array image_info(const uint8_t* data, std::size_t len) { + int w = 0, h = 0, comp = 0; + if (!stbi_info_from_memory(data, static_cast(len), &w, &h, &comp)) { + const char* reason = stbi_failure_reason(); + throw std::runtime_error(std::string("image_info: ") + + (reason ? reason : "not a decodable image")); + } + return {h, w}; +} + mx::array decode_image_file(const std::string& path) { std::ifstream f(path, std::ios::binary); if (!f) throw std::runtime_error("decode_image_file: cannot open '" + path + "'"); diff --git a/src/vision/image_decode.h b/src/vision/image_decode.h index da8af3e..b676d51 100644 --- a/src/vision/image_decode.h +++ b/src/vision/image_decode.h @@ -5,6 +5,7 @@ // so it never leaks into the public headers. #pragma once +#include #include #include #include @@ -19,6 +20,11 @@ namespace mx = mlx::core; // (with stb's reason) if the bytes are not a decodable image. mx::array decode_image(const uint8_t* data, std::size_t len); +// Read just an encoded image's pixel dimensions, without decoding it (CPU only, +// no MLX) -> {height, width}. Lets a non-worker thread compute the image-token +// count for prompt templating. Throws if the bytes are not a recognizable image. +std::array image_info(const uint8_t* data, std::size_t len); + // Read and decode an image file to (H, W, 3) uint8 RGB. Throws if the file // cannot be read or decoded. mx::array decode_image_file(const std::string& path); diff --git a/src/vision/preprocess.cpp b/src/vision/preprocess.cpp index 530089c..be4dbbf 100644 --- a/src/vision/preprocess.cpp +++ b/src/vision/preprocess.cpp @@ -78,6 +78,13 @@ Preprocessed patchify_image(const mx::array& image_rgb, const PreprocessConfig& return {pixel_values, {1, gh, gw}}; } +int image_token_count(int height, int width, const PreprocessConfig& cfg) { + const int factor = cfg.patch_size * cfg.merge_size; + const std::array hw = smart_resize(height, width, factor, cfg.min_pixels, cfg.max_pixels); + const int gh = hw[0] / cfg.patch_size, gw = hw[1] / cfg.patch_size; // patch grid + return (gh * gw) / (cfg.merge_size * cfg.merge_size); // grid_t == 1 +} + Preprocessed preprocess_image(const mx::array& image_rgb, const PreprocessConfig& cfg) { const int H = image_rgb.shape()[0], W = image_rgb.shape()[1]; const int factor = cfg.patch_size * cfg.merge_size; diff --git a/src/vision/preprocess.h b/src/vision/preprocess.h index 1818f11..b5e3674 100644 --- a/src/vision/preprocess.h +++ b/src/vision/preprocess.h @@ -70,4 +70,10 @@ Preprocessed preprocess_image(const mx::array& image_rgb, const PreprocessConfig std::array smart_resize(int height, int width, int factor, int min_pixels, int max_pixels); +// Number of <|image_pad|> placeholder tokens an image of (height, width) pixels +// expands to under `cfg`: the smart-resized patch grid collapsed by merge_size² +// (grid_t == 1 for a still image). Computed from dimensions alone, so a non-worker +// thread can size the chat-template expansion without decoding the image. +int image_token_count(int height, int width, const PreprocessConfig& cfg); + } // namespace mlxforge diff --git a/tests/vision/preprocess_test.cpp b/tests/vision/preprocess_test.cpp index 69be7df..9957d56 100644 --- a/tests/vision/preprocess_test.cpp +++ b/tests/vision/preprocess_test.cpp @@ -2,6 +2,11 @@ // Pure logic: the committed image_rgb -> pixel_values, no model needed. #include +#include +#include +#include +#include + #include "mlx/ops.h" #include "mlx/transforms.h" @@ -43,6 +48,26 @@ TEST_CASE("Qwen3-VL smart_resize rounds and rescales like the HF reference") { CHECK(smart_resize(64, 64, F, 256, 4096) == std::array{64, 64}); } +TEST_CASE("Qwen3-VL image_info reads dimensions without decoding") { + std::ifstream f(qwen3_vl_ref_path("image.png"), std::ios::binary); + std::vector bytes((std::istreambuf_iterator(f)), + std::istreambuf_iterator()); + std::array hw = image_info(bytes.data(), bytes.size()); + CHECK(hw[0] == 64); // height + CHECK(hw[1] == 64); // width +} + +TEST_CASE("Qwen3-VL image_token_count from dimensions matches the grid") { + PreprocessConfig prod; // production defaults (min 65536, max 16777216) + // 64x64 is below min_pixels -> upscaled to 256x256 -> 16x16 patches -> /merge². + CHECK(image_token_count(64, 64, prod) == 64); + PreprocessConfig tiny = prod; + tiny.min_pixels = 256; + tiny.max_pixels = 4096; + // The fixtures' bounds leave 64x64 untouched -> 4x4 patches -> /merge². + CHECK(image_token_count(64, 64, tiny) == 4); +} + TEST_CASE("Qwen3-VL image decode: PNG file decodes to the reference RGB") { // image.png and image_rgb.npy are the same picture (PNG is lossless), so the // decoded pixels must match the committed RGB array exactly. From c5e67ded510f78f65014e6fcb6eca80d57ef443f Mon Sep 17 00:00:00 2001 From: Helder Vasconcelos Date: Tue, 9 Jun 2026 22:54:01 +0100 Subject: [PATCH 2/6] multimodal: multiple images per request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Support N images in one request end to end. The numeric layer already generalized (mrope_position_ids takes a vector of grids; merge / DeepStack scatter features in order across all <|image_pad|> runs) — this plumbs a list of images through the rest: - Request: mm_image -> mm_images (vector of encoded-byte vectors, in order). - generate_multimodal: take a vector of decoded images; preprocess + ViT- encode each, concatenate the merged features and each DeepStack layer in order, assert the total placeholder count matches, M-RoPE over all grids, generate. generate_from_image wraps a single image. - worker.handle_multimodal: decode every mm_images entry (server path), or the one image (mm_text path). - C ABI submit_image: one image -> mm_images of size 1 (unchanged surface). - server (OpenAI + Anthropic): collect all image parts/blocks in order; make_request sizes each placeholder run from its image's dimensions and attaches the runs (in order) to the last user turn. Gated: parse collects multiple images in order; generate_multimodal runs with two images (feature/DeepStack concat + multi-image M-RoPE). Verified live: a two-image /v1/chat/completions describes each image correctly. Full suite green (193). Note: all images attach to the last user turn (the "compare these N images" case); images spread across different turns are a later refinement. Co-Authored-By: Claude Opus 4.8 --- src/capi/mlxforge.cpp | 2 +- src/runtime/multimodal_stream.cpp | 57 +++++++++++++++++++++++-------- src/runtime/multimodal_stream.h | 16 +++++---- src/runtime/worker.cpp | 14 ++++---- src/scheduler/request.h | 19 ++++++----- src/server/anthropic.cpp | 11 +++--- src/server/http_server.cpp | 29 +++++++++------- src/server/openai.cpp | 20 +++++------ src/server/openai.h | 9 ++--- tests/model/qwen3_vl_test.cpp | 30 +++++++++++++++- tests/server/anthropic_test.cpp | 3 +- tests/server/openai_test.cpp | 21 ++++++++++-- 12 files changed, 158 insertions(+), 73 deletions(-) diff --git a/src/capi/mlxforge.cpp b/src/capi/mlxforge.cpp index 4bf5039..6ab6213 100644 --- a/src/capi/mlxforge.cpp +++ b/src/capi/mlxforge.cpp @@ -276,7 +276,7 @@ mlxforge_request* mlxforge_submit_image(mlxforge_engine* engine, const char* pro try { auto req = std::make_shared(); req->mm_text = prompt ? prompt : ""; - req->mm_image.assign(image_data, image_data + image_len); + req->mm_images.emplace_back(image_data, image_data + image_len); req->params = to_params(sampling); req->max_tokens = sampling_max_tokens(sampling); req->eos_ids = engine->engine->config().eos_token_ids; diff --git a/src/runtime/multimodal_stream.cpp b/src/runtime/multimodal_stream.cpp index b1d019b..9c8d89e 100644 --- a/src/runtime/multimodal_stream.cpp +++ b/src/runtime/multimodal_stream.cpp @@ -1,8 +1,10 @@ #include "runtime/multimodal_stream.h" #include +#include #include #include +#include #include "cache/kv_cache.h" #include "sample/sampler.h" @@ -69,31 +71,56 @@ GenerateResult greedy_generate_multimodal(const Qwen3VLModel& model, } GenerateResult generate_multimodal(const Qwen3VLModel& model, const VitEncoder& vit, - const std::vector& prompt_ids, const mx::array& image_rgb, - int max_tokens, const std::vector& eos_ids, + const std::vector& prompt_ids, + const std::vector& images_rgb, int max_tokens, + const std::vector& eos_ids, const std::function& on_token, const PreprocessConfig* pcfg) { const ModelConfig& cfg = model.config(); + if (images_rgb.empty()) throw std::runtime_error("generate_multimodal: no images"); + const PreprocessConfig pc = pcfg ? *pcfg : PreprocessConfig::from(*cfg.vision); - // Smart-resize + preprocess -> ViT encode. - PreprocessConfig pc = pcfg ? *pcfg : PreprocessConfig::from(*cfg.vision); - Preprocessed pre = preprocess_image(image_rgb, pc); - mx::array grid(pre.grid_thw.data(), {1, 3}, mx::int32); - VitEncoder::Output v = vit.forward(pre.pixel_values, grid); + // Smart-resize + preprocess + ViT-encode each image; collect per-image merged + // features, per-layer DeepStack features, and the patch grids (in order). + std::vector> grids; + std::vector hidden_parts; + std::vector> deepstack_parts; // [image][layer] + int merged_total = 0; + for (const auto& rgb : images_rgb) { + Preprocessed pre = preprocess_image(rgb, pc); + mx::array grid(pre.grid_thw.data(), {1, 3}, mx::int32); + VitEncoder::Output v = vit.forward(pre.pixel_values, grid); + grids.push_back(pre.grid_thw); + hidden_parts.push_back(v.hidden); + deepstack_parts.push_back(v.deepstack); + merged_total += pre.grid_thw[0] * pre.grid_thw[1] * pre.grid_thw[2] / cfg.vision->merge_unit(); + } - // The prompt's placeholder run must match the merged-patch count, or the + // The prompt's placeholder runs must total the merged-patch count, or the // feature scatter (merge_image_features) misaligns. - const int merged = pre.grid_thw[0] * pre.grid_thw[1] * pre.grid_thw[2] / cfg.vision->merge_unit(); const int pads = static_cast(std::count(prompt_ids.begin(), prompt_ids.end(), cfg.image_token_id)); - if (pads != merged) { + if (pads != merged_total) { throw std::runtime_error("multimodal prompt has " + std::to_string(pads) + - " image placeholder(s) but the image yields " + std::to_string(merged)); + " image placeholder(s) but the image(s) yield " + + std::to_string(merged_total)); + } + + // Concatenate features (and each DeepStack layer) across images, in order. + auto cat = [](const std::vector& parts) { + return parts.size() == 1 ? parts[0] : mx::concatenate(parts, /*axis=*/0); + }; + mx::array features = cat(hidden_parts); + std::vector deepstack; + for (size_t layer = 0; layer < deepstack_parts[0].size(); ++layer) { + std::vector layer_parts; + for (const auto& per_image : deepstack_parts) layer_parts.push_back(per_image[layer]); + deepstack.push_back(cat(layer_parts)); } - mx::array pos = mrope_position_ids(prompt_ids, {pre.grid_thw}, cfg); - return greedy_generate_multimodal(model, prompt_ids, v.hidden, v.deepstack, pos, max_tokens, - eos_ids, on_token); + mx::array pos = mrope_position_ids(prompt_ids, grids, cfg); + return greedy_generate_multimodal(model, prompt_ids, features, deepstack, pos, max_tokens, eos_ids, + on_token); } GenerateResult generate_from_image(const Qwen3VLModel& model, const VitEncoder& vit, @@ -112,7 +139,7 @@ GenerateResult generate_from_image(const Qwen3VLModel& model, const VitEncoder& msg.content = user_text; msg.image_token_counts = {n}; std::vector ids = tokenizer.apply_chat_template({msg}, /*add_generation_prompt=*/true); - return generate_multimodal(model, vit, ids, image_rgb, max_tokens, eos_ids, on_token, &pc); + return generate_multimodal(model, vit, ids, {image_rgb}, max_tokens, eos_ids, on_token, &pc); } } // namespace mlxforge diff --git a/src/runtime/multimodal_stream.h b/src/runtime/multimodal_stream.h index f854996..aeebe05 100644 --- a/src/runtime/multimodal_stream.h +++ b/src/runtime/multimodal_stream.h @@ -36,14 +36,16 @@ GenerateResult greedy_generate_multimodal(const Qwen3VLModel& model, const std::function& on_token = {}); // Generate from an already-templated multimodal prompt: `prompt_ids` must already -// contain the image placeholder run (image_token_id × N) at the right position — -// the caller renders the full chat history (system + turns) and sizes N from the -// image dimensions. Preprocesses + ViT-encodes `image_rgb`, checks N matches the -// merged-patch count, builds M-RoPE positions, and greedily generates. `pcfg` -// overrides the preprocessing config; NULL uses the model's defaults. +// contain the image placeholder runs (image_token_id × Nᵢ) in order — the caller +// renders the full chat history and sizes each Nᵢ from that image's dimensions. +// Preprocesses + ViT-encodes each of `images_rgb`, concatenates their features / +// DeepStack outputs in order, checks the total placeholder count matches, builds +// M-RoPE positions over all images, and greedily generates. `pcfg` overrides the +// preprocessing config; NULL uses the model's defaults. GenerateResult generate_multimodal(const Qwen3VLModel& model, const VitEncoder& vit, - const std::vector& prompt_ids, const mx::array& image_rgb, - int max_tokens, const std::vector& eos_ids, + const std::vector& prompt_ids, + const std::vector& images_rgb, int max_tokens, + const std::vector& eos_ids, const std::function& on_token = {}, const PreprocessConfig* pcfg = nullptr); diff --git a/src/runtime/worker.cpp b/src/runtime/worker.cpp index 7a9c361..803f526 100644 --- a/src/runtime/worker.cpp +++ b/src/runtime/worker.cpp @@ -76,7 +76,9 @@ void Worker::handle_multimodal(Request& req) { if (!vit_) { vit_ = std::make_unique(*model_->config().vision, model_->weights()); } - mx::array image = decode_image(req.mm_image.data(), req.mm_image.size()); + std::vector images; + images.reserve(req.mm_images.size()); + for (const auto& bytes : req.mm_images) images.push_back(decode_image(bytes.data(), bytes.size())); int produced = 0; auto on_token = [&](int id) { @@ -85,16 +87,16 @@ void Worker::handle_multimodal(Request& req) { ++produced; }; // A caller that pre-rendered the full chat history (the server) supplies - // prompt_ids with the image placeholders already expanded; the simple path - // (C ABI / CLI) supplies just mm_text and we render one user turn. + // prompt_ids with the image placeholders already expanded (any number of + // images); the simple path (C ABI / CLI) supplies just mm_text + one image. GenerateResult r; if (!req.prompt_ids.empty()) { - r = generate_multimodal(*vl, *vit_, req.prompt_ids, image, req.max_tokens, req.eos_ids, + r = generate_multimodal(*vl, *vit_, req.prompt_ids, images, req.max_tokens, req.eos_ids, on_token); } else { if (tok_ == nullptr) throw std::runtime_error("multimodal text prompt needs a tokenizer"); - r = generate_from_image(*vl, *vit_, *tok_, req.mm_text, image, req.max_tokens, req.eos_ids, - on_token); + r = generate_from_image(*vl, *vit_, *tok_, req.mm_text, images.front(), req.max_tokens, + req.eos_ids, on_token); } req.finish_reason = r.hit_eos ? "stop" : "length"; } catch (const std::exception& e) { diff --git a/src/scheduler/request.h b/src/scheduler/request.h index acfb765..89ea733 100644 --- a/src/scheduler/request.h +++ b/src/scheduler/request.h @@ -87,14 +87,17 @@ struct Request { bool embedding_normalize = true; std::vector embedding_result; - // Multimodal (Qwen3-VL) one-shot generation: when `mm_image` is non-empty the - // worker decodes the image, runs the ViT, renders the chat prompt from - // `mm_text` with the image placeholders, and streams generated tokens - // single-stream (not merged into the continuous-decode batch). `prompt_ids` is - // unused on this path — the worker builds the prompt itself. - std::string mm_text; // the user's text prompt - std::vector mm_image; // raw encoded image bytes (JPEG/PNG/…) - bool is_multimodal() const { return !mm_image.empty(); } + // Multimodal (Qwen3-VL) one-shot generation: when `mm_images` is non-empty the + // worker decodes each image, runs the ViT, and streams generated tokens + // single-stream (not merged into the continuous-decode batch). Two prompt forms: + // - `prompt_ids` set (the server): already chat-templated with the image + // placeholder runs expanded in order — the worker uses it as-is. + // - `prompt_ids` empty (C ABI / CLI): `mm_text` is a single user turn the + // worker renders itself (one image). + // Images are consumed in `mm_images` order, matching the <|image_pad|> runs. + std::string mm_text; // single-turn user text (mm_text path) + std::vector> mm_images; // raw encoded image bytes, in order + bool is_multimodal() const { return !mm_images.empty(); } // Set by the submitting thread (e.g. client disconnect); read by the worker. std::atomic cancelled{false}; diff --git a/src/server/anthropic.cpp b/src/server/anthropic.cpp index 6ac1e4f..dc2ec1f 100644 --- a/src/server/anthropic.cpp +++ b/src/server/anthropic.cpp @@ -35,7 +35,8 @@ std::string render_tool_use(const json& block) { // content yields a single message; a block array is split so tool_use becomes an // assistant tool-call turn and tool_result becomes a "tool" turn (matching the // roles the chat template understands). -void append_message(const json& m, std::vector& out, std::string& image_out) { +void append_message(const json& m, std::vector& out, + std::vector& images_out) { if (!m.contains("role")) throw std::runtime_error("each message needs a 'role'"); const std::string role = m.at("role").get(); auto it = m.find("content"); @@ -49,7 +50,7 @@ void append_message(const json& m, std::vector& out, std::st // tool_result blocks (carried on a user turn) are fed back first, then any // free text; tool_use blocks (on an assistant turn) become tool-call turns; - // the first base64 image block is decoded into image_out. + // base64 image blocks are decoded into images_out, in order. std::string text; std::vector tool_uses; for (const auto& b : *it) { @@ -61,10 +62,10 @@ void append_message(const json& m, std::vector& out, std::st tool_uses.push_back(b); } else if (type == "tool_result") { out.push_back({"tool", text_of(b.value("content", json())), ""}); - } else if (type == "image" && image_out.empty()) { + } else if (type == "image") { const auto src = b.find("source"); if (src != b.end() && src->is_object() && src->value("type", std::string()) == "base64") - image_out = base64_decode(src->value("data", std::string())); + images_out.push_back(base64_decode(src->value("data", std::string()))); } } if (!text.empty()) out.push_back({role, text, ""}); @@ -122,7 +123,7 @@ ChatRequest parse_messages_request(const nlohmann::json& body) { auto it = body.find("messages"); if (it == body.end() || !it->is_array() || it->empty()) throw std::runtime_error("'messages' must be a non-empty array"); - for (const auto& m : *it) append_message(m, r.messages, r.image); + for (const auto& m : *it) append_message(m, r.messages, r.images); r.params.temperature = body.value("temperature", 1.0f); if (r.params.temperature < 0.0f) throw std::runtime_error("'temperature' must be >= 0"); diff --git a/src/server/http_server.cpp b/src/server/http_server.cpp index 60c4f5b..a405279 100644 --- a/src/server/http_server.cpp +++ b/src/server/http_server.cpp @@ -53,27 +53,30 @@ std::shared_ptr HttpServer::make_request(const ChatRequest& cr) const { req->max_tokens = cr.max_tokens; req->eos_ids = cfg_.eos_token_ids; - // Multimodal (Qwen3-VL): an attached image makes this a single-stream vision - // turn. We render the FULL chat history (system + prior turns) here, sizing the - // <|image_pad|> run from the image's dimensions (a CPU probe — no decode), and - // attach it to the last user turn. The worker decodes + ViT-encodes the bytes - // and generates from these prompt_ids. - if (!cr.image.empty()) { + // Multimodal (Qwen3-VL): attached image(s) make this a single-stream vision + // turn. We render the FULL chat history (system + prior turns) here, sizing each + // <|image_pad|> run from that image's dimensions (a CPU probe — no decode), and + // attach the runs (in order) to the last user turn. The worker decodes + + // ViT-encodes the bytes and generates from these prompt_ids. + if (!cr.images.empty()) { if (!cfg_.has_vision_tower()) throw std::runtime_error("this model does not support image input"); - const std::array hw = - image_info(reinterpret_cast(cr.image.data()), cr.image.size()); - const int n = image_token_count(hw[0], hw[1], PreprocessConfig::from(*cfg_.vision)); - - std::vector msgs = cr.messages; // copy: set the placeholder count + const PreprocessConfig pc = PreprocessConfig::from(*cfg_.vision); + std::vector counts; + for (const auto& img : cr.images) { + const std::array hw = + image_info(reinterpret_cast(img.data()), img.size()); + counts.push_back(image_token_count(hw[0], hw[1], pc)); + req->mm_images.emplace_back(img.begin(), img.end()); + } + std::vector msgs = cr.messages; // copy: set placeholder counts for (auto it = msgs.rbegin(); it != msgs.rend(); ++it) { if (it->role == "user") { - it->image_token_counts = {n}; + it->image_token_counts = counts; break; } } req->prompt_ids = tok_->apply_chat_template(msgs, true, "", {}, cr.enable_thinking); - req->mm_image.assign(cr.image.begin(), cr.image.end()); return req; } diff --git a/src/server/openai.cpp b/src/server/openai.cpp index 7343669..c5e8f1a 100644 --- a/src/server/openai.cpp +++ b/src/server/openai.cpp @@ -78,9 +78,9 @@ std::string decode_image_url(const std::string& url) { } // Parse an OpenAI `content` value (a string, or an array of {type:"text"} / -// {type:"image_url"} parts) into its text. The first image found is decoded into -// `image_out` (left untouched if there is none). -std::string parse_content(const json& content, std::string& image_out) { +// {type:"image_url"} parts) into its text. Each image found is decoded and +// appended to `images_out`, in order. +std::string parse_content(const json& content, std::vector& images_out) { if (content.is_string()) return content.get(); if (!content.is_array()) throw std::runtime_error("'content' must be a string or an array of parts"); @@ -90,21 +90,21 @@ std::string parse_content(const json& content, std::string& image_out) { const std::string type = part.value("type", std::string()); if (type == "text") { text += part.value("text", std::string()); - } else if (type == "image_url" && image_out.empty()) { + } else if (type == "image_url") { const auto iu = part.find("image_url"); std::string url; if (iu != part.end() && iu->is_object()) url = iu->value("url", std::string()); else if (iu != part.end() && iu->is_string()) url = iu->get(); - if (!url.empty()) image_out = decode_image_url(url); + if (!url.empty()) images_out.push_back(decode_image_url(url)); } } return text; } // Parse one chat message into our Message, handling assistant tool_calls (content -// may be null/absent), array content with images, and tool-result messages. The -// first image across the conversation is decoded into `image_out`. -Tokenizer::Message parse_message(const json& m, std::string& image_out) { +// may be null/absent), array content with images, and tool-result messages. Any +// images are decoded and appended to `images_out`, in order. +Tokenizer::Message parse_message(const json& m, std::vector& images_out) { if (!m.contains("role")) throw std::runtime_error("each message needs a 'role'"); Tokenizer::Message msg; msg.role = m.at("role").get(); @@ -114,7 +114,7 @@ Tokenizer::Message parse_message(const json& m, std::string& image_out) { } auto c = m.find("content"); if (c == m.end() || c->is_null()) throw std::runtime_error("each message needs 'content'"); - msg.content = parse_content(*c, image_out); + msg.content = parse_content(*c, images_out); return msg; } @@ -167,7 +167,7 @@ ChatRequest parse_chat_request(const nlohmann::json& body) { auto it = body.find("messages"); if (it == body.end() || !it->is_array() || it->empty()) throw std::runtime_error("'messages' must be a non-empty array"); - for (const auto& m : *it) r.messages.push_back(parse_message(m, r.image)); + for (const auto& m : *it) r.messages.push_back(parse_message(m, r.images)); parse_common(body, r); return r; } diff --git a/src/server/openai.h b/src/server/openai.h index 9a7fcbb..569914f 100644 --- a/src/server/openai.h +++ b/src/server/openai.h @@ -18,10 +18,11 @@ struct ChatRequest { std::string model; std::vector messages; // chat; for /v1/completions a single user msg bool is_chat = true; // chat vs raw completion - // Decoded bytes of the first image attached to the conversation (empty = none). - // When set, the request is served as a single-stream Qwen3-VL multimodal turn: - // the worker decodes the image, runs the ViT, and generates from the user text. - std::string image; + // Decoded bytes of each image attached to the conversation, in order (empty = + // none). When non-empty the request is served as a single-stream Qwen3-VL + // multimodal turn: each image is decoded, ViT-encoded, and its placeholder run + // expanded into the prompt. + std::vector images; SamplingParams params; int max_tokens = 128; bool stream = false; diff --git a/tests/model/qwen3_vl_test.cpp b/tests/model/qwen3_vl_test.cpp index 54099e8..478b9cd 100644 --- a/tests/model/qwen3_vl_test.cpp +++ b/tests/model/qwen3_vl_test.cpp @@ -253,6 +253,34 @@ TEST_CASE("Qwen3-VL: generate_from_image composes the full pipeline") { assert_tokens_equal(r.tokens, load_qwen3_vl_token_ids("greedy_tokens.npy")); } +TEST_CASE("Qwen3-VL: generate_multimodal handles multiple images") { + if (!qwen3_vl_model_available()) { + MESSAGE("Qwen3-VL model not found in HF cache; skipping multi-image test"); + return; + } + // Two images in one user turn: the prompt carries two image_pad runs and the + // ViT features / DeepStack are concatenated in order. We can't gate exact + // tokens (no two-image fixture), so this checks the plumbing runs end to end. + const Qwen3VLModel& m = shared_qwen3_vl_model(); + const VitEncoder& vit = shared_qwen3_vl_vit(); + Tokenizer tok = Tokenizer::from_file(qwen3_vl_model_dir() + "/tokenizer.json", -1, ChatFormat::Qwen3); + mx::array rgb = load_qwen3_vl_npy("image_rgb.npy"); // 64x64 + + PreprocessConfig pc = PreprocessConfig::from(*qwen3_vl_config().vision); + pc.min_pixels = 256; + pc.max_pixels = 4096; // identity resize -> grid (1,4,4) -> 4 tokens/image + const int n = image_token_count(64, 64, pc); + Tokenizer::Message msg; + msg.role = "user"; + msg.content = "Compare these two images."; + msg.image_token_counts = {n, n}; // two image runs + std::vector ids = tok.apply_chat_template({msg}, /*add_generation_prompt=*/true); + + GenerateResult r = + generate_multimodal(m, vit, ids, {rgb, rgb}, /*max_tokens=*/8, /*eos=*/{}, /*on_token=*/{}, &pc); + CHECK(r.tokens.size() == 8); // concatenation + scatter + M-RoPE over two images ran +} + TEST_CASE("Qwen3-VL: worker serves a multimodal request from another thread") { if (!qwen3_vl_model_available()) { MESSAGE("Qwen3-VL model not found in HF cache; skipping multimodal worker test"); @@ -280,7 +308,7 @@ TEST_CASE("Qwen3-VL: worker serves a multimodal request from another thread") { auto req = std::make_shared(); req->mm_text = "What is in this image?"; - req->mm_image = img_bytes; + req->mm_images.push_back(img_bytes); req->max_tokens = 10; // Empty eos_ids so the run matches the fixed-length reference greedy stream. sched.submit(req); diff --git a/tests/server/anthropic_test.cpp b/tests/server/anthropic_test.cpp index e6a6fe2..7085e6a 100644 --- a/tests/server/anthropic_test.cpp +++ b/tests/server/anthropic_test.cpp @@ -182,5 +182,6 @@ TEST_CASE("parse_messages_request extracts a base64 image block as bytes") { "max_tokens": 8 })"); ChatRequest r = parse_messages_request(body); - CHECK(r.image == "hello"); + REQUIRE(r.images.size() == 1); + CHECK(r.images[0] == "hello"); } diff --git a/tests/server/openai_test.cpp b/tests/server/openai_test.cpp index 4cd53dd..3c63622 100644 --- a/tests/server/openai_test.cpp +++ b/tests/server/openai_test.cpp @@ -217,7 +217,8 @@ TEST_CASE("parse_chat_request extracts an image_url data URI as bytes") { ChatRequest r = parse_chat_request(body); REQUIRE(r.messages.size() == 1); CHECK(r.messages[0].content == "what is this?"); - CHECK(r.image == "hello"); // routed to the multimodal path by make_request + REQUIRE(r.images.size() == 1); + CHECK(r.images[0] == "hello"); // routed to the multimodal path by make_request } TEST_CASE("parse_chat_request rejects a non-data: image URL") { @@ -233,5 +234,21 @@ TEST_CASE("parse_chat_request leaves image empty for text-only content") { json body = json::parse(R"({ "messages": [{"role": "user", "content": "just text"}], "max_tokens": 8 })"); - CHECK(parse_chat_request(body).image.empty()); + CHECK(parse_chat_request(body).images.empty()); +} + +TEST_CASE("parse_chat_request collects multiple images in order") { + // base64 "aGVsbG8=" -> "hello", "d29ybGQ=" -> "world". + json body = json::parse(R"({ + "messages": [{"role": "user", "content": [ + {"type": "text", "text": "compare these"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,aGVsbG8="}}, + {"type": "image_url", "image_url": {"url": "data:image/jpeg;base64,d29ybGQ="}} + ]}], + "max_tokens": 8 + })"); + ChatRequest r = parse_chat_request(body); + REQUIRE(r.images.size() == 2); + CHECK(r.images[0] == "hello"); + CHECK(r.images[1] == "world"); } From 00711fb808cdb2d32e41f413e6c6c848611d6ece Mon Sep 17 00:00:00 2001 From: Helder Vasconcelos Date: Tue, 9 Jun 2026 23:04:41 +0100 Subject: [PATCH 3/6] capi+bindings: multi-image submit (mlxforge_submit_images, images()) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend the multi-image path through the embeddable surface so it's no longer server-only: - C ABI (v4, append-only): add mlxforge_image + mlxforge_submit_images (prompt + N images, expanded/attended in array order). mlxforge_submit_ image stays; both share a submit_mm helper. Baseline refreshed; the guard passes (18 symbols). - engine: generate_from_images (single-turn, N images) — generate_from_image is now a one-image wrapper; the worker's mm_text path uses all images. - bindings: add images() to Rust (Engine::images(&[&[u8]])), Swift (images([[UInt8]]) — flattened into one stable buffer), and Node (engine.images(Buffer[]) via submitImages in the N-API addon + d.ts). Gated: the C ABI test submits a two-image request and streams a finished response; the Rust crate type-checks against the updated FFI. Full suite green (193). Co-Authored-By: Claude Opus 4.8 --- bindings/node/index.d.ts | 6 ++ bindings/node/index.js | 9 +++ bindings/node/src/addon.cc | 29 ++++++++ bindings/rust/src/lib.rs | 46 ++++++++++++ .../swift/Sources/MLXForge/MLXForge.swift | 33 +++++++++ cmake/abi-baseline.txt | 1 + src/capi/mlxforge.cpp | 74 +++++++++++++++---- src/capi/mlxforge.h | 21 +++++- src/runtime/multimodal_stream.cpp | 37 +++++++--- src/runtime/multimodal_stream.h | 19 +++-- src/runtime/worker.cpp | 4 +- tests/capi/capi_test.cpp | 9 +++ 12 files changed, 252 insertions(+), 36 deletions(-) diff --git a/bindings/node/index.d.ts b/bindings/node/index.d.ts index 3911e3e..fe8982f 100644 --- a/bindings/node/index.d.ts +++ b/bindings/node/index.d.ts @@ -64,6 +64,12 @@ export class Engine { * be a vision-language checkpoint (e.g. Qwen3-VL). */ image(prompt: string, image: Uint8Array, sampling?: SamplingOptions): Stream; + /** + * Stream a vision-language completion over several images (raw encoded bytes + * each), expanded into the prompt in order. The model must be a vision-language + * checkpoint (e.g. Qwen3-VL). + */ + images(prompt: string, images: Uint8Array[], sampling?: SamplingOptions): Stream; /** Run a chat to completion and return the full string. */ complete(messages: ChatMessage[], sampling?: SamplingOptions): Promise; /** diff --git a/bindings/node/index.js b/bindings/node/index.js index d3bd181..87c3b67 100644 --- a/bindings/node/index.js +++ b/bindings/node/index.js @@ -87,6 +87,15 @@ class Engine { return streamRequest(this._h.submitImage(prompt, imageBuffer, sampling)); } + /** + * Stream a vision-language completion over several images (an array of + * Buffers/Uint8Arrays of raw encoded bytes), expanded into the prompt in + * order. The loaded model must be a vision-language checkpoint (e.g. Qwen3-VL). + */ + images(prompt, imageBuffers, sampling = {}) { + return streamRequest(this._h.submitImages(prompt, imageBuffers, sampling)); + } + /** Run a chat to completion and return the full string. */ complete(messages, sampling = {}) { return collect(this.chat(messages, sampling)); diff --git a/bindings/node/src/addon.cc b/bindings/node/src/addon.cc index 77dc58a..f6f498c 100644 --- a/bindings/node/src/addon.cc +++ b/bindings/node/src/addon.cc @@ -220,6 +220,7 @@ class EngineWrap : public Napi::ObjectWrap { InstanceMethod("submitChat", &EngineWrap::SubmitChat), InstanceMethod("submitText", &EngineWrap::SubmitText), InstanceMethod("submitImage", &EngineWrap::SubmitImage), + InstanceMethod("submitImages", &EngineWrap::SubmitImages), InstanceMethod("embed", &EngineWrap::Embed), InstanceMethod("dispose", &EngineWrap::Dispose), }); @@ -344,6 +345,34 @@ class EngineWrap : public Napi::ObjectWrap { return finish_submit(env, req, err); } + Napi::Value SubmitImages(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + if (!eng_) throw Napi::Error::New(env, "engine is disposed"); + if (info.Length() < 2 || !info[0].IsString() || !info[1].IsArray()) + throw Napi::TypeError::New(env, "submitImages(prompt, imageBuffers[], sampling?)"); + std::string prompt = info[0].As().Utf8Value(); + Napi::Array arr = info[1].As(); + // Pointers into the JS Buffers; valid for the synchronous submit call (the + // engine copies the bytes before returning). The Buffers stay alive via `arr`. + std::vector images; + images.reserve(arr.Length()); + for (uint32_t i = 0; i < arr.Length(); ++i) { + Napi::Value v = arr.Get(i); + if (!v.IsBuffer()) throw Napi::TypeError::New(env, "each image must be a Buffer"); + Napi::Buffer buf = v.As>(); + images.push_back(mlxforge_image{buf.Data(), buf.Length()}); + } + std::string schema; // kept alive across the submit call + mlxforge_sampling s = (info.Length() >= 3 && info[2].IsObject()) + ? parse_sampling(info[2].As(), schema) + : mlxforge_sampling{}; + if (!schema.empty()) s.json_schema = schema.c_str(); + char* err = nullptr; + mlxforge_request* req = + mlxforge_submit_images(eng_, prompt.c_str(), images.data(), images.size(), &s, &err); + return finish_submit(env, req, err); + } + Napi::Value Embed(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); auto deferred = Napi::Promise::Deferred::New(env); diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index d2f7f14..53e5e1f 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -30,6 +30,12 @@ struct Msg { content: *const c_char, } +#[repr(C)] +struct CImage { + data: *const u8, + len: usize, +} + #[repr(C)] struct CSampling { temperature: c_float, @@ -106,6 +112,14 @@ extern "C" { sampling: *const CSampling, err: *mut *mut c_char, ) -> *mut mlxforge_request; + fn mlxforge_submit_images( + e: *mut mlxforge_engine, + prompt: *const c_char, + images: *const CImage, + n_images: usize, + sampling: *const CSampling, + err: *mut *mut c_char, + ) -> *mut mlxforge_request; fn mlxforge_request_next(r: *mut mlxforge_request, text: *mut *mut c_char) -> c_int; fn mlxforge_request_finish_reason(r: *mut mlxforge_request) -> *const c_char; fn mlxforge_request_free(r: *mut mlxforge_request); @@ -291,6 +305,38 @@ impl Engine { Ok(drain(req)) } + /// Run a vision-language completion over several images: a text `prompt` + /// about `images` (each raw encoded bytes), expanded into the prompt in + /// order. The loaded model must be a vision-language checkpoint (Qwen3-VL). + pub fn images(&self, prompt: &str, images: &[&[u8]], sampling: &Sampling) + -> Result + { + let cprompt = CString::new(prompt).map_err(|_| "prompt contains NUL".to_string())?; + // Raw pointers into the borrowed slices; valid for the duration of the + // submit call (the engine copies the bytes synchronously). + let cimgs: Vec = images + .iter() + .map(|b| CImage { data: b.as_ptr(), len: b.len() }) + .collect(); + let mut schema_keep: Option = None; + let cs = Self::c_sampling(sampling, &mut schema_keep); + let mut err: *mut c_char = ptr::null_mut(); + let req = unsafe { + mlxforge_submit_images( + self.handle, + cprompt.as_ptr(), + cimgs.as_ptr(), + cimgs.len(), + &cs, + &mut err, + ) + }; + if req.is_null() { + return Err(unsafe { take_string(err) }.unwrap_or_else(|| "submit failed".into())); + } + Ok(drain(req)) + } + /// Embed text into a unit-normalized vector. `pooling`: 0 = mean, 1 = last. /// Simple form (no EOS/instruction); for Qwen3-Embedding conventions or to /// let the model pick its defaults, use [`Engine::embed_with`]. diff --git a/bindings/swift/Sources/MLXForge/MLXForge.swift b/bindings/swift/Sources/MLXForge/MLXForge.swift index ee58be4..a09eeea 100644 --- a/bindings/swift/Sources/MLXForge/MLXForge.swift +++ b/bindings/swift/Sources/MLXForge/MLXForge.swift @@ -126,6 +126,39 @@ public final class Engine { return Engine.stream(req) } + /// Stream a vision-language completion over several images (each raw encoded + /// bytes), expanded into the prompt in order. The loaded model must be a + /// vision-language checkpoint (e.g. Qwen3-VL). + public func images(_ prompt: String, _ imagesBytes: [[UInt8]], sampling: Sampling = .greedy) + throws -> AsyncThrowingStream + { + var s = sampling.c + let schemaC = sampling.jsonSchema.map { strdup($0) } ?? nil + if let p = schemaC { s.json_schema = UnsafePointer(p) } + defer { if let p = schemaC { free(p) } } + var err: UnsafeMutablePointer? + // Flatten into one contiguous buffer so each mlxforge_image points into stable + // storage for the call (the engine copies the bytes synchronously). + let flat = imagesBytes.flatMap { $0 } + let reqOpt = flat.withUnsafeBufferPointer { (fb) -> OpaquePointer? in + var cImages: [mlxforge_image] = [] + var offset = 0 + for img in imagesBytes { + cImages.append(mlxforge_image(data: fb.baseAddress.map { $0 + offset }, len: img.count)) + offset += img.count + } + return cImages.withUnsafeBufferPointer { cb in + mlxforge_submit_images(handle, prompt, cb.baseAddress, cb.count, &s, &err) + } + } + guard let req = reqOpt else { + let message = err.map { String(cString: $0) } ?? "submit failed" + mlxforge_string_free(err) + throw MLXForgeError(message: message) + } + return Engine.stream(req) + } + /// Run a chat to completion and return the full string. public func complete(_ messages: [ChatMessage], sampling: Sampling = .greedy) async throws -> String diff --git a/cmake/abi-baseline.txt b/cmake/abi-baseline.txt index 9adc997..aa934c5 100644 --- a/cmake/abi-baseline.txt +++ b/cmake/abi-baseline.txt @@ -13,5 +13,6 @@ mlxforge_request_next mlxforge_string_free mlxforge_submit_chat mlxforge_submit_image +mlxforge_submit_images mlxforge_submit_text mlxforge_version diff --git a/src/capi/mlxforge.cpp b/src/capi/mlxforge.cpp index 6ab6213..f96f53e 100644 --- a/src/capi/mlxforge.cpp +++ b/src/capi/mlxforge.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include "core/config.h" @@ -261,6 +262,30 @@ mlxforge_request* mlxforge_submit_chat(mlxforge_engine* engine, return nullptr; } +namespace { +// Shared multimodal submit: build a single-turn Request from `prompt` + raw image +// byte-spans, attach sampling/limits, enqueue, and return the handle. The worker +// renders the prompt (one user turn with one block per image) and generates. +mlxforge_request* submit_mm(mlxforge_engine* engine, const char* prompt, + std::vector> images, + const mlxforge_sampling* sampling, char** err) { + auto req = std::make_shared(); + req->mm_text = prompt ? prompt : ""; + req->mm_images = std::move(images); + req->params = to_params(sampling); + req->max_tokens = sampling_max_tokens(sampling); + req->eos_ids = engine->engine->config().eos_token_ids; + if (!engine->engine->scheduler().submit(req)) { + set_err(err, "request rejected: waiting queue is full"); + return nullptr; + } + auto handle = std::make_unique(); + handle->req = req; + handle->detok = std::make_unique(engine->engine->tokenizer()); + return handle.release(); +} +} // namespace + mlxforge_request* mlxforge_submit_image(mlxforge_engine* engine, const char* prompt, const unsigned char* image_data, size_t image_len, const mlxforge_sampling* sampling, char** err) { @@ -274,21 +299,9 @@ mlxforge_request* mlxforge_submit_image(mlxforge_engine* engine, const char* pro return nullptr; } try { - auto req = std::make_shared(); - req->mm_text = prompt ? prompt : ""; - req->mm_images.emplace_back(image_data, image_data + image_len); - req->params = to_params(sampling); - req->max_tokens = sampling_max_tokens(sampling); - req->eos_ids = engine->engine->config().eos_token_ids; - if (!engine->engine->scheduler().submit(req)) { - set_err(err, "request rejected: waiting queue is full"); - return nullptr; - } - auto handle = std::make_unique(); - handle->req = req; - handle->detok = - std::make_unique(engine->engine->tokenizer()); - return handle.release(); + std::vector> images; + images.emplace_back(image_data, image_data + image_len); + return submit_mm(engine, prompt, std::move(images), sampling, err); } catch (const std::exception& e) { set_err(err, e.what()); } catch (...) { @@ -297,6 +310,37 @@ mlxforge_request* mlxforge_submit_image(mlxforge_engine* engine, const char* pro return nullptr; } +mlxforge_request* mlxforge_submit_images(mlxforge_engine* engine, const char* prompt, + const mlxforge_image* images, size_t n_images, + const mlxforge_sampling* sampling, char** err) { + if (err) *err = nullptr; + if (!engine || !engine->engine) { + set_err(err, "engine is null"); + return nullptr; + } + if (!images || n_images == 0) { + set_err(err, "images is empty"); + return nullptr; + } + try { + std::vector> imgs; + imgs.reserve(n_images); + for (size_t i = 0; i < n_images; ++i) { + if (!images[i].data || images[i].len == 0) { + set_err(err, "an image is empty"); + return nullptr; + } + imgs.emplace_back(images[i].data, images[i].data + images[i].len); + } + return submit_mm(engine, prompt, std::move(imgs), sampling, err); + } catch (const std::exception& e) { + set_err(err, e.what()); + } catch (...) { + set_err(err, "unknown error submitting images"); + } + return nullptr; +} + mlxforge_request* mlxforge_submit_text(mlxforge_engine* engine, const char* prompt, const mlxforge_sampling* sampling, char** err) { if (err) *err = nullptr; diff --git a/src/capi/mlxforge.h b/src/capi/mlxforge.h index 56aa7de..4bc14dd 100644 --- a/src/capi/mlxforge.h +++ b/src/capi/mlxforge.h @@ -36,8 +36,9 @@ extern "C" { * v2: mlxforge_embed_ex + mlxforge_embed_opts (Qwen3-Embedding conventions: * last-token pooling, trailing EOS, instruction prefix). * v3: mlxforge_submit_image (Qwen3-VL vision-language: a prompt + one image, - * served single-stream). */ -#define MLXFORGE_ABI_VERSION 3 + * served single-stream). + * v4: mlxforge_image + mlxforge_submit_images (a prompt + N images). */ +#define MLXFORGE_ABI_VERSION 4 typedef struct mlxforge_engine mlxforge_engine; typedef struct mlxforge_request mlxforge_request; @@ -179,6 +180,22 @@ mlxforge_request* mlxforge_submit_image(mlxforge_engine* engine, const char* pro const unsigned char* image_data, size_t image_len, const mlxforge_sampling* sampling, char** err); +/* One image as raw encoded bytes, for mlxforge_submit_images. */ +typedef struct { + const unsigned char* data; + size_t len; +} mlxforge_image; + +/* Submit a multimodal request with N images: a text `prompt` plus `images[0..n-1]` + * (each raw encoded bytes). The images are expanded into the prompt — and attended + * over — in array order. Otherwise identical to mlxforge_submit_image (single-turn, + * served single-stream; requires a vision-language model). `n_images` must be >= 1. + * + * Returns a request handle, or NULL on failure (sets *err). */ +mlxforge_request* mlxforge_submit_images(mlxforge_engine* engine, const char* prompt, + const mlxforge_image* images, size_t n_images, + const mlxforge_sampling* sampling, char** err); + /* Pull the next chunk of generated text. Blocks until decoded text is available * or the request finishes. The detokenizer is UTF-8-safe: a chunk is always a * run of complete characters (never a split multi-byte sequence). diff --git a/src/runtime/multimodal_stream.cpp b/src/runtime/multimodal_stream.cpp index 9c8d89e..36ff9cd 100644 --- a/src/runtime/multimodal_stream.cpp +++ b/src/runtime/multimodal_stream.cpp @@ -123,23 +123,38 @@ GenerateResult generate_multimodal(const Qwen3VLModel& model, const VitEncoder& on_token); } +GenerateResult generate_from_images(const Qwen3VLModel& model, const VitEncoder& vit, + const Tokenizer& tokenizer, const std::string& user_text, + const std::vector& images_rgb, int max_tokens, + const std::vector& eos_ids, + const std::function& on_token, + const PreprocessConfig* pcfg) { + // Single-turn convenience: size each placeholder run from its image's + // dimensions (CPU math), render a one-user-message prompt with that many image + // blocks, then generate. The full chat history is handled by the caller building + // prompt_ids directly for generate_multimodal (the server path). + PreprocessConfig pc = pcfg ? *pcfg : PreprocessConfig::from(*model.config().vision); + std::vector counts; + counts.reserve(images_rgb.size()); + for (const auto& rgb : images_rgb) + counts.push_back(image_token_count(rgb.shape()[0], rgb.shape()[1], pc)); + + Tokenizer::Message msg; + msg.role = "user"; + msg.content = user_text; + msg.image_token_counts = counts; + std::vector ids = tokenizer.apply_chat_template({msg}, /*add_generation_prompt=*/true); + return generate_multimodal(model, vit, ids, images_rgb, max_tokens, eos_ids, on_token, &pc); +} + GenerateResult generate_from_image(const Qwen3VLModel& model, const VitEncoder& vit, const Tokenizer& tokenizer, const std::string& user_text, const mx::array& image_rgb, int max_tokens, const std::vector& eos_ids, const std::function& on_token, const PreprocessConfig* pcfg) { - // Single-turn convenience: size the placeholder run from the image dimensions - // (CPU math), render a one-user-message prompt, then generate. The full chat - // history is handled by the caller building prompt_ids for generate_multimodal. - PreprocessConfig pc = pcfg ? *pcfg : PreprocessConfig::from(*model.config().vision); - const int n = image_token_count(image_rgb.shape()[0], image_rgb.shape()[1], pc); - Tokenizer::Message msg; - msg.role = "user"; - msg.content = user_text; - msg.image_token_counts = {n}; - std::vector ids = tokenizer.apply_chat_template({msg}, /*add_generation_prompt=*/true); - return generate_multimodal(model, vit, ids, {image_rgb}, max_tokens, eos_ids, on_token, &pc); + return generate_from_images(model, vit, tokenizer, user_text, {image_rgb}, max_tokens, eos_ids, + on_token, pcfg); } } // namespace mlxforge diff --git a/src/runtime/multimodal_stream.h b/src/runtime/multimodal_stream.h index aeebe05..d1ef155 100644 --- a/src/runtime/multimodal_stream.h +++ b/src/runtime/multimodal_stream.h @@ -49,12 +49,19 @@ GenerateResult generate_multimodal(const Qwen3VLModel& model, const VitEncoder& const std::function& on_token = {}, const PreprocessConfig* pcfg = nullptr); -// High-level single-image orchestration: smart-resize + preprocess `image_rgb` -// (decoded H×W×3 uint8), encode the ViT, render the ChatML prompt with the right -// number of image placeholders, build M-RoPE positions, and greedily generate. -// Ties the whole vision pipeline together behind one call (the CLI's image-to- -// text core). `pcfg` overrides the preprocessing config (resize bounds etc.); -// NULL uses the model's defaults. +// High-level single-turn orchestration: render a one-user-message ChatML prompt +// for `user_text` with a placeholder run per image (sized from each image's +// dimensions), then generate over all `images_rgb` (decoded H×W×3 uint8). Ties +// the whole vision pipeline together behind one call (the CLI / C-ABI image-to- +// text core). `pcfg` overrides the preprocessing config; NULL uses the defaults. +GenerateResult generate_from_images(const Qwen3VLModel& model, const VitEncoder& vit, + const Tokenizer& tokenizer, const std::string& user_text, + const std::vector& images_rgb, int max_tokens, + const std::vector& eos_ids, + const std::function& on_token = {}, + const PreprocessConfig* pcfg = nullptr); + +// Single-image convenience: generate_from_images with one image. GenerateResult generate_from_image(const Qwen3VLModel& model, const VitEncoder& vit, const Tokenizer& tokenizer, const std::string& user_text, const mx::array& image_rgb, int max_tokens, diff --git a/src/runtime/worker.cpp b/src/runtime/worker.cpp index 803f526..de69152 100644 --- a/src/runtime/worker.cpp +++ b/src/runtime/worker.cpp @@ -95,8 +95,8 @@ void Worker::handle_multimodal(Request& req) { on_token); } else { if (tok_ == nullptr) throw std::runtime_error("multimodal text prompt needs a tokenizer"); - r = generate_from_image(*vl, *vit_, *tok_, req.mm_text, images.front(), req.max_tokens, - req.eos_ids, on_token); + r = generate_from_images(*vl, *vit_, *tok_, req.mm_text, images, req.max_tokens, req.eos_ids, + on_token); } req.finish_reason = r.hit_eos ? "stop" : "length"; } catch (const std::exception& e) { diff --git a/tests/capi/capi_test.cpp b/tests/capi/capi_test.cpp index 620452f..549cbde 100644 --- a/tests/capi/capi_test.cpp +++ b/tests/capi/capi_test.cpp @@ -74,6 +74,15 @@ TEST_CASE("C ABI submits a multimodal request and streams a response") { CHECK(text.size() > 0); CHECK(std::string(mlxforge_request_finish_reason(r)) == "length"); + // The same conversation with two images via mlxforge_submit_images. + mlxforge_image two[2] = {{img.data(), img.size()}, {img.data(), img.size()}}; + mlxforge_request* r2 = + mlxforge_submit_images(eng, "Compare these images.", two, 2, &s, &err); + REQUIRE_MESSAGE(r2 != nullptr, (err ? err : "submit_images failed")); + CHECK(drain(r2).size() > 0); + CHECK(std::string(mlxforge_request_finish_reason(r2)) == "length"); + mlxforge_request_free(r2); + mlxforge_request_free(r); mlxforge_engine_free(eng); } From 846d850410f0745081e7a5582233e162a16667f3 Mon Sep 17 00:00:00 2001 From: Helder Vasconcelos Date: Tue, 9 Jun 2026 23:16:25 +0100 Subject: [PATCH 4/6] server: place images on their own turn (multi-turn vision history) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously all images attached to the last user turn; an image sent on an earlier turn rendered in the wrong position. Track images per message and place each placeholder run on the turn it belongs to: - ChatRequest: images -> message_images (aligned 1:1 with messages; one entry per turn, in order). has_images() helper. - OpenAI / Anthropic parsers: collect each turn's images into its own slot (Anthropic keeps message_images in lockstep with the messages it emits, and now emits an image-only user turn so its image renders). - make_request: walk the messages in the renderer's vision-block order (skip a leading system turn and tool/assistant turns, which emit no block), sizing each run from its image's dimensions and collecting mm_images in that same order — so the placeholder count and the image features stay aligned. Gated: parse keeps an image on the first user turn across a 3-message exchange. Verified live: an image sent on turn 1 and asked about on turn 3 ("what animal was in the picture I sent earlier?") is answered correctly (tiger). Full suite green (194). Co-Authored-By: Claude Opus 4.8 --- src/server/anthropic.cpp | 29 +++++++++++++++++++++------- src/server/http_server.cpp | 32 ++++++++++++++++--------------- src/server/openai.cpp | 6 +++++- src/server/openai.h | 16 +++++++++++----- tests/server/anthropic_test.cpp | 5 +++-- tests/server/openai_test.cpp | 34 +++++++++++++++++++++++++++------ 6 files changed, 86 insertions(+), 36 deletions(-) diff --git a/src/server/anthropic.cpp b/src/server/anthropic.cpp index dc2ec1f..d35eca0 100644 --- a/src/server/anthropic.cpp +++ b/src/server/anthropic.cpp @@ -36,23 +36,26 @@ std::string render_tool_use(const json& block) { // assistant tool-call turn and tool_result becomes a "tool" turn (matching the // roles the chat template understands). void append_message(const json& m, std::vector& out, - std::vector& images_out) { + std::vector>& message_images) { if (!m.contains("role")) throw std::runtime_error("each message needs a 'role'"); const std::string role = m.at("role").get(); auto it = m.find("content"); if (it == m.end() || it->is_null()) throw std::runtime_error("each message needs 'content'"); + // Every push to `out` gets a matching `message_images` entry (1:1). if (it->is_string()) { out.push_back({role, it->get(), ""}); + message_images.push_back({}); return; } if (!it->is_array()) throw std::runtime_error("'content' must be a string or array of blocks"); // tool_result blocks (carried on a user turn) are fed back first, then any // free text; tool_use blocks (on an assistant turn) become tool-call turns; - // base64 image blocks are decoded into images_out, in order. + // base64 image blocks decode onto this turn (in order). std::string text; std::vector tool_uses; + std::vector images; for (const auto& b : *it) { if (!b.is_object()) continue; const std::string type = b.value("type", std::string()); @@ -62,14 +65,23 @@ void append_message(const json& m, std::vector& out, tool_uses.push_back(b); } else if (type == "tool_result") { out.push_back({"tool", text_of(b.value("content", json())), ""}); + message_images.push_back({}); } else if (type == "image") { const auto src = b.find("source"); if (src != b.end() && src->is_object() && src->value("type", std::string()) == "base64") - images_out.push_back(base64_decode(src->value("data", std::string()))); + images.push_back(base64_decode(src->value("data", std::string()))); } } - if (!text.empty()) out.push_back({role, text, ""}); - for (const auto& tu : tool_uses) out.push_back({"assistant", "", render_tool_use(tu)}); + // The role turn carries this turn's text + images (emit it even for an + // image-only turn so the image renders). + if (!text.empty() || !images.empty()) { + out.push_back({role, text, ""}); + message_images.push_back(std::move(images)); + } + for (const auto& tu : tool_uses) { + out.push_back({"assistant", "", render_tool_use(tu)}); + message_images.push_back({}); + } } // Convert an Anthropic tool definition {name, description, input_schema} into the @@ -117,13 +129,16 @@ ChatRequest parse_messages_request(const nlohmann::json& body) { // template expects it first, especially for Qwen3). if (auto s = body.find("system"); s != body.end() && !s->is_null()) { const std::string sys = text_of(*s); - if (!sys.empty()) r.messages.push_back({"system", sys, ""}); + if (!sys.empty()) { + r.messages.push_back({"system", sys, ""}); + r.message_images.push_back({}); // keep aligned with messages + } } auto it = body.find("messages"); if (it == body.end() || !it->is_array() || it->empty()) throw std::runtime_error("'messages' must be a non-empty array"); - for (const auto& m : *it) append_message(m, r.messages, r.images); + for (const auto& m : *it) append_message(m, r.messages, r.message_images); r.params.temperature = body.value("temperature", 1.0f); if (r.params.temperature < 0.0f) throw std::runtime_error("'temperature' must be >= 0"); diff --git a/src/server/http_server.cpp b/src/server/http_server.cpp index a405279..bfffc56 100644 --- a/src/server/http_server.cpp +++ b/src/server/http_server.cpp @@ -55,25 +55,27 @@ std::shared_ptr HttpServer::make_request(const ChatRequest& cr) const { // Multimodal (Qwen3-VL): attached image(s) make this a single-stream vision // turn. We render the FULL chat history (system + prior turns) here, sizing each - // <|image_pad|> run from that image's dimensions (a CPU probe — no decode), and - // attach the runs (in order) to the last user turn. The worker decodes + - // ViT-encodes the bytes and generates from these prompt_ids. - if (!cr.images.empty()) { + // <|image_pad|> run from its image's dimensions (a CPU probe — no decode) and + // attaching it to the turn the image belongs to, so images placed across + // different turns land at the right positions. The worker decodes + ViT-encodes + // the bytes (in order) and generates from these prompt_ids. + if (cr.has_images()) { if (!cfg_.has_vision_tower()) throw std::runtime_error("this model does not support image input"); const PreprocessConfig pc = PreprocessConfig::from(*cfg_.vision); - std::vector counts; - for (const auto& img : cr.images) { - const std::array hw = - image_info(reinterpret_cast(img.data()), img.size()); - counts.push_back(image_token_count(hw[0], hw[1], pc)); - req->mm_images.emplace_back(img.begin(), img.end()); - } std::vector msgs = cr.messages; // copy: set placeholder counts - for (auto it = msgs.rbegin(); it != msgs.rend(); ++it) { - if (it->role == "user") { - it->image_token_counts = counts; - break; + + // Mirror render_qwen3's vision-block emission order: a leading system turn and + // any tool/assistant turns render no vision block, so we skip their images + // (and never let them desync the placeholder count). + const bool have_system = !msgs.empty() && msgs.front().role == "system"; + for (size_t i = (have_system ? 1 : 0); i < msgs.size(); ++i) { + if (msgs[i].role == "tool" || msgs[i].role == "assistant") continue; + for (const auto& img : cr.message_images[i]) { + const std::array hw = + image_info(reinterpret_cast(img.data()), img.size()); + msgs[i].image_token_counts.push_back(image_token_count(hw[0], hw[1], pc)); + req->mm_images.emplace_back(img.begin(), img.end()); } } req->prompt_ids = tok_->apply_chat_template(msgs, true, "", {}, cr.enable_thinking); diff --git a/src/server/openai.cpp b/src/server/openai.cpp index c5e8f1a..0ee4095 100644 --- a/src/server/openai.cpp +++ b/src/server/openai.cpp @@ -167,7 +167,11 @@ ChatRequest parse_chat_request(const nlohmann::json& body) { auto it = body.find("messages"); if (it == body.end() || !it->is_array() || it->empty()) throw std::runtime_error("'messages' must be a non-empty array"); - for (const auto& m : *it) r.messages.push_back(parse_message(m, r.images)); + for (const auto& m : *it) { + std::vector imgs; // this message's images, in order + r.messages.push_back(parse_message(m, imgs)); + r.message_images.push_back(std::move(imgs)); + } parse_common(body, r); return r; } diff --git a/src/server/openai.h b/src/server/openai.h index 569914f..fbd5940 100644 --- a/src/server/openai.h +++ b/src/server/openai.h @@ -18,11 +18,17 @@ struct ChatRequest { std::string model; std::vector messages; // chat; for /v1/completions a single user msg bool is_chat = true; // chat vs raw completion - // Decoded bytes of each image attached to the conversation, in order (empty = - // none). When non-empty the request is served as a single-stream Qwen3-VL - // multimodal turn: each image is decoded, ViT-encoded, and its placeholder run - // expanded into the prompt. - std::vector images; + // Decoded image bytes per message, aligned 1:1 with `messages` (each entry is + // that turn's images, in order; most are empty). When any are present the + // request is served as a single-stream Qwen3-VL multimodal turn: each image is + // decoded, ViT-encoded, and its placeholder run expanded into the prompt at the + // position of the message it belongs to. + std::vector> message_images; + bool has_images() const { + for (const auto& imgs : message_images) + if (!imgs.empty()) return true; + return false; + } SamplingParams params; int max_tokens = 128; bool stream = false; diff --git a/tests/server/anthropic_test.cpp b/tests/server/anthropic_test.cpp index 7085e6a..74c7a16 100644 --- a/tests/server/anthropic_test.cpp +++ b/tests/server/anthropic_test.cpp @@ -182,6 +182,7 @@ TEST_CASE("parse_messages_request extracts a base64 image block as bytes") { "max_tokens": 8 })"); ChatRequest r = parse_messages_request(body); - REQUIRE(r.images.size() == 1); - CHECK(r.images[0] == "hello"); + REQUIRE(r.message_images.size() == 1); + REQUIRE(r.message_images[0].size() == 1); + CHECK(r.message_images[0][0] == "hello"); } diff --git a/tests/server/openai_test.cpp b/tests/server/openai_test.cpp index 3c63622..7b1b310 100644 --- a/tests/server/openai_test.cpp +++ b/tests/server/openai_test.cpp @@ -217,8 +217,9 @@ TEST_CASE("parse_chat_request extracts an image_url data URI as bytes") { ChatRequest r = parse_chat_request(body); REQUIRE(r.messages.size() == 1); CHECK(r.messages[0].content == "what is this?"); - REQUIRE(r.images.size() == 1); - CHECK(r.images[0] == "hello"); // routed to the multimodal path by make_request + REQUIRE(r.message_images.size() == 1); + REQUIRE(r.message_images[0].size() == 1); + CHECK(r.message_images[0][0] == "hello"); // on the user turn; routed to the multimodal path } TEST_CASE("parse_chat_request rejects a non-data: image URL") { @@ -234,7 +235,7 @@ TEST_CASE("parse_chat_request leaves image empty for text-only content") { json body = json::parse(R"({ "messages": [{"role": "user", "content": "just text"}], "max_tokens": 8 })"); - CHECK(parse_chat_request(body).images.empty()); + CHECK_FALSE(parse_chat_request(body).has_images()); } TEST_CASE("parse_chat_request collects multiple images in order") { @@ -248,7 +249,28 @@ TEST_CASE("parse_chat_request collects multiple images in order") { "max_tokens": 8 })"); ChatRequest r = parse_chat_request(body); - REQUIRE(r.images.size() == 2); - CHECK(r.images[0] == "hello"); - CHECK(r.images[1] == "world"); + REQUIRE(r.message_images.size() == 1); + REQUIRE(r.message_images[0].size() == 2); + CHECK(r.message_images[0][0] == "hello"); + CHECK(r.message_images[0][1] == "world"); +} + +TEST_CASE("parse_chat_request keeps each image on its own turn") { + // Image on the FIRST user turn; a later user turn only references it. + json body = json::parse(R"({ + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": "here is a picture"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,aGVsbG8="}}]}, + {"role": "assistant", "content": "ok"}, + {"role": "user", "content": "what was in it?"} + ], + "max_tokens": 8 + })"); + ChatRequest r = parse_chat_request(body); + REQUIRE(r.message_images.size() == 3); + REQUIRE(r.message_images[0].size() == 1); // image stays on the first user turn + CHECK(r.message_images[0][0] == "hello"); + CHECK(r.message_images[1].empty()); // assistant turn + CHECK(r.message_images[2].empty()); // later user turn } From cffe624bf846e4ac324f14902fd0c1fb149bbaa6 Mon Sep 17 00:00:00 2001 From: Helder Vasconcelos Date: Tue, 9 Jun 2026 23:21:33 +0100 Subject: [PATCH 5/6] swift: runtime-validate the vision bindings (image / images) Add a Swift test exercising engine.image (one image) and engine.images (several) against a Qwen3-VL checkpoint, gated on MLXFORGE_VL_MODEL and reading the committed fixture image. The binding now builds and runs: `swift build` compiles the images() wrapper and links the dylib (the C shim header is a symlink to the v4 C ABI, so it's already current), `swift test --filter testImagesIfVLModel Present` passes against Qwen3-VL-4B, and scripts/make_xcframework.sh repackages the distribution xcframework. No longer "source-only". Co-Authored-By: Claude Opus 4.8 --- .../Tests/MLXForgeTests/MLXForgeTests.swift | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/bindings/swift/Tests/MLXForgeTests/MLXForgeTests.swift b/bindings/swift/Tests/MLXForgeTests/MLXForgeTests.swift index 0cf6ca1..3b66ae0 100644 --- a/bindings/swift/Tests/MLXForgeTests/MLXForgeTests.swift +++ b/bindings/swift/Tests/MLXForgeTests/MLXForgeTests.swift @@ -60,4 +60,34 @@ final class MLXForgeTests: XCTestCase { XCTAssertEqual(dot(a, a), 1.0, accuracy: 0.02) // unit-normalized XCTAssertGreaterThan(dot(a, b), dot(a, c)) // semantic ordering } + + // Vision-language: one image via image(), and several via images(). Gated on a + // dedicated env var pointing at a Qwen3-VL checkpoint; reads the committed test + // image from the repo's fixtures. + func testImagesIfVLModelPresent() async throws { + guard let dir = ProcessInfo.processInfo.environment["MLXFORGE_VL_MODEL"], !dir.isEmpty else { + throw XCTSkip("MLXFORGE_VL_MODEL not set; skipping vision test") + } + let engine = try await Engine.load(dir) + + var repo = URL(fileURLWithPath: #filePath) // …/bindings/swift/Tests/MLXForgeTests/ + for _ in 0..<5 { repo.deleteLastPathComponent() } + let imgURL = repo.appendingPathComponent("reference/fixtures_qwen3_vl/image.png") + let bytes = [UInt8](try Data(contentsOf: imgURL)) + + var s = Sampling.greedy + s.maxTokens = 8 + + var single = "" + for try await chunk in try engine.image("What is in this image?", bytes, sampling: s) { + single += chunk + } + XCTAssertFalse(single.isEmpty) + + var multi = "" + for try await chunk in try engine.images("Compare these images.", [bytes, bytes], sampling: s) { + multi += chunk + } + XCTAssertFalse(multi.isEmpty) + } } From 67ca4997c36d6fd1e22b0b723934046c16a08007 Mon Sep 17 00:00:00 2001 From: Helder Vasconcelos Date: Tue, 9 Jun 2026 23:41:37 +0100 Subject: [PATCH 6/6] multimodal: prefill-single, decode-batched VL serving (vLLM/omlx-style) Qwen3-VL requests now join the continuous-batching decode pool instead of running single-stream inline on the worker thread. The ViT + image-merge + 3D-M-RoPE prefill still runs single (the ViT can't batch ragged grids), but its K/V is adopted into a batch-1 BatchKVCache and merged into the shared decode batch, where the prompt's (pure-text) generated tokens decode alongside text rows through the ordinary batched forward. The key enabler: in the batched forward the attention mask works in physical cache-slot space (idx/left_padding) while RoPE uses a *separate* per-row offset. So a VL row carries offset = max(3D position)+1 -- well below its image-padded token count -- without the cache needing per-row 3D positions. A generated VL token has t==h==w, and build_inv_freq/apply_mrope produce the identical angle to compute_rope_freqs/fast::rope, so VL decode reuses the inherited DecoderModel::forward(BatchKVCache&) unchanged. - BatchKVCache::from_single_sequence: adopt a prefilled single sequence's K/V with an explicitly-seeded decode RoPE offset. - KVCache::n_layers()/fetch(): expose per-layer K/V for the adoption. - multimodal_stream: prepare_multimodal_prefill / prefill_multimodal_batched / render_multimodal_prompt / greedy_generate_multimodal_batched, with generate_multimodal/generate_from_images refactored to share them. - Worker::admit_multimodal replaces handle_multimodal; register_rows extracted from admit and shared by both the text and multimodal admit paths. Golden-gated: batched decode reproduces the single-stream greedy stream token for token, and a vision row stays uncorrupted when batched next to a shorter text row (the per-row offset-decoupling / ragged-merge gate). Full suite green (196/196). Co-Authored-By: Claude Opus 4.8 (1M context) --- CLAUDE.md | 21 ++++-- src/cache/batch_kv_cache.cpp | 15 ++++ src/cache/batch_kv_cache.h | 12 +++ src/cache/kv_cache.h | 10 +++ src/runtime/multimodal_stream.cpp | 118 ++++++++++++++++++++++++------ src/runtime/multimodal_stream.h | 62 +++++++++++++++- src/runtime/worker.cpp | 69 ++++++++++------- src/runtime/worker.h | 17 +++-- tests/model/qwen3_vl_test.cpp | 73 ++++++++++++++++++ 9 files changed, 335 insertions(+), 62 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index d2ea98d..efc16d1 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -153,10 +153,20 @@ reference/.venv/bin/python reference/dump_ref.py not 3D `(t,h,w)` positions). `Qwen3VLModel` hand-rolls a half-split rotation with a per-frequency t/h/w selector; text tokens have `t==h==w` so it reduces to ordinary 1D RoPE (and a generated/decode token is a scalar position one past the - prompt's max — it jumps over the image's spatial extent). Vision is served - **single-stream** (`runtime/multimodal_stream`); the continuous-batching worker - is still text-only (`BatchKVCache` can't yet carry per-row 3D positions). Every - vision stage is golden-gated against `mlx-vlm` (`reference/fixtures_qwen3_vl/`). + prompt's max — it jumps over the image's spatial extent). Vision serving is + **prefill-single, decode-batched** (like vLLM/omlx): the ViT + image-merge + + 3D-M-RoPE prefill runs single-stream (`runtime/multimodal_stream`), then the + prompt's K/V is adopted into a batch-1 `BatchKVCache` + (`BatchKVCache::from_single_sequence`) and **merged into the continuous-batching + decode pool** (`Worker::admit_multimodal`) — a generated VL token is pure text + (`t==h==w`), so it decodes through the ordinary batched forward alongside text + rows. No per-row 3D positions are needed in the cache: the batched mask works in + physical-slot space (`idx`/`left_padding`) while RoPE uses a *separate* per-row + `offset`, so a VL row just carries `offset = max(3D position)+1` (well below its + image-padded token count). The prefill itself stays single (the ViT can't batch + ragged grids). Every vision stage is golden-gated against `mlx-vlm` + (`reference/fixtures_qwen3_vl/`), and batched decode is gated equal to the + single-stream stream (`tests/model/qwen3_vl_test.cpp`). ## Where things live @@ -169,7 +179,8 @@ in `apps/` (`mlxforge` server, `mlxforge-cli`), tests mirror the module path und `src/model/qwen3_vl.{h,cpp}` is the fused model (image merge + interleaved M-RoPE + DeepStack + cached decode); `src/vision/` does image decode (`stb_image`) + preprocess (smart-resize/normalize/patchify); `runtime/multimodal_stream.{h,cpp}` is -the single-stream image→text path. Selected by `create_model` on a `vision_config`. +the image→text path (single-stream prefill + batched-decode packaging). Selected by +`create_model` on a `vision_config`. **Product-facing surface:** `src/capi/` is the stable `extern "C"` ABI (`mlxforge.h`) wrapping `runtime/engine` — the public surface — and `bindings/` diff --git a/src/cache/batch_kv_cache.cpp b/src/cache/batch_kv_cache.cpp index 1b2ea70..7561a9d 100644 --- a/src/cache/batch_kv_cache.cpp +++ b/src/cache/batch_kv_cache.cpp @@ -34,6 +34,21 @@ BatchKVCache::BatchKVCache(int n_layers, const std::vector& left_padding) left_padding_(mx::array(left_padding.data(), {static_cast(left_padding.size())}, mx::int32)) {} +BatchKVCache BatchKVCache::from_single_sequence( + std::vector> kv_per_layer, int seq, int decode_offset) { + const int n_layers = static_cast(kv_per_layer.size()); + BatchKVCache c(n_layers, std::vector{0}); // batch 1, no left padding + for (int l = 0; l < n_layers; ++l) { + c.keys_[l] = std::move(kv_per_layer[l].first); + c.values_[l] = std::move(kv_per_layer[l].second); + } + c.idx_ = seq; // physical sequence length (drives the attention mask) + int off = decode_offset; + c.offset_ = mx::array(&off, {1}, mx::int32); // decoupled RoPE position + mx::eval(c.offset_); + return c; +} + int BatchKVCache::s_cap() const { return keys_[0].has_value() ? keys_[0]->shape()[2] : 0; } diff --git a/src/cache/batch_kv_cache.h b/src/cache/batch_kv_cache.h index 12cfe72..619f3ba 100644 --- a/src/cache/batch_kv_cache.h +++ b/src/cache/batch_kv_cache.h @@ -31,6 +31,18 @@ class BatchKVCache { // One cache for all layers; `left_padding[i]` is the pad count of batch row i. BatchKVCache(int n_layers, const std::vector& left_padding); + // Build a batch-1 cache from one already-prefilled sequence's per-layer K/V + // (each (1, n_kv_heads, seq, head_dim)). `decode_offset` seeds the per-row RoPE + // position for the next (decode) token. For an interleaved-M-RoPE prompt + // (Qwen3-VL) this is one past the prompt's max 3D position, which sits BELOW its + // token count (the image collapses many tokens into few positions) — so offset + // is passed explicitly rather than derived from seq, and the standard batched + // decode (whose mask works in physical-slot space while RoPE uses offset) is + // numerically correct unchanged. left_padding is 0 (a fresh single-row prefill + // has no padding). Used to admit a vision prompt into the decode pool (merge()). + static BatchKVCache from_single_sequence( + std::vector> kv_per_layer, int seq, int decode_offset); + int batch_size() const { return batch_; } int idx() const { return idx_; } // populated sequence length (_idx) // Allocated capacity along the sequence axis (0 before the first write). diff --git a/src/cache/kv_cache.h b/src/cache/kv_cache.h index fa8c935..4f95a08 100644 --- a/src/cache/kv_cache.h +++ b/src/cache/kv_cache.h @@ -24,6 +24,16 @@ class KVCache { int offset() const { return offset_; } void advance(int n_tokens) { offset_ += n_tokens; } + int n_layers() const { return static_cast(keys_.size()); } + + // Stored K/V for a layer (each (1, n_kv_heads, offset, head_dim), no capacity + // padding). Valid only after the layer has been written. Lets a prefilled + // single sequence be handed to the batched cache for continuous-batching decode + // (BatchKVCache::from_single_sequence). + std::pair fetch(int layer) const { + return {*keys_[layer], *values_[layer]}; + } + // Append this layer's K/V (each (1, n_kv_heads, L, head_dim)) along the // sequence axis and return the full cached (keys, values) to attend over. std::pair update_and_fetch(int layer, const mx::array& k, diff --git a/src/runtime/multimodal_stream.cpp b/src/runtime/multimodal_stream.cpp index 36ff9cd..d6e2592 100644 --- a/src/runtime/multimodal_stream.cpp +++ b/src/runtime/multimodal_stream.cpp @@ -70,14 +70,12 @@ GenerateResult greedy_generate_multimodal(const Qwen3VLModel& model, return result; } -GenerateResult generate_multimodal(const Qwen3VLModel& model, const VitEncoder& vit, - const std::vector& prompt_ids, - const std::vector& images_rgb, int max_tokens, - const std::vector& eos_ids, - const std::function& on_token, - const PreprocessConfig* pcfg) { +MultimodalPrefillInputs prepare_multimodal_prefill(const Qwen3VLModel& model, const VitEncoder& vit, + const std::vector& prompt_ids, + const std::vector& images_rgb, + const PreprocessConfig* pcfg) { const ModelConfig& cfg = model.config(); - if (images_rgb.empty()) throw std::runtime_error("generate_multimodal: no images"); + if (images_rgb.empty()) throw std::runtime_error("prepare_multimodal_prefill: no images"); const PreprocessConfig pc = pcfg ? *pcfg : PreprocessConfig::from(*cfg.vision); // Smart-resize + preprocess + ViT-encode each image; collect per-image merged @@ -117,23 +115,29 @@ GenerateResult generate_multimodal(const Qwen3VLModel& model, const VitEncoder& for (const auto& per_image : deepstack_parts) layer_parts.push_back(per_image[layer]); deepstack.push_back(cat(layer_parts)); } + return {features, std::move(deepstack), mrope_position_ids(prompt_ids, grids, cfg)}; +} - mx::array pos = mrope_position_ids(prompt_ids, grids, cfg); - return greedy_generate_multimodal(model, prompt_ids, features, deepstack, pos, max_tokens, eos_ids, - on_token); +GenerateResult generate_multimodal(const Qwen3VLModel& model, const VitEncoder& vit, + const std::vector& prompt_ids, + const std::vector& images_rgb, int max_tokens, + const std::vector& eos_ids, + const std::function& on_token, + const PreprocessConfig* pcfg) { + MultimodalPrefillInputs in = prepare_multimodal_prefill(model, vit, prompt_ids, images_rgb, pcfg); + return greedy_generate_multimodal(model, prompt_ids, in.features, in.deepstack, in.position_ids, + max_tokens, eos_ids, on_token); } -GenerateResult generate_from_images(const Qwen3VLModel& model, const VitEncoder& vit, - const Tokenizer& tokenizer, const std::string& user_text, - const std::vector& images_rgb, int max_tokens, - const std::vector& eos_ids, - const std::function& on_token, - const PreprocessConfig* pcfg) { - // Single-turn convenience: size each placeholder run from its image's - // dimensions (CPU math), render a one-user-message prompt with that many image - // blocks, then generate. The full chat history is handled by the caller building - // prompt_ids directly for generate_multimodal (the server path). - PreprocessConfig pc = pcfg ? *pcfg : PreprocessConfig::from(*model.config().vision); +std::vector render_multimodal_prompt(const Tokenizer& tokenizer, const Qwen3VLModel& model, + const std::string& user_text, + const std::vector& images_rgb, + const PreprocessConfig* pcfg) { + // Size each placeholder run from its image's dimensions (CPU math) and render a + // one-user-message prompt with that many image blocks. Must use the SAME + // preprocessing config the ViT prefill will use, or the placeholder count and + // the merged-patch count disagree. + const PreprocessConfig pc = pcfg ? *pcfg : PreprocessConfig::from(*model.config().vision); std::vector counts; counts.reserve(images_rgb.size()); for (const auto& rgb : images_rgb) @@ -143,10 +147,80 @@ GenerateResult generate_from_images(const Qwen3VLModel& model, const VitEncoder& msg.role = "user"; msg.content = user_text; msg.image_token_counts = counts; - std::vector ids = tokenizer.apply_chat_template({msg}, /*add_generation_prompt=*/true); + return tokenizer.apply_chat_template({msg}, /*add_generation_prompt=*/true); +} + +GenerateResult generate_from_images(const Qwen3VLModel& model, const VitEncoder& vit, + const Tokenizer& tokenizer, const std::string& user_text, + const std::vector& images_rgb, int max_tokens, + const std::vector& eos_ids, + const std::function& on_token, + const PreprocessConfig* pcfg) { + // Single-turn convenience: template the prompt, then generate. The full chat + // history is handled by the caller building prompt_ids directly for + // generate_multimodal (the server path). + const PreprocessConfig pc = pcfg ? *pcfg : PreprocessConfig::from(*model.config().vision); + std::vector ids = render_multimodal_prompt(tokenizer, model, user_text, images_rgb, &pc); return generate_multimodal(model, vit, ids, images_rgb, max_tokens, eos_ids, on_token, &pc); } +MultimodalPrefill prefill_multimodal_batched(const Qwen3VLModel& model, + const std::vector& prompt_ids, + const mx::array& features, + const std::vector& deepstack, + const mx::array& position_ids) { + // Reuse the golden-gated single-stream prefill verbatim (it writes the prompt's + // K/V into a single-sequence cache), then adopt that K/V into a batch-1 + // BatchKVCache. The decode RoPE position continues one past the prompt's max + // M-RoPE position (the prompt's positions jump over the image's spatial extent), + // which is well below the token count — so it is seeded explicitly, decoupled + // from the cache's physical length. + KVCache kv(model.config().n_layers); + mx::array logits = model.prefill(prompt_ids, features, deepstack, position_ids, kv); + const int seq = static_cast(prompt_ids.size()); + const int vocab = logits.shape()[2]; + mx::array last = mx::reshape(mx::slice(logits, {0, seq - 1, 0}, {1, seq, vocab}), {1, vocab}); + const int decode_offset = static_cast(mx::max(position_ids).item()) + 1; + + std::vector> kv_per_layer; + kv_per_layer.reserve(kv.n_layers()); + for (int l = 0; l < kv.n_layers(); ++l) kv_per_layer.push_back(kv.fetch(l)); + BatchKVCache cache = BatchKVCache::from_single_sequence(std::move(kv_per_layer), seq, + decode_offset); + cache.eval_state(); // materialize K/V + bookkeeping (detach from the prefill graph) + mx::eval(last); + return {std::move(cache), last}; +} + +GenerateResult greedy_generate_multimodal_batched(const Qwen3VLModel& model, + const std::vector& prompt_ids, + const mx::array& features, + const std::vector& deepstack, + const mx::array& position_ids, int max_tokens, + const std::vector& eos_ids, + const std::function& on_token) { + auto is_eos = [&](int id) { + return std::find(eos_ids.begin(), eos_ids.end(), id) != eos_ids.end(); + }; + MultimodalPrefill pf = + prefill_multimodal_batched(model, prompt_ids, features, deepstack, position_ids); + + GenerateResult result; + int next = greedy_row(pf.last_logits); // (1, vocab) + for (int i = 0; i < max_tokens; ++i) { + if (is_eos(next)) { + result.hit_eos = true; + break; + } + result.tokens.push_back(next); + if (on_token) on_token(next); + mx::array step(&next, {1, 1}, mx::int32); + mx::array logits = model.forward(step, pf.cache); // (1, 1, vocab), batched path + next = greedy_row(mx::reshape(logits, {1, logits.shape()[2]})); + } + return result; +} + GenerateResult generate_from_image(const Qwen3VLModel& model, const VitEncoder& vit, const Tokenizer& tokenizer, const std::string& user_text, const mx::array& image_rgb, int max_tokens, diff --git a/src/runtime/multimodal_stream.h b/src/runtime/multimodal_stream.h index d1ef155..81654f4 100644 --- a/src/runtime/multimodal_stream.h +++ b/src/runtime/multimodal_stream.h @@ -1,11 +1,13 @@ -// Single-stream greedy generation for Qwen3-VL (image -> text). +// Greedy generation for Qwen3-VL (image -> text), single-stream and batched. // // The multimodal sibling of greedy_generate(): prefill the prompt — with the ViT // image features scattered into the image_pad rows, DeepStack injection, and 3D // interleaved M-RoPE positions — into a KV cache, then decode text tokens // incrementally (each a scalar M-RoPE position one past the prompt's max). The -// continuous-batching worker is still text-only; this is the single-stream path -// the CLI uses to run a vision-language prompt end to end. +// single-stream path is what the CLI uses; the lower half of this header packages +// the same prefill for the continuous-batching worker (prefill-single, +// decode-batched), where a VL prompt's K/V joins the shared decode batch and its +// pure-text generated tokens decode alongside text rows. #pragma once #include @@ -13,6 +15,7 @@ #include "mlx/array.h" +#include "cache/batch_kv_cache.h" #include "model/qwen3_vl.h" #include "model/vision/vit.h" #include "runtime/single_stream.h" // GenerateResult @@ -69,4 +72,57 @@ GenerateResult generate_from_image(const Qwen3VLModel& model, const VitEncoder& const std::function& on_token = {}, const PreprocessConfig* pcfg = nullptr); +// --- Prefill-single, decode-batched (continuous-batching) multimodal serving --- +// +// The vision prompt is prefilled single-stream (the ViT can't batch ragged image +// grids, and M-RoPE positions are 3D and per-prompt), then handed to the text +// continuous-batching decode pool: generated tokens are pure text (t==h==w), so +// decode runs through the ordinary batched forward, mixing freely with text rows. + +// ViT features + DeepStack + (3, seq) M-RoPE positions for a templated prompt. +// Factored out of generate_multimodal so the batched worker and the single-stream +// path build their prefill inputs the same way. +struct MultimodalPrefillInputs { + mx::array features; // (num_image_tokens, hidden) merged ViT features, in order + std::vector deepstack; // per-layer image features + mx::array position_ids; // (3, seq) int32 interleaved M-RoPE positions +}; +MultimodalPrefillInputs prepare_multimodal_prefill(const Qwen3VLModel& model, const VitEncoder& vit, + const std::vector& prompt_ids, + const std::vector& images_rgb, + const PreprocessConfig* pcfg = nullptr); + +// A vision prompt prefilled and packaged for admission into the decode pool: a +// batch-1 BatchKVCache (decode RoPE offset seeded one past the prompt's max +// M-RoPE position) plus the prompt's last-token logits (1, vocab) to sample the +// first generated token from. +struct MultimodalPrefill { + BatchKVCache cache; + mx::array last_logits; // (1, vocab) +}; +MultimodalPrefill prefill_multimodal_batched(const Qwen3VLModel& model, + const std::vector& prompt_ids, + const mx::array& features, + const std::vector& deepstack, + const mx::array& position_ids); + +// Render a single-user-turn ChatML prompt for `user_text` with one image-pad run +// per image (sized from each image's dimensions). The templating half of +// generate_from_images, shared with the batched worker path. +std::vector render_multimodal_prompt(const Tokenizer& tokenizer, const Qwen3VLModel& model, + const std::string& user_text, + const std::vector& images_rgb, + const PreprocessConfig* pcfg = nullptr); + +// Greedy generation through the BATCHED decode path for a single row: prefill the +// prompt, then decode via forward(BatchKVCache&). Proves the batched serving path +// is numerically identical to greedy_generate_multimodal (golden/QA seam). +GenerateResult greedy_generate_multimodal_batched(const Qwen3VLModel& model, + const std::vector& prompt_ids, + const mx::array& features, + const std::vector& deepstack, + const mx::array& position_ids, int max_tokens, + const std::vector& eos_ids, + const std::function& on_token = {}); + } // namespace mlxforge diff --git a/src/runtime/worker.cpp b/src/runtime/worker.cpp index de69152..143debe 100644 --- a/src/runtime/worker.cpp +++ b/src/runtime/worker.cpp @@ -66,7 +66,7 @@ void Worker::handle_embedding(Request& req) { req.tokens.close(); // unblock the waiting submitter } -void Worker::handle_multimodal(Request& req) { +void Worker::admit_multimodal(const std::shared_ptr& req) { try { auto* vl = dynamic_cast(model_.get()); if (vl == nullptr || !model_->config().has_vision_tower()) { @@ -77,33 +77,41 @@ void Worker::handle_multimodal(Request& req) { vit_ = std::make_unique(*model_->config().vision, model_->weights()); } std::vector images; - images.reserve(req.mm_images.size()); - for (const auto& bytes : req.mm_images) images.push_back(decode_image(bytes.data(), bytes.size())); - - int produced = 0; - auto on_token = [&](int id) { - if (produced == 0) req.first_token_time = Request::Clock::now(); - req.tokens.push(id); - ++produced; - }; + images.reserve(req->mm_images.size()); + for (const auto& bytes : req->mm_images) + images.push_back(decode_image(bytes.data(), bytes.size())); + // A caller that pre-rendered the full chat history (the server) supplies // prompt_ids with the image placeholders already expanded (any number of - // images); the simple path (C ABI / CLI) supplies just mm_text + one image. - GenerateResult r; - if (!req.prompt_ids.empty()) { - r = generate_multimodal(*vl, *vit_, req.prompt_ids, images, req.max_tokens, req.eos_ids, - on_token); - } else { + // images); the simple path (C ABI / CLI) supplies just mm_text + image(s). + std::vector prompt_ids = req->prompt_ids; + if (prompt_ids.empty()) { if (tok_ == nullptr) throw std::runtime_error("multimodal text prompt needs a tokenizer"); - r = generate_from_images(*vl, *vit_, *tok_, req.mm_text, images, req.max_tokens, req.eos_ids, - on_token); + prompt_ids = render_multimodal_prompt(*tok_, *vl, req->mm_text, images); + } + // The row decodes as ordinary text from here, so its history/metrics treat the + // expanded prompt (image placeholders included) as the prompt. + req->prompt_ids = prompt_ids; + + // Prefill single-stream (all fallible work — ViT, scatter, M-RoPE — happens + // before we touch the shared batch), then admit into the decode pool. + MultimodalPrefillInputs in = prepare_multimodal_prefill(*vl, *vit_, prompt_ids, images); + MultimodalPrefill pf = + prefill_multimodal_batched(*vl, prompt_ids, in.features, in.deepstack, in.position_ids); + + if (!cache_) { + cache_ = std::make_unique(std::move(pf.cache)); + } else { + cache_->merge(pf.cache); } - req.finish_reason = r.hit_eos ? "stop" : "length"; + register_rows({req}, pf.last_logits); + log::debug("worker: admitted multimodal request (prompt={}, batch now {})", prompt_ids.size(), + reqs_.size()); } catch (const std::exception& e) { - log::error("worker: multimodal error: {}", e.what()); - req.finish_reason = "error"; + log::error("worker: multimodal admit error: {}", e.what()); + req->finish_reason = "error"; + req->tokens.close(); } - req.tokens.close(); } void Worker::ensure_token_bytes(int vocab) { @@ -190,13 +198,15 @@ void Worker::run() { try { if (!incoming.empty()) { - // Embedding and multimodal requests are one-shot (handled inline, on this - // thread); only text-generation requests are admitted into the decode batch. + // Embedding requests are one-shot (handled inline). Multimodal requests are + // prefilled single-stream then admitted into the decode batch (one prefill + // each — the ViT can't batch ragged grids). Text-generation requests are + // prefilled together and admitted as a group. std::vector> gen; gen.reserve(incoming.size()); for (auto& r : incoming) { if (r->embedding) handle_embedding(*r); - else if (r->is_multimodal()) handle_multimodal(*r); + else if (r->is_multimodal()) admit_multimodal(r); else gen.push_back(std::move(r)); } if (!gen.empty()) admit(gen); @@ -233,9 +243,14 @@ void Worker::admit(const std::vector>& incoming) { } else { cache_->merge(pr.cache); } + register_rows(incoming, pr.last_logits); +} +void Worker::register_rows(const std::vector>& incoming, + const mx::array& last_logits) { // Register the new rows before sampling so sample_rows() can read their params, - // penalty history (seeded with the prompt) and RNG key. + // penalty history (seeded with the prompt) and RNG key. `last_logits` rows are + // aligned to the new tail of the batch (row i -> reqs_[base + i]). const int base = static_cast(reqs_.size()); for (size_t i = 0; i < incoming.size(); ++i) { reqs_.push_back(incoming[i]); @@ -256,7 +271,7 @@ void Worker::admit(const std::vector>& incoming) { } std::vector first = - read_ids(sample_rows(pr.last_logits, base, static_cast(incoming.size()))); + read_ids(sample_rows(last_logits, base, static_cast(incoming.size()))); for (size_t i = 0; i < incoming.size(); ++i) { const int b = base + static_cast(i); diff --git a/src/runtime/worker.h b/src/runtime/worker.h index 97882af..8b65355 100644 --- a/src/runtime/worker.h +++ b/src/runtime/worker.h @@ -62,6 +62,11 @@ class Worker { // Prefill `incoming` and merge it into the decode batch (emitting each row's // first token). void admit(const std::vector>& incoming); + // Register freshly-admitted rows into the decode-batch state (already merged + // into the cache) and sample each row's first token from `last_logits` (rows + // aligned to the new tail). Shared by the text and multimodal admit paths. + void register_rows(const std::vector>& incoming, + const mx::array& last_logits); // One decode step over the whole batch: forward -> sample -> async_eval -> // push each row's token, marking finished rows. void decode_step(); @@ -78,11 +83,13 @@ class Worker { // req.embedding_result, then close its token queue. Runs on the worker thread. void handle_embedding(Request& req); - // Handle a one-shot multimodal (Qwen3-VL) request: decode the image, run the - // ViT, render the chat prompt, and stream generated tokens into req.tokens - // single-stream. Runs on the worker thread (it owns all MLX state). Errors if - // the loaded model is not a vision-language model. - void handle_multimodal(Request& req); + // Admit a multimodal (Qwen3-VL) request into the decode batch: decode the + // image(s), run the ViT, render/validate the prompt, prefill single-stream, then + // merge the prompt's K/V into the shared BatchKVCache so its (pure-text) + // generated tokens decode batched alongside text rows (prefill-single, + // decode-batched). Runs on the worker thread (it owns all MLX state). On error + // (e.g. the loaded model is not a VL model) the request is failed, not batched. + void admit_multimodal(const std::shared_ptr& req); // Constrained decoding helpers. ensure_token_bytes builds the id->output-bytes // table once (from the tokenizer); grammar_mask returns an additive (1, vocab) diff --git a/tests/model/qwen3_vl_test.cpp b/tests/model/qwen3_vl_test.cpp index 478b9cd..3c4199c 100644 --- a/tests/model/qwen3_vl_test.cpp +++ b/tests/model/qwen3_vl_test.cpp @@ -12,11 +12,13 @@ #include "mlx/ops.h" #include "mlx/transforms.h" +#include "cache/batch_kv_cache.h" #include "cache/kv_cache.h" #include "core/config.h" #include "core/weights.h" #include "model/model_factory.h" #include "model/qwen3_vl.h" +#include "runtime/batching.h" #include "runtime/multimodal_stream.h" #include "runtime/worker.h" #include "scheduler/request.h" @@ -326,6 +328,77 @@ TEST_CASE("Qwen3-VL: worker serves a multimodal request from another thread") { CHECK(req->finish_reason == "length"); } +namespace { +// Greedy ids from a (B, vocab) logits batch. +std::vector argmax_rows(const mx::array& logits) { + mx::array a = mx::contiguous(mx::astype(mx::argmax(logits, /*axis=*/-1), mx::int32)); + mx::eval(a); + return std::vector(a.data(), a.data() + a.size()); +} +} // namespace + +TEST_CASE("Qwen3-VL: batched decode path reproduces the single-stream greedy tokens") { + if (!qwen3_vl_model_available()) { + MESSAGE("Qwen3-VL model not found in HF cache; skipping batched-decode equivalence test"); + return; + } + // The continuous-batching serving path (prefill-single, decode-batched): the + // prompt's K/V is adopted into a batch-1 BatchKVCache and decoded through the + // ordinary text batched forward. A generated VL token is pure text (t==h==w), + // so this MUST match the single-stream cached-decode stream token-for-token. + const Qwen3VLModel& m = shared_qwen3_vl_model(); + mx::array feats = load_qwen3_vl_npy("vit_out.npy"); + std::vector deepstack = {load_qwen3_vl_npy("deepstack_0.npy"), + load_qwen3_vl_npy("deepstack_1.npy"), + load_qwen3_vl_npy("deepstack_2.npy")}; + std::vector ids = load_qwen3_vl_token_ids("input_ids.npy"); + mx::array pos = mrope_position_ids(ids, grid_fixture(), qwen3_vl_config()); + + GenerateResult r = greedy_generate_multimodal_batched(m, ids, feats, deepstack, pos, + /*max_tokens=*/10, /*eos_ids=*/{}); + assert_tokens_equal(r.tokens, load_qwen3_vl_token_ids("greedy_tokens.npy")); +} + +TEST_CASE("Qwen3-VL: a vision row decodes correctly batched next to a text row") { + if (!qwen3_vl_model_available()) { + MESSAGE("Qwen3-VL model not found in HF cache; skipping mixed batched-decode test"); + return; + } + // The real cross-contamination gate: a vision row (M-RoPE offset decoupled from + // its long, image-padded length) and a SHORTER pure-text row share one decode + // batch. The merge right-justifies the text row (ragged left padding) while each + // row keeps its own RoPE offset. The vision row must still reproduce the + // single-stream greedy stream despite the heterogeneous neighbor. + const Qwen3VLModel& m = shared_qwen3_vl_model(); + mx::array feats = load_qwen3_vl_npy("vit_out.npy"); + std::vector deepstack = {load_qwen3_vl_npy("deepstack_0.npy"), + load_qwen3_vl_npy("deepstack_1.npy"), + load_qwen3_vl_npy("deepstack_2.npy")}; + std::vector ids = load_qwen3_vl_token_ids("input_ids.npy"); + mx::array pos = mrope_position_ids(ids, grid_fixture(), qwen3_vl_config()); + + // Vision row -> batch-1 cache; a short text prompt (the prompt's trailing text + // tokens, no image) -> its own batch-1 cache; merge into one batch-2 cache. + MultimodalPrefill vl = prefill_multimodal_batched(m, ids, feats, deepstack, pos); + std::vector text_prompt(ids.end() - 6, ids.end()); + PrefillResult text = prefill(m, {text_prompt}); + + BatchKVCache cache = std::move(vl.cache); + cache.merge(text.cache); // batch order: [vision, text] + const int vocab = vl.last_logits.shape()[1]; + + std::vector feed = {argmax_rows(vl.last_logits)[0], + argmax_rows(mx::reshape(text.last_logits, {1, vocab}))[0]}; + std::vector vision_tokens = {feed[0]}; + for (int i = 1; i < 10; ++i) { + mx::array inputs(feed.data(), {2, 1}, mx::int32); + mx::array logits = m.forward(inputs, cache); // (2, 1, vocab), batched path + feed = argmax_rows(mx::reshape(logits, {2, vocab})); + vision_tokens.push_back(feed[0]); + } + assert_tokens_equal(vision_tokens, load_qwen3_vl_token_ids("greedy_tokens.npy")); +} + TEST_CASE("Qwen3-VL: cached KV decode reproduces the greedy tokens") { if (!qwen3_vl_model_available()) { MESSAGE("Qwen3-VL model not found in HF cache; skipping cached-decode test");