Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,20 @@ reference/.venv/bin/python reference/dump_ref.py
not 3D `(t,h,w)` positions). `Qwen3VLModel` hand-rolls a half-split rotation
with a per-frequency t/h/w selector; text tokens have `t==h==w` so it reduces to
ordinary 1D RoPE (and a generated/decode token is a scalar position one past the
prompt's max — it jumps over the image's spatial extent). Vision is served
**single-stream** (`runtime/multimodal_stream`); the continuous-batching worker
is still text-only (`BatchKVCache` can't yet carry per-row 3D positions). Every
vision stage is golden-gated against `mlx-vlm` (`reference/fixtures_qwen3_vl/`).
prompt's max — it jumps over the image's spatial extent). Vision serving is
**prefill-single, decode-batched** (like vLLM/omlx): the ViT + image-merge +
3D-M-RoPE prefill runs single-stream (`runtime/multimodal_stream`), then the
prompt's K/V is adopted into a batch-1 `BatchKVCache`
(`BatchKVCache::from_single_sequence`) and **merged into the continuous-batching
decode pool** (`Worker::admit_multimodal`) — a generated VL token is pure text
(`t==h==w`), so it decodes through the ordinary batched forward alongside text
rows. No per-row 3D positions are needed in the cache: the batched mask works in
physical-slot space (`idx`/`left_padding`) while RoPE uses a *separate* per-row
`offset`, so a VL row just carries `offset = max(3D position)+1` (well below its
image-padded token count). The prefill itself stays single (the ViT can't batch
ragged grids). Every vision stage is golden-gated against `mlx-vlm`
(`reference/fixtures_qwen3_vl/`), and batched decode is gated equal to the
single-stream stream (`tests/model/qwen3_vl_test.cpp`).

## Where things live

Expand All @@ -169,7 +179,8 @@ in `apps/` (`mlxforge` server, `mlxforge-cli`), tests mirror the module path und
`src/model/qwen3_vl.{h,cpp}` is the fused model (image merge + interleaved M-RoPE +
DeepStack + cached decode); `src/vision/` does image decode (`stb_image`) +
preprocess (smart-resize/normalize/patchify); `runtime/multimodal_stream.{h,cpp}` is
the single-stream image→text path. Selected by `create_model` on a `vision_config`.
the image→text path (single-stream prefill + batched-decode packaging). Selected by
`create_model` on a `vision_config`.

**Product-facing surface:** `src/capi/` is the stable `extern "C"` ABI
(`mlxforge.h`) wrapping `runtime/engine` — the public surface — and `bindings/`
Expand Down
6 changes: 6 additions & 0 deletions bindings/node/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ export class Engine {
* be a vision-language checkpoint (e.g. Qwen3-VL).
*/
image(prompt: string, image: Uint8Array, sampling?: SamplingOptions): Stream;
/**
* Stream a vision-language completion over several images (raw encoded bytes
* each), expanded into the prompt in order. The model must be a vision-language
* checkpoint (e.g. Qwen3-VL).
*/
images(prompt: string, images: Uint8Array[], sampling?: SamplingOptions): Stream;
/** Run a chat to completion and return the full string. */
complete(messages: ChatMessage[], sampling?: SamplingOptions): Promise<string>;
/**
Expand Down
9 changes: 9 additions & 0 deletions bindings/node/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ class Engine {
return streamRequest(this._h.submitImage(prompt, imageBuffer, sampling));
}

/**
* Stream a vision-language completion over several images (an array of
* Buffers/Uint8Arrays of raw encoded bytes), expanded into the prompt in
* order. The loaded model must be a vision-language checkpoint (e.g. Qwen3-VL).
*/
images(prompt, imageBuffers, sampling = {}) {
return streamRequest(this._h.submitImages(prompt, imageBuffers, sampling));
}

/** Run a chat to completion and return the full string. */
complete(messages, sampling = {}) {
return collect(this.chat(messages, sampling));
Expand Down
29 changes: 29 additions & 0 deletions bindings/node/src/addon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class EngineWrap : public Napi::ObjectWrap<EngineWrap> {
InstanceMethod("submitChat", &EngineWrap::SubmitChat),
InstanceMethod("submitText", &EngineWrap::SubmitText),
InstanceMethod("submitImage", &EngineWrap::SubmitImage),
InstanceMethod("submitImages", &EngineWrap::SubmitImages),
InstanceMethod("embed", &EngineWrap::Embed),
InstanceMethod("dispose", &EngineWrap::Dispose),
});
Expand Down Expand Up @@ -344,6 +345,34 @@ class EngineWrap : public Napi::ObjectWrap<EngineWrap> {
return finish_submit(env, req, err);
}

Napi::Value SubmitImages(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
if (!eng_) throw Napi::Error::New(env, "engine is disposed");
if (info.Length() < 2 || !info[0].IsString() || !info[1].IsArray())
throw Napi::TypeError::New(env, "submitImages(prompt, imageBuffers[], sampling?)");
std::string prompt = info[0].As<Napi::String>().Utf8Value();
Napi::Array arr = info[1].As<Napi::Array>();
// Pointers into the JS Buffers; valid for the synchronous submit call (the
// engine copies the bytes before returning). The Buffers stay alive via `arr`.
std::vector<mlxforge_image> images;
images.reserve(arr.Length());
for (uint32_t i = 0; i < arr.Length(); ++i) {
Napi::Value v = arr.Get(i);
if (!v.IsBuffer()) throw Napi::TypeError::New(env, "each image must be a Buffer");
Napi::Buffer<uint8_t> buf = v.As<Napi::Buffer<uint8_t>>();
images.push_back(mlxforge_image{buf.Data(), buf.Length()});
}
std::string schema; // kept alive across the submit call
mlxforge_sampling s = (info.Length() >= 3 && info[2].IsObject())
? parse_sampling(info[2].As<Napi::Object>(), schema)
: mlxforge_sampling{};
if (!schema.empty()) s.json_schema = schema.c_str();
char* err = nullptr;
mlxforge_request* req =
mlxforge_submit_images(eng_, prompt.c_str(), images.data(), images.size(), &s, &err);
return finish_submit(env, req, err);
}

Napi::Value Embed(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
auto deferred = Napi::Promise::Deferred::New(env);
Expand Down
46 changes: 46 additions & 0 deletions bindings/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ struct Msg {
content: *const c_char,
}

#[repr(C)]
struct CImage {
data: *const u8,
len: usize,
}

#[repr(C)]
struct CSampling {
temperature: c_float,
Expand Down Expand Up @@ -106,6 +112,14 @@ extern "C" {
sampling: *const CSampling,
err: *mut *mut c_char,
) -> *mut mlxforge_request;
fn mlxforge_submit_images(
e: *mut mlxforge_engine,
prompt: *const c_char,
images: *const CImage,
n_images: usize,
sampling: *const CSampling,
err: *mut *mut c_char,
) -> *mut mlxforge_request;
fn mlxforge_request_next(r: *mut mlxforge_request, text: *mut *mut c_char) -> c_int;
fn mlxforge_request_finish_reason(r: *mut mlxforge_request) -> *const c_char;
fn mlxforge_request_free(r: *mut mlxforge_request);
Expand Down Expand Up @@ -291,6 +305,38 @@ impl Engine {
Ok(drain(req))
}

/// Run a vision-language completion over several images: a text `prompt`
/// about `images` (each raw encoded bytes), expanded into the prompt in
/// order. The loaded model must be a vision-language checkpoint (Qwen3-VL).
pub fn images(&self, prompt: &str, images: &[&[u8]], sampling: &Sampling)
-> Result<String, String>
{
let cprompt = CString::new(prompt).map_err(|_| "prompt contains NUL".to_string())?;
// Raw pointers into the borrowed slices; valid for the duration of the
// submit call (the engine copies the bytes synchronously).
let cimgs: Vec<CImage> = images
.iter()
.map(|b| CImage { data: b.as_ptr(), len: b.len() })
.collect();
let mut schema_keep: Option<CString> = None;
let cs = Self::c_sampling(sampling, &mut schema_keep);
let mut err: *mut c_char = ptr::null_mut();
let req = unsafe {
mlxforge_submit_images(
self.handle,
cprompt.as_ptr(),
cimgs.as_ptr(),
cimgs.len(),
&cs,
&mut err,
)
};
if req.is_null() {
return Err(unsafe { take_string(err) }.unwrap_or_else(|| "submit failed".into()));
}
Ok(drain(req))
}

/// Embed text into a unit-normalized vector. `pooling`: 0 = mean, 1 = last.
/// Simple form (no EOS/instruction); for Qwen3-Embedding conventions or to
/// let the model pick its defaults, use [`Engine::embed_with`].
Expand Down
33 changes: 33 additions & 0 deletions bindings/swift/Sources/MLXForge/MLXForge.swift
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,39 @@ public final class Engine {
return Engine.stream(req)
}

/// Stream a vision-language completion over several images (each raw encoded
/// bytes), expanded into the prompt in order. The loaded model must be a
/// vision-language checkpoint (e.g. Qwen3-VL).
public func images(_ prompt: String, _ imagesBytes: [[UInt8]], sampling: Sampling = .greedy)
throws -> AsyncThrowingStream<String, Error>
{
var s = sampling.c
let schemaC = sampling.jsonSchema.map { strdup($0) } ?? nil
if let p = schemaC { s.json_schema = UnsafePointer(p) }
defer { if let p = schemaC { free(p) } }
var err: UnsafeMutablePointer<CChar>?
// Flatten into one contiguous buffer so each mlxforge_image points into stable
// storage for the call (the engine copies the bytes synchronously).
let flat = imagesBytes.flatMap { $0 }
let reqOpt = flat.withUnsafeBufferPointer { (fb) -> OpaquePointer? in
var cImages: [mlxforge_image] = []
var offset = 0
for img in imagesBytes {
cImages.append(mlxforge_image(data: fb.baseAddress.map { $0 + offset }, len: img.count))
offset += img.count
}
return cImages.withUnsafeBufferPointer { cb in
mlxforge_submit_images(handle, prompt, cb.baseAddress, cb.count, &s, &err)
}
}
guard let req = reqOpt else {
let message = err.map { String(cString: $0) } ?? "submit failed"
mlxforge_string_free(err)
throw MLXForgeError(message: message)
}
return Engine.stream(req)
}

/// Run a chat to completion and return the full string.
public func complete(_ messages: [ChatMessage], sampling: Sampling = .greedy) async throws
-> String
Expand Down
30 changes: 30 additions & 0 deletions bindings/swift/Tests/MLXForgeTests/MLXForgeTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,34 @@ final class MLXForgeTests: XCTestCase {
XCTAssertEqual(dot(a, a), 1.0, accuracy: 0.02) // unit-normalized
XCTAssertGreaterThan(dot(a, b), dot(a, c)) // semantic ordering
}

// Vision-language: one image via image(), and several via images(). Gated on a
// dedicated env var pointing at a Qwen3-VL checkpoint; reads the committed test
// image from the repo's fixtures.
func testImagesIfVLModelPresent() async throws {
guard let dir = ProcessInfo.processInfo.environment["MLXFORGE_VL_MODEL"], !dir.isEmpty else {
throw XCTSkip("MLXFORGE_VL_MODEL not set; skipping vision test")
}
let engine = try await Engine.load(dir)

var repo = URL(fileURLWithPath: #filePath) // …/bindings/swift/Tests/MLXForgeTests/<file>
for _ in 0..<5 { repo.deleteLastPathComponent() }
let imgURL = repo.appendingPathComponent("reference/fixtures_qwen3_vl/image.png")
let bytes = [UInt8](try Data(contentsOf: imgURL))

var s = Sampling.greedy
s.maxTokens = 8

var single = ""
for try await chunk in try engine.image("What is in this image?", bytes, sampling: s) {
single += chunk
}
XCTAssertFalse(single.isEmpty)

var multi = ""
for try await chunk in try engine.images("Compare these images.", [bytes, bytes], sampling: s) {
multi += chunk
}
XCTAssertFalse(multi.isEmpty)
}
}
1 change: 1 addition & 0 deletions cmake/abi-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ mlxforge_request_next
mlxforge_string_free
mlxforge_submit_chat
mlxforge_submit_image
mlxforge_submit_images
mlxforge_submit_text
mlxforge_version
15 changes: 15 additions & 0 deletions src/cache/batch_kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,21 @@ BatchKVCache::BatchKVCache(int n_layers, const std::vector<int>& left_padding)
left_padding_(mx::array(left_padding.data(), {static_cast<int>(left_padding.size())},
mx::int32)) {}

BatchKVCache BatchKVCache::from_single_sequence(
std::vector<std::pair<mx::array, mx::array>> kv_per_layer, int seq, int decode_offset) {
const int n_layers = static_cast<int>(kv_per_layer.size());
BatchKVCache c(n_layers, std::vector<int>{0}); // batch 1, no left padding
for (int l = 0; l < n_layers; ++l) {
c.keys_[l] = std::move(kv_per_layer[l].first);
c.values_[l] = std::move(kv_per_layer[l].second);
}
c.idx_ = seq; // physical sequence length (drives the attention mask)
int off = decode_offset;
c.offset_ = mx::array(&off, {1}, mx::int32); // decoupled RoPE position
mx::eval(c.offset_);
return c;
}

int BatchKVCache::s_cap() const {
return keys_[0].has_value() ? keys_[0]->shape()[2] : 0;
}
Expand Down
12 changes: 12 additions & 0 deletions src/cache/batch_kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,18 @@ class BatchKVCache {
// One cache for all layers; `left_padding[i]` is the pad count of batch row i.
BatchKVCache(int n_layers, const std::vector<int>& left_padding);

// Build a batch-1 cache from one already-prefilled sequence's per-layer K/V
// (each (1, n_kv_heads, seq, head_dim)). `decode_offset` seeds the per-row RoPE
// position for the next (decode) token. For an interleaved-M-RoPE prompt
// (Qwen3-VL) this is one past the prompt's max 3D position, which sits BELOW its
// token count (the image collapses many tokens into few positions) — so offset
// is passed explicitly rather than derived from seq, and the standard batched
// decode (whose mask works in physical-slot space while RoPE uses offset) is
// numerically correct unchanged. left_padding is 0 (a fresh single-row prefill
// has no padding). Used to admit a vision prompt into the decode pool (merge()).
static BatchKVCache from_single_sequence(
std::vector<std::pair<mx::array, mx::array>> kv_per_layer, int seq, int decode_offset);

int batch_size() const { return batch_; }
int idx() const { return idx_; } // populated sequence length (_idx)
// Allocated capacity along the sequence axis (0 before the first write).
Expand Down
10 changes: 10 additions & 0 deletions src/cache/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ class KVCache {
int offset() const { return offset_; }
void advance(int n_tokens) { offset_ += n_tokens; }

int n_layers() const { return static_cast<int>(keys_.size()); }

// Stored K/V for a layer (each (1, n_kv_heads, offset, head_dim), no capacity
// padding). Valid only after the layer has been written. Lets a prefilled
// single sequence be handed to the batched cache for continuous-batching decode
// (BatchKVCache::from_single_sequence).
std::pair<mx::array, mx::array> fetch(int layer) const {
return {*keys_[layer], *values_[layer]};
}

// Append this layer's K/V (each (1, n_kv_heads, L, head_dim)) along the
// sequence axis and return the full cached (keys, values) to attend over.
std::pair<mx::array, mx::array> update_and_fetch(int layer, const mx::array& k,
Expand Down
Loading
Loading