From 16d6003d5ff8736c140771fc11b15b993e610e28 Mon Sep 17 00:00:00 2001 From: Helder Vasconcelos Date: Wed, 10 Jun 2026 14:03:11 +0100 Subject: [PATCH] Add per-token logprobs output (OpenAI logprobs/top_logprobs) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Expose per-token log-probabilities — each chosen token's logprob plus its top-N alternatives — through the OpenAI HTTP server, the CLI, and the C ABI + Node/Swift/Rust bindings. The engine already computed the chosen-token logprob and discarded it; this surfaces it with full OpenAI parity. Engine: - Sampler computes logprobs from log_softmax(adj) (penalized, pre-temperature, pre-filter): the chosen-token logprob plus, when requested, the descending top-K (id, logprob). Pure graph ops, no extra eval. - A TokenLogprob travels worker->consumer on a parallel SPSC LogprobQueue, pushed in lockstep with each emitted token (EOS, never emitted, has none). The worker collects per-row logprob arrays only for requesting rows and folds them into the single async_eval per decode step. Surfaces: - OpenAI server: parse logprobs/top_logprobs; serialize choices[].logprobs.content in both blocking and streaming responses. Vision requests opt out. - CLI: generate --logprobs [N] prints per-token logprobs to stderr. - C ABI v4: mlxforge_sampling.logprobs + mlxforge_request_logprobs (append-only, baseline updated). Node/Rust/Swift gain post-completion logprobs accessors. Tests: sampler ordering/normalization, single-stream alignment, server parse + JSON shape, and a C-ABI round-trip. All 200 ctest pass; check-abi.sh green. Co-Authored-By: Claude Opus 4.8 --- apps/mlxforge_cli.cpp | 65 ++++++++++++-- bindings/node/index.d.ts | 23 +++++ bindings/node/index.js | 11 +++ bindings/node/src/addon.cc | 17 ++++ bindings/rust/src/lib.rs | 64 +++++++++++++ .../swift/Sources/MLXForge/MLXForge.swift | 36 +++++++- cmake/abi-baseline.txt | 1 + src/capi/mlxforge.cpp | 50 +++++++++++ src/capi/mlxforge.h | 24 ++++- src/runtime/single_stream.cpp | 52 +++++++++-- src/runtime/single_stream.h | 10 ++- src/runtime/worker.cpp | 81 ++++++++++++++--- src/runtime/worker.h | 22 ++++- src/sample/sampler.cpp | 61 ++++++++++--- src/sample/sampler.h | 10 +++ src/scheduler/request.h | 46 +++++++--- src/server/http_server.cpp | 36 +++++++- src/server/openai.cpp | 54 +++++++++-- src/server/openai.h | 25 +++++- tests/capi/capi_test.cpp | 63 +++++++++++++ tests/runtime/single_stream_test.cpp | 30 +++++++ tests/sample/sampler_test.cpp | 61 +++++++++++++ tests/server/openai_test.cpp | 90 +++++++++++++++++++ 23 files changed, 860 insertions(+), 72 deletions(-) diff --git a/apps/mlxforge_cli.cpp b/apps/mlxforge_cli.cpp index 08d7d5a..87694d4 100644 --- a/apps/mlxforge_cli.cpp +++ b/apps/mlxforge_cli.cpp @@ -6,10 +6,11 @@ // mlxforge-cli dump-weights // - Loads a model's weights from the supplied directory, prints key/shape/dtype for each tensor, // asserts that all tensors are fp16, and reports the peak resident memory used. -// mlxforge-cli generate [max_tokens] +// mlxforge-cli generate [max_tokens] [--logprobs [N]] // - Runs greedy single-stream generation: pre-fills the prompt (as raw text using the chat template // or as a pre-tokenized .npy of ids), then streams the detokenized text to stdout until EOS or -// max_tokens. +// max_tokens. With --logprobs [N], each emitted token's log-prob (and its N most-likely +// alternatives) is printed to stderr after generation; stdout stays the generated text. // mlxforge-cli bench [max_tokens] [runs] // - Repeatable throughput benchmark over a fixed prompt: one discarded warmup run, then `runs` // timed runs (defaults: max_tokens=128, runs=3) reporting time-to-first-token and decode tok/s. @@ -22,6 +23,7 @@ // which will be downloaded on first use. #include +#include #include #include #include @@ -162,8 +164,24 @@ int run_dump_weights(const std::string& spec) { return non_fp16 == 0 ? 0 : 1; } +// Render a token's text for the logprobs dump: quoted, with the common control +// characters escaped so whitespace tokens stay legible on one line. +std::string show_token(const std::string& s) { + std::string out = "\""; + for (char c : s) { + if (c == '\n') out += "\\n"; + else if (c == '\t') out += "\\t"; + else if (c == '\r') out += "\\r"; + else out += c; + } + return out + "\""; +} + // Performs generation using a loaded model, with either raw text or pre-tokenized prompts. -int run_generate(const std::string& spec, const std::string& prompt_arg, int max_tokens) { +// `top_logprobs` mirrors the engine knob: -1 = off; 0 = each token's own log-prob; +// N > 0 = also its N most-likely alternatives (printed to stderr after generation). +int run_generate(const std::string& spec, const std::string& prompt_arg, int max_tokens, + int top_logprobs = -1) { // Resolve and load the model (GGUF file or safetensors dir; downloads if needed) LoadedModel lm = load_for_inference(spec); mlxforge::DecoderModel& model = *lm.model; @@ -194,13 +212,29 @@ int run_generate(const std::string& spec, const std::string& prompt_arg, int max std::string piece = detok.add(id); std::fwrite(piece.data(), 1, piece.size(), stdout); std::fflush(stdout); - }); + }, top_logprobs); // Output any final detokenized tail remaining in the streaming detokenizer std::string tail = detok.finish(); std::fwrite(tail.data(), 1, tail.size(), stdout); std::fputc('\n', stdout); + // Per-token log-probs go to stderr (stdout stays the generated text). One line + // per emitted token: its text + log-prob, then any requested alternatives. + if (top_logprobs >= 0) { + std::fprintf(stderr, "logprobs (%zu tokens):\n", r.token_logprobs.size()); + for (const mlxforge::TokenLogprob& lp : r.token_logprobs) { + std::fprintf(stderr, " %-16s logprob=%8.4f", show_token(tok.decode({lp.id})).c_str(), + lp.logprob); + if (!lp.top.empty()) { + std::fprintf(stderr, " top:"); + for (const auto& alt : lp.top) + std::fprintf(stderr, " %s=%.4f", show_token(tok.decode({alt.first})).c_str(), alt.second); + } + std::fputc('\n', stderr); + } + } + // Log some generation statistics mlxforge::log::info("generated {} tokens{}", r.tokens.size(), r.hit_eos ? " (stopped at EOS)" : ""); @@ -334,12 +368,27 @@ int main(int argc, char** argv) { if (cmd == "generate") { // Print usage if not enough arguments; otherwise run generation logic if (argc < 4) { - std::fprintf(stderr, "usage: mlxforge-cli generate [max_tokens]\n"); + std::fprintf(stderr, + "usage: mlxforge-cli generate [max_tokens] " + "[--logprobs [N]]\n"); return 2; } - // Parse max_tokens if provided, otherwise default to 64 - const int max_tokens = argc >= 5 ? std::stoi(argv[4]) : 64; - return run_generate(argv[2], argv[3], max_tokens); + // Positional [max_tokens] (default 64) plus an optional --logprobs [N] flag (N + // alternatives, default 0 = the chosen token's own log-prob only). + int max_tokens = 64; + int top_logprobs = -1; + for (int i = 4; i < argc; ++i) { + const std::string a = argv[i]; + if (a == "--logprobs") { + if (i + 1 < argc && std::isdigit(static_cast(argv[i + 1][0]))) + top_logprobs = std::stoi(argv[++i]); + else + top_logprobs = 0; + } else { + max_tokens = std::stoi(a); + } + } + return run_generate(argv[2], argv[3], max_tokens, top_logprobs); } if (cmd == "image") { // Vision-language generation: describe / answer about an image. diff --git a/bindings/node/index.d.ts b/bindings/node/index.d.ts index fe8982f..f35567d 100644 --- a/bindings/node/index.d.ts +++ b/bindings/node/index.d.ts @@ -24,6 +24,12 @@ export interface SamplingOptions { seed?: number; /** Max new tokens (default 64). */ maxTokens?: number; + /** + * OpenAI logprobs. 0 (default) => off. N > 0 => report each emitted token's + * own log-prob plus its (N - 1) most-likely alternatives (so 1 = chosen-only). + * Retrieve them from the stream's {@link Stream.logprobs} after consumption. + */ + logprobs?: number; /** * Constrained decoding. "json" forces any valid JSON value; otherwise a * JSON-Schema string (supported subset: a top-level object with ordered, @@ -39,12 +45,29 @@ export interface ChatMessage { content: string; } +/** One token's log-probability (OpenAI logprobs `content` entry). */ +export interface TokenLogprob { + /** The token's decoded text. */ + token: string; + /** Natural-log probability of the token (<= 0). */ + logprob: number; + /** The token's raw UTF-8 bytes. */ + bytes: number[]; + /** The most-likely alternatives at this position (may be empty). */ + top_logprobs: { token: string; logprob: number; bytes: number[] }[]; +} + /** An async-iterable stream of decoded text chunks for one request. */ export interface Stream extends AsyncIterable { /** Ask the engine to stop generating this request. */ cancel(): void; /** "stop" | "length" | "cancel" | "" (while running). */ readonly finishReason: string; + /** + * Per-token log-probs, available after the stream has been fully consumed. + * Returns null when `logprobs` was not requested (or none were produced). + */ + logprobs(): TokenLogprob[] | null; } export class Engine { diff --git a/bindings/node/index.js b/bindings/node/index.js index 87c3b67..02954f5 100644 --- a/bindings/node/index.js +++ b/bindings/node/index.js @@ -8,6 +8,9 @@ const native = require('./build/Release/mlxforge_node.node'); // Wrap a native Request as an async-iterable of decoded text chunks. The // request is disposed when iteration finishes (or the consumer breaks out). function streamRequest(req) { + // Captured once the stream is fully drained (the native handle is disposed + // right after, so we read logprobs while it is still valid). + let logprobs = null; return { // The underlying native handle, in case callers want cancel()/finishReason(). request: req, @@ -17,10 +20,18 @@ function streamRequest(req) { get finishReason() { return req.finishReason(); }, + /** + * Per-token log-probs (OpenAI logprobs `content` shape), or null when none + * were requested. Available after the stream has been fully consumed. + */ + logprobs() { + return logprobs; + }, async *[Symbol.asyncIterator]() { try { let chunk; while ((chunk = await req.next()) !== null) yield chunk; + logprobs = req.logprobs(); // capture before dispose() invalidates the handle } finally { req.dispose(); } diff --git a/bindings/node/src/addon.cc b/bindings/node/src/addon.cc index f6f498c..f127e67 100644 --- a/bindings/node/src/addon.cc +++ b/bindings/node/src/addon.cc @@ -32,6 +32,9 @@ mlxforge_sampling parse_sampling(const Napi::Object& o, std::string& schema_out) s.presence_penalty = static_cast(num("presencePenalty", 0.0)); s.seed = static_cast(num("seed", 0.0)); s.max_tokens = static_cast(num("maxTokens", 0.0)); + // OpenAI logprobs: 0 = off; N > 0 = the chosen token's logprob plus (N - 1) + // alternatives (so 1 = chosen-only). Retrieved via Request.logprobs(). + s.logprobs = static_cast(num("logprobs", 0.0)); if (o.Has("jsonSchema") && o.Get("jsonSchema").IsString()) schema_out = o.Get("jsonSchema").As().Utf8Value(); else if (o.Has("responseFormat") && o.Get("responseFormat").IsString()) @@ -53,6 +56,7 @@ class RequestWrap : public Napi::ObjectWrap { InstanceMethod("next", &RequestWrap::Next), InstanceMethod("cancel", &RequestWrap::Cancel), InstanceMethod("finishReason", &RequestWrap::FinishReason), + InstanceMethod("logprobs", &RequestWrap::Logprobs), InstanceMethod("dispose", &RequestWrap::Dispose), }); g_request_ctor = Napi::Persistent(f); @@ -93,6 +97,19 @@ class RequestWrap : public Napi::ObjectWrap { return Napi::String::New(info.Env(), req_ ? mlxforge_request_finish_reason(req_) : ""); } + // The accumulated per-token log-probs, parsed from the C ABI's OpenAI-shaped + // JSON (or null when none / not requested). Call after the stream is drained, + // before dispose(). + Napi::Value Logprobs(const Napi::CallbackInfo& info) { + Napi::Env env = info.Env(); + char* j = req_ ? mlxforge_request_logprobs(req_) : nullptr; + if (!j) return env.Null(); + std::string s(j); + mlxforge_string_free(j); + Napi::Object json = env.Global().Get("JSON").As(); + return json.Get("parse").As().Call(json, {Napi::String::New(env, s)}); + } + void Dispose(const Napi::CallbackInfo&) { free_req(); } }; diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 53e5e1f..73c80cb 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -48,6 +48,7 @@ struct CSampling { seed: u64, max_tokens: c_int, json_schema: *const c_char, + logprobs: c_int, } // Mirrors mlxforge_embed_opts. pooling/add_eos are tri-state: -1 = model default. @@ -122,6 +123,7 @@ extern "C" { ) -> *mut mlxforge_request; fn mlxforge_request_next(r: *mut mlxforge_request, text: *mut *mut c_char) -> c_int; fn mlxforge_request_finish_reason(r: *mut mlxforge_request) -> *const c_char; + fn mlxforge_request_logprobs(r: *mut mlxforge_request) -> *mut c_char; fn mlxforge_request_free(r: *mut mlxforge_request); } @@ -159,6 +161,9 @@ pub struct Sampling { pub max_tokens: i32, /// Constrained decoding: "json" or a JSON-Schema string (see the C ABI docs). pub json_schema: Option, + /// OpenAI logprobs: 0 = off; N > 0 = the chosen token's log-prob plus (N - 1) + /// alternatives (so 1 = chosen-only). Retrieve via [`Engine::chat_with_logprobs`]. + pub logprobs: i32, } impl Sampling { @@ -234,6 +239,7 @@ impl Engine { seed: s.seed, max_tokens: s.max_tokens, json_schema, + logprobs: s.logprobs, } } @@ -268,6 +274,42 @@ impl Engine { Ok(drain(req)) } + /// Run a chat completion and return both the text and the per-token log-probs + /// (the C ABI's OpenAI-shaped JSON `content` array, or `None` when `logprobs` + /// was not set on `sampling`). + pub fn chat_with_logprobs( + &self, + messages: &[(&str, &str)], + sampling: &Sampling, + ) -> Result<(String, Option), String> { + let owned: Vec<(CString, CString)> = messages + .iter() + .map(|(r, c)| { + ( + CString::new(*r).unwrap_or_default(), + CString::new(*c).unwrap_or_default(), + ) + }) + .collect(); + let msgs: Vec = owned + .iter() + .map(|(r, c)| Msg { + role: r.as_ptr(), + content: c.as_ptr(), + }) + .collect(); + + let mut schema_keep: Option = None; + let cs = Self::c_sampling(sampling, &mut schema_keep); + let mut err: *mut c_char = ptr::null_mut(); + let req = + unsafe { mlxforge_submit_chat(self.handle, msgs.as_ptr(), msgs.len(), &cs, &mut err) }; + if req.is_null() { + return Err(unsafe { take_string(err) }.unwrap_or_else(|| "submit failed".into())); + } + Ok(drain_with_logprobs(req)) + } + /// Run a raw-text completion (no chat template) to completion. pub fn text(&self, prompt: &str, sampling: &Sampling) -> Result { let cprompt = CString::new(prompt).map_err(|_| "prompt contains NUL".to_string())?; @@ -436,3 +478,25 @@ fn drain(req: *mut mlxforge_request) -> String { unsafe { mlxforge_request_free(req) }; out } + +// Drain a request to text and then read its accumulated per-token log-probs +// (the OpenAI-shaped JSON `content` array, or None when none were produced), +// freeing the request. +fn drain_with_logprobs(req: *mut mlxforge_request) -> (String, Option) { + let mut out = String::new(); + loop { + let mut text: *mut c_char = ptr::null_mut(); + let rc = unsafe { mlxforge_request_next(req, &mut text) }; + if rc == 0 { + if let Some(s) = unsafe { take_string(text) } { + out.push_str(&s); + } + } else { + break; + } + } + // Read logprobs while the request is still alive, then free it. + let logprobs = unsafe { take_string(mlxforge_request_logprobs(req)) }; + unsafe { mlxforge_request_free(req) }; + (out, logprobs) +} diff --git a/bindings/swift/Sources/MLXForge/MLXForge.swift b/bindings/swift/Sources/MLXForge/MLXForge.swift index a09eeea..2f4830e 100644 --- a/bindings/swift/Sources/MLXForge/MLXForge.swift +++ b/bindings/swift/Sources/MLXForge/MLXForge.swift @@ -23,6 +23,9 @@ public struct Sampling { /// JSON-Schema string (supported subset: a top-level object with ordered, /// required, scalar-typed properties). Output is masked to be well-formed JSON. public var jsonSchema: String? = nil + /// OpenAI logprobs: 0 = off; N > 0 = the chosen token's log-prob plus (N - 1) + /// alternatives (so 1 = chosen-only). Retrieved via `Engine.completeWithLogprobs`. + public var logprobs: Int32 = 0 public init() {} public static var greedy: Sampling { Sampling() } @@ -34,7 +37,7 @@ public struct Sampling { temperature: temperature, top_k: topK, top_p: topP, min_p: minP, repetition_penalty: repetitionPenalty, frequency_penalty: frequencyPenalty, presence_penalty: presencePenalty, seed: seed, max_tokens: maxTokens, - json_schema: nil) + json_schema: nil, logprobs: logprobs) } } @@ -168,6 +171,37 @@ public final class Engine { return out } + /// Run a chat to completion and return both the text and the per-token log-probs + /// (the C ABI's OpenAI-shaped JSON `content` string, or nil when `logprobs` was + /// not set on `sampling`). Logprobs are read once the stream is fully drained. + public func completeWithLogprobs(_ messages: [ChatMessage], sampling: Sampling = .greedy) + async throws -> (text: String, logprobs: String?) + { + let req = try submitChat(messages, sampling) + return try await withCheckedThrowingContinuation { cont in + DispatchQueue.global(qos: .userInitiated).async { + var out = "" + while true { + var text: UnsafeMutablePointer? + let rc = mlxforge_request_next(req, &text) + if rc == 0, let t = text { + out += String(cString: t) + mlxforge_string_free(t) + } else { + break + } + } + var logprobs: String? = nil + if let lp = mlxforge_request_logprobs(req) { + logprobs = String(cString: lp) + mlxforge_string_free(lp) + } + mlxforge_request_free(req) + cont.resume(returning: (out, logprobs)) + } + } + } + /// Embed text into a (by default unit-normalized) vector. /// /// With all defaults the model self-selects its convention: a Qwen3-Embedding diff --git a/cmake/abi-baseline.txt b/cmake/abi-baseline.txt index aa934c5..cc2cf02 100644 --- a/cmake/abi-baseline.txt +++ b/cmake/abi-baseline.txt @@ -9,6 +9,7 @@ mlxforge_floats_free mlxforge_request_cancel mlxforge_request_finish_reason mlxforge_request_free +mlxforge_request_logprobs mlxforge_request_next mlxforge_string_free mlxforge_submit_chat diff --git a/src/capi/mlxforge.cpp b/src/capi/mlxforge.cpp index f96f53e..712a026 100644 --- a/src/capi/mlxforge.cpp +++ b/src/capi/mlxforge.cpp @@ -15,6 +15,8 @@ #include #include +#include + #include "core/config.h" #include "runtime/engine.h" #include "scheduler/request.h" @@ -61,6 +63,9 @@ mlxforge::SamplingParams to_params(const mlxforge_sampling* s) { p.frequency_penalty = std::isfinite(s->frequency_penalty) ? s->frequency_penalty : 0.0f; p.presence_penalty = std::isfinite(s->presence_penalty) ? s->presence_penalty : 0.0f; p.seed = s->seed; + // logprobs: 0 => off (-1); N > 0 => the chosen token's logprob plus (N - 1) + // alternatives, so the engine's top_logprobs (alternatives count) is N - 1. + p.top_logprobs = s->logprobs > 0 ? s->logprobs - 1 : -1; return p; } @@ -68,6 +73,14 @@ int sampling_max_tokens(const mlxforge_sampling* s) { return (s && s->max_tokens > 0) ? s->max_tokens : 64; } +// One {token, logprob, bytes} entry: decode `id` to its text and raw UTF-8 bytes. +nlohmann::json lp_entry(int id, float logprob, const mlxforge::Tokenizer& tok) { + const std::string text = tok.decode({id}); + nlohmann::json bytes = nlohmann::json::array(); + for (unsigned char c : text) bytes.push_back(static_cast(c)); + return {{"token", text}, {"logprob", logprob}, {"bytes", std::move(bytes)}}; +} + } // namespace // Opaque handles. The engine wrapper owns the C++ Engine. The request wrapper @@ -83,6 +96,12 @@ struct mlxforge_request { std::shared_ptr req; std::unique_ptr detok; bool done = false; + // OpenAI logprobs: when `want_logprobs`, each token drained by + // mlxforge_request_next pops its log-prob (in lockstep) into `logprobs`; + // mlxforge_request_logprobs serializes the accumulated list using `tok`. + bool want_logprobs = false; + const mlxforge::Tokenizer* tok = nullptr; + std::vector logprobs; }; extern "C" { @@ -231,6 +250,8 @@ mlxforge_request* submit_ids(mlxforge_engine* engine, std::vector prompt_id handle->req = req; handle->detok = std::make_unique(engine->engine->tokenizer()); + handle->tok = &engine->engine->tokenizer(); + handle->want_logprobs = req->params.top_logprobs >= 0; return handle.release(); } @@ -370,6 +391,12 @@ int mlxforge_request_next(mlxforge_request* req, char** text) { for (;;) { int tok = 0; if (req->req->tokens.pop(tok)) { + // Pop this token's log-prob in lockstep (the worker pushes one per emitted + // token) and accumulate it for mlxforge_request_logprobs. + if (req->want_logprobs) { + mlxforge::TokenLogprob lp; + if (req->req->logprobs.pop(lp)) req->logprobs.push_back(std::move(lp)); + } std::string piece = req->detok->add(tok); if (piece.empty()) continue; if (text) *text = dup_cstr(piece); @@ -399,6 +426,23 @@ const char* mlxforge_request_finish_reason(mlxforge_request* req) { return (req && req->req) ? req->req->finish_reason.c_str() : ""; } +char* mlxforge_request_logprobs(mlxforge_request* req) { + if (!req || !req->tok || req->logprobs.empty()) return nullptr; + try { + nlohmann::json content = nlohmann::json::array(); + for (const mlxforge::TokenLogprob& lp : req->logprobs) { + nlohmann::json entry = lp_entry(lp.id, lp.logprob, *req->tok); + nlohmann::json top = nlohmann::json::array(); + for (const auto& alt : lp.top) top.push_back(lp_entry(alt.first, alt.second, *req->tok)); + entry["top_logprobs"] = std::move(top); + content.push_back(std::move(entry)); + } + return dup_cstr(content.dump()); + } catch (...) { + return nullptr; + } +} + void mlxforge_request_free(mlxforge_request* req) { if (!req) return; try { @@ -408,6 +452,12 @@ void mlxforge_request_free(mlxforge_request* req) { req->req->cancelled.store(true); int tok = 0; while (req->req->tokens.pop(tok)) { + // Drain log-probs in lockstep too, so the worker's producer never blocks + // on a full, abandoned log-prob queue. + if (req->want_logprobs) { + mlxforge::TokenLogprob lp; + req->req->logprobs.pop(lp); + } } } } catch (...) { diff --git a/src/capi/mlxforge.h b/src/capi/mlxforge.h index 4bc14dd..d09ccbf 100644 --- a/src/capi/mlxforge.h +++ b/src/capi/mlxforge.h @@ -37,8 +37,12 @@ extern "C" { * last-token pooling, trailing EOS, instruction prefix). * v3: mlxforge_submit_image (Qwen3-VL vision-language: a prompt + one image, * served single-stream). - * v4: mlxforge_image + mlxforge_submit_images (a prompt + N images). */ -#define MLXFORGE_ABI_VERSION 4 + * v4: mlxforge_image + mlxforge_submit_images (a prompt + N images). + * v5: mlxforge_sampling.logprobs + mlxforge_request_logprobs (OpenAI per-token + * log-probabilities: the chosen token's logprob and its top-N alternatives, + * accumulated as the request is drained and returned as an OpenAI-shaped + * JSON array). */ +#define MLXFORGE_ABI_VERSION 5 typedef struct mlxforge_engine mlxforge_engine; typedef struct mlxforge_request mlxforge_request; @@ -72,6 +76,12 @@ typedef struct { * object with ordered, required, scalar-typed properties). The output is * masked so it can only be well-formed JSON. */ const char* json_schema; + /* Per-token log-probabilities (OpenAI logprobs). 0 (the zero-init default) => + * off. N > 0 => report each emitted token's own log-prob plus its (N - 1) most + * likely alternatives, so 1 = the chosen token's log-prob only. Retrieve them + * with mlxforge_request_logprobs() once the request is drained. Not supported + * on vision (mlxforge_submit_image) requests, which ignore it. */ + int logprobs; } mlxforge_sampling; /* ---- Library info ---------------------------------------------------------*/ @@ -217,6 +227,16 @@ void mlxforge_request_cancel(mlxforge_request* req); * 1. Owned by the request. */ const char* mlxforge_request_finish_reason(mlxforge_request* req); +/* Per-token log-probabilities for the tokens emitted so far, as a newly-allocated + * JSON array in the OpenAI logprobs `content` shape: + * [{ "token": "...", "logprob": , "bytes": [...], + * "top_logprobs": [{ "token", "logprob", "bytes" }, ...] }, ...] + * Log-probs accumulate as mlxforge_request_next() drains the stream, so call this + * once the request is done (next() returned 1). Returns NULL when logprobs were + * not requested (sampling.logprobs == 0) or none were produced. The caller owns + * the string and frees it with mlxforge_string_free(). */ +char* mlxforge_request_logprobs(mlxforge_request* req); + /* Destroy a request. If it is still running it is cancelled and drained first * (so the worker never blocks on a full token queue). NULL is ignored. */ void mlxforge_request_free(mlxforge_request* req); diff --git a/src/runtime/single_stream.cpp b/src/runtime/single_stream.cpp index 14c35fe..eab2205 100644 --- a/src/runtime/single_stream.cpp +++ b/src/runtime/single_stream.cpp @@ -7,25 +7,55 @@ #include "sample/sampler.h" #include "mlx/ops.h" +#include "mlx/random.h" #include "mlx/transforms.h" namespace mlxforge { namespace { -// Greedy token id from the last position of (1, L, vocab) logits. -int greedy_last(const mx::array& logits) { +std::vector to_ints(const mx::array& a) { + mx::array c = mx::contiguous(mx::astype(a, mx::int32)); + mx::eval(c); + return std::vector(c.data(), c.data() + c.size()); +} +std::vector to_floats(const mx::array& a) { + mx::array c = mx::contiguous(mx::astype(a, mx::float32)); + mx::eval(c); + return std::vector(c.data(), c.data() + c.size()); +} + +// Greedy token id from the last position of (1, L, vocab) logits. When `lp` is +// non-null, also fill it with the chosen token's log-prob and (per `top_logprobs`) +// its top-K alternatives, reusing the sampler's greedy + logprob path. +int greedy_last(const mx::array& logits, int top_logprobs, TokenLogprob* lp) { const int L = logits.shape()[1]; const int V = logits.shape()[2]; mx::array last = mx::reshape(mx::slice(logits, {0, L - 1, 0}, {1, L, V}), {1, V}); - mx::array tok = Sampler::greedy(last); - mx::eval(tok); - return tok.item(); + if (!lp) { + mx::array tok = Sampler::greedy(last); + mx::eval(tok); + return tok.item(); + } + SamplingParams p; + p.temperature = 0.0f; // greedy + p.top_logprobs = top_logprobs; + SampleResult res = Sampler::sample(last, p, mx::random::key(0)); + // Read through to_ints/to_floats (which astype first): the chosen logprob is in + // the model's fp16, so a raw item() would misread the 2-byte value. + const int id = to_ints(res.tokens)[0]; + lp->id = id; + lp->logprob = to_floats(res.logprobs)[0]; + lp->top.clear(); + const std::vector tids = to_ints(res.top_tokens); + const std::vector tlp = to_floats(res.top_logprobs); + for (size_t k = 0; k < tids.size(); ++k) lp->top.emplace_back(tids[k], tlp[k]); + return id; } } // namespace GenerateResult greedy_generate(const DecoderModel& model, const std::vector& prompt_ids, int max_tokens, const std::vector& eos_ids, - const std::function& on_token) { + const std::function& on_token, int top_logprobs) { auto is_eos = [&](int id) { return std::find(eos_ids.begin(), eos_ids.end(), id) != eos_ids.end(); }; @@ -39,10 +69,15 @@ GenerateResult greedy_generate(const DecoderModel& model, const std::vector KVCache cache(model.config().n_layers); mx::array prompt(prompt_ids.data(), {1, static_cast(prompt_ids.size())}, mx::int32); + // `next_lp` carries the log-prob record for `next`, pushed when (and only when) + // `next` is actually emitted (EOS is excluded, like its token id). + const bool want_lp = top_logprobs >= 0; + TokenLogprob next_lp; + // Prefill + first sample: time to first token. greedy_last() eval()s, so the // elapsed wall time covers the actual GPU work, not just graph construction. const auto t_start = Clock::now(); - int next = greedy_last(model.forward(prompt, &cache)); + int next = greedy_last(model.forward(prompt, &cache), top_logprobs, want_lp ? &next_lp : nullptr); result.ttft_ms = ms_since(t_start); const auto t_decode_start = Clock::now(); @@ -52,10 +87,11 @@ GenerateResult greedy_generate(const DecoderModel& model, const std::vector break; } result.tokens.push_back(next); + if (want_lp) result.token_logprobs.push_back(next_lp); if (on_token) on_token(next); mx::array step(&next, {1, 1}, mx::int32); - next = greedy_last(model.forward(step, &cache)); + next = greedy_last(model.forward(step, &cache), top_logprobs, want_lp ? &next_lp : nullptr); } // The first token came from prefill; throughput here is the steady-state // decode rate over the tokens generated after it. diff --git a/src/runtime/single_stream.h b/src/runtime/single_stream.h index 571c06b..3000255 100644 --- a/src/runtime/single_stream.h +++ b/src/runtime/single_stream.h @@ -8,11 +8,15 @@ #include #include "model/decoder_model.h" +#include "scheduler/request.h" // TokenLogprob namespace mlxforge { struct GenerateResult { std::vector tokens; // generated token ids (EOS excluded) + // Per-token log-probs (OpenAI logprobs), aligned with `tokens`; populated only + // when greedy_generate is asked for them (top_logprobs >= 0), else empty. + std::vector token_logprobs; bool hit_eos = false; // stopped because an EOS token was produced double ttft_ms = 0.0; // prefill + first sample (time to first token) double decode_ms = 0.0; // wall time spent generating tokens after the first @@ -26,8 +30,12 @@ struct GenerateResult { // Greedy (argmax) single-stream generation. Calls `on_token(id)` for each // emitted token. Stops when an EOS token would be produced or after max_tokens. +// `top_logprobs` mirrors SamplingParams: -1 = off (no logprob work); 0 = record +// each emitted token's own log-prob into result.token_logprobs; N > 0 = also +// record its N most-likely alternatives. GenerateResult greedy_generate(const DecoderModel& model, const std::vector& prompt_ids, int max_tokens, const std::vector& eos_ids, - const std::function& on_token = {}); + const std::function& on_token = {}, + int top_logprobs = -1); } // namespace mlxforge diff --git a/src/runtime/worker.cpp b/src/runtime/worker.cpp index 143debe..8340d39 100644 --- a/src/runtime/worker.cpp +++ b/src/runtime/worker.cpp @@ -28,18 +28,27 @@ std::vector read_ids(const mx::array& a) { return std::vector(c.data(), c.data() + c.size()); } +std::vector read_floats(const mx::array& a) { + mx::array c = mx::contiguous(mx::astype(a, mx::float32)); + mx::eval(c); + return std::vector(c.data(), c.data() + c.size()); +} + bool is_eos(const Request& req, int id) { return std::find(req.eos_ids.begin(), req.eos_ids.end(), id) != req.eos_ids.end(); } // Emit token `id` for a request, returning true if the request is now finished. -bool consume(Request& req, int& produced, int id) { +// `lp` (when non-null) is the token's log-prob record, pushed in lockstep with +// the token; EOS is never emitted, so it carries no logprob. +bool consume(Request& req, int& produced, int id, const TokenLogprob* lp) { if (is_eos(req, id)) { req.finish_reason = "stop"; return true; } if (produced == 0) req.first_token_time = Request::Clock::now(); // TTFT marker req.tokens.push(id); + if (lp) req.logprobs.push(*lp); if (++produced >= req.max_tokens) { req.finish_reason = "length"; return true; @@ -64,6 +73,7 @@ void Worker::handle_embedding(Request& req) { req.finish_reason = "error"; } req.tokens.close(); // unblock the waiting submitter + req.logprobs.close(); } void Worker::admit_multimodal(const std::shared_ptr& req) { @@ -111,6 +121,7 @@ void Worker::admit_multimodal(const std::shared_ptr& req) { log::error("worker: multimodal admit error: {}", e.what()); req->finish_reason = "error"; req->tokens.close(); + req->logprobs.close(); } } @@ -270,27 +281,31 @@ void Worker::register_rows(const std::vector>& incoming } } - std::vector first = - read_ids(sample_rows(last_logits, base, static_cast(incoming.size()))); + const int count = static_cast(incoming.size()); + std::vector first; + std::vector row_lp; + std::vector has_lp; + finalize_sample(sample_rows(last_logits, base, count), count, first, row_lp, has_lp); - for (size_t i = 0; i < incoming.size(); ++i) { - const int b = base + static_cast(i); + for (int i = 0; i < count; ++i) { + const int b = base + i; feed_[b] = first[i]; // feed the first token next step history_[b].push_back(first[i]); // and let later penalties see it advance_grammar(*reqs_[b], first[i]); if (reqs_[b]->cancelled.load()) { reqs_[b]->finish_reason = "cancel"; finished_[b] = true; - } else if (consume(*reqs_[b], produced_[b], first[i])) { + } else if (consume(*reqs_[b], produced_[b], first[i], has_lp[i] ? &row_lp[i] : nullptr)) { finished_[b] = true; } } } -mx::array Worker::sample_rows(const mx::array& logits, int row_offset, int count) { +Worker::SampledRows Worker::sample_rows(const mx::array& logits, int row_offset, int count) { const int vocab = logits.shape()[1]; std::vector tokens; tokens.reserve(count); + SampledRows out{mx::zeros({0}, mx::int32), {}, {}, {}, {}}; for (int i = 0; i < count; ++i) { const int r = row_offset + i; mx::array row = mx::slice(logits, {i, 0}, {i + 1, vocab}); // (1, vocab) @@ -313,8 +328,45 @@ mx::array Worker::sample_rows(const mx::array& logits, int row_offset, int count return Sampler::sample(row, p, ks.second, history); }(); tokens.push_back(res.tokens); // (1,) + // Collect the log-prob arrays only for rows that asked, so the logprob + // subgraph stays dead (MLX prunes it) for everyone else. + if (p.top_logprobs >= 0) { + out.lp_rows.push_back(i); + out.lp_chosen.push_back(res.logprobs); // (1,) + out.lp_top_ids.push_back(res.top_tokens); // (1, K) + out.lp_top_lp.push_back(res.top_logprobs); // (1, K) + } + } + out.tokens = mx::concatenate(tokens, /*axis=*/0); // (count,) + return out; +} + +void Worker::finalize_sample(const SampledRows& s, int count, std::vector& ids, + std::vector& row_lp, std::vector& has_lp) { + // The ONE async_eval per step, over the whole batch: the chosen tokens plus + // every logprob-enabled row's log-prob arrays, so they ride the same eval. + std::vector to_eval; + to_eval.reserve(1 + s.lp_chosen.size() + s.lp_top_ids.size() + s.lp_top_lp.size()); + to_eval.push_back(s.tokens); + for (const auto& a : s.lp_chosen) to_eval.push_back(a); + for (const auto& a : s.lp_top_ids) to_eval.push_back(a); + for (const auto& a : s.lp_top_lp) to_eval.push_back(a); + mx::async_eval(to_eval); + + ids = read_ids(s.tokens); + row_lp.assign(count, TokenLogprob{}); + has_lp.assign(count, 0); + for (size_t j = 0; j < s.lp_rows.size(); ++j) { + const int i = s.lp_rows[j]; + has_lp[i] = 1; + TokenLogprob& lp = row_lp[i]; + lp.id = ids[i]; + lp.logprob = read_floats(s.lp_chosen[j])[0]; + const std::vector tids = read_ids(s.lp_top_ids[j]); + const std::vector tlp = read_floats(s.lp_top_lp[j]); + lp.top.reserve(tids.size()); + for (size_t k = 0; k < tids.size(); ++k) lp.top.emplace_back(tids[k], tlp[k]); } - return mx::concatenate(tokens, /*axis=*/0); // (count,) } void Worker::decode_step() { @@ -334,10 +386,13 @@ void Worker::decode_step() { mx::array inputs(fed.data(), {bucket, 1}, mx::int32); mx::array logits = model_->forward(inputs, *cache_); // (bucket, 1, vocab) // Sample only the B real rows (dummy rows are excluded from the graph). - mx::array next = sample_rows(mx::reshape(logits, {bucket, logits.shape()[2]}), 0, B); + SampledRows sampled = sample_rows(mx::reshape(logits, {bucket, logits.shape()[2]}), 0, B); - mx::async_eval(next); // the ONE eval per decode step, over the whole batch - std::vector ids = read_ids(next); + // One async_eval over the whole batch (tokens + any logprob arrays). + std::vector ids; + std::vector row_lp; + std::vector has_lp; + finalize_sample(sampled, B, ids, row_lp, has_lp); for (int b = 0; b < B; ++b) { if (finished_[b]) continue; @@ -349,7 +404,8 @@ void Worker::decode_step() { feed_[b] = ids[b]; history_[b].push_back(ids[b]); // penalties see the full sequence so far advance_grammar(*reqs_[b], ids[b]); - if (consume(*reqs_[b], produced_[b], ids[b])) finished_[b] = true; + if (consume(*reqs_[b], produced_[b], ids[b], has_lp[b] ? &row_lp[b] : nullptr)) + finished_[b] = true; } // Drop the dummy rows so the cache holds only real rows again. @@ -387,6 +443,7 @@ void Worker::evict_finished() { request_us_sum_ += static_cast(us(now - r.enqueue_time).count()); reqs_[b]->tokens.close(); // signal the consumer + reqs_[b]->logprobs.close(); } else { keep.push_back(b); } diff --git a/src/runtime/worker.h b/src/runtime/worker.h index 8b65355..9e9a213 100644 --- a/src/runtime/worker.h +++ b/src/runtime/worker.h @@ -19,6 +19,7 @@ #include "cache/batch_kv_cache.h" #include "model/decoder_model.h" #include "runtime/metrics.h" +#include "scheduler/request.h" // Request, TokenLogprob #include "scheduler/scheduler.h" #include "mlx/array.h" @@ -70,11 +71,28 @@ class Worker { // One decode step over the whole batch: forward -> sample -> async_eval -> // push each row's token, marking finished rows. void decode_step(); + + // Result of sampling the active batch in one graph: the chosen tokens, plus — + // for the rows that requested log-probs (params.top_logprobs >= 0) — their + // per-row log-prob arrays. All are built into the same graph so one async_eval + // covers them; the parallel vectors are aligned with `lp_rows`. + struct SampledRows { + mx::array tokens; // (count,) int32 + std::vector lp_rows; // row-local indices (0..count-1) wanting logprobs + std::vector lp_chosen; // aligned: (1,) chosen-token log-prob + std::vector lp_top_ids; // aligned: (1,K) int32 alternatives + std::vector lp_top_lp; // aligned: (1,K) fp32 alternatives + }; // Sample one token for `count` rows of `logits`, where logits row i belongs to // request reqs_[row_offset + i]. Applies each request's SamplingParams + penalty // history and advances its RNG key. Builds one graph (no eval) so the caller - // keeps the single-async_eval-per-step invariant. Returns a (count,) int32 array. - mx::array sample_rows(const mx::array& logits, int row_offset, int count); + // keeps the single-async_eval-per-step invariant. + SampledRows sample_rows(const mx::array& logits, int row_offset, int count); + // Async-eval a sampled batch (the ONE eval per step) and read it back: fills + // `ids` with the chosen token id per local row, and — for logprob-enabled rows — + // `row_lp`/`has_lp` (both sized `count`) with the reconstructed TokenLogprob. + void finalize_sample(const SampledRows& s, int count, std::vector& ids, + std::vector& row_lp, std::vector& has_lp); // Drop rows marked finished (filter the cache, compact the row vectors, // close their token queues). void evict_finished(); diff --git a/src/sample/sampler.cpp b/src/sample/sampler.cpp index 33fa58b..a05983a 100644 --- a/src/sample/sampler.cpp +++ b/src/sample/sampler.cpp @@ -11,13 +11,49 @@ namespace mlxforge { namespace { constexpr float kNegInf = -std::numeric_limits::infinity(); -// log-prob of each chosen token under softmax(logits): (B, vocab), (B,) -> (B,). -mx::array gather_logprob(const mx::array& logits, const mx::array& tokens) { +// Log-prob distribution over the vocab: log(softmax(logits)), (B, vocab). +mx::array log_probs(const mx::array& logits) { + return mx::log(mx::softmax(logits, /*axis=*/-1, /*precise=*/true)); +} + +// log-prob of each chosen token, given a precomputed (B, vocab) log-prob array +// and the (B,) chosen ids -> (B,). +mx::array gather_logprob(const mx::array& logp, const mx::array& tokens) { const int batch = tokens.shape()[0]; - mx::array logp = mx::log(mx::softmax(logits, /*axis=*/-1, /*precise=*/true)); mx::array idx = mx::reshape(tokens, {batch, 1}); return mx::reshape(mx::take_along_axis(logp, idx, /*axis=*/-1), {batch}); } + +// Top-k (id, log-prob) per row in descending log-prob order, from a (B, vocab) +// log-prob array. Returns ((B, k) int32 ids, (B, k) fp32 log-probs). Pure graph +// ops — no eval, so the caller keeps the single-async_eval-per-step invariant. +std::pair top_k_logprobs(const mx::array& logp, int k) { + const int batch = logp.shape()[0]; + // argsort of the negated log-probs is ascending of -logp = descending of logp; + // the first k columns are the k most-likely ids, already in order. + mx::array order = mx::argsort(mx::negative(logp), /*axis=*/-1); // (B, vocab) + mx::array top_ids = mx::slice(order, {0, 0}, {batch, k}); // (B, k) + mx::array top_lp = mx::take_along_axis(logp, top_ids, /*axis=*/-1); + return {mx::astype(top_ids, mx::int32), top_lp}; +} + +// Empty (B, 0) placeholders for the top-k fields when no alternatives were asked. +SampleResult with_no_top(const mx::array& tokens, const mx::array& logprobs) { + const int batch = tokens.shape()[0]; + return {tokens, logprobs, mx::zeros({batch, 0}, mx::int32), + mx::zeros({batch, 0}, mx::float32)}; +} + +// Attach the chosen-token logprob and (when params.top_logprobs > 0) the top-k +// alternatives, both derived from `dist` (a coherent (B, vocab) log-prob array). +SampleResult attach_logprobs(const mx::array& tokens, const mx::array& dist, + const SamplingParams& params, int vocab) { + mx::array chosen = gather_logprob(dist, tokens); + const int k = std::min(params.top_logprobs, vocab); + if (k <= 0) return with_no_top(tokens, chosen); + std::pair top = top_k_logprobs(dist, k); + return {tokens, chosen, top.first, top.second}; +} } // namespace mx::array Sampler::greedy(const mx::array& logits) { @@ -106,22 +142,27 @@ SampleResult Sampler::sample(const mx::array& logits, const SamplingParams& para const mx::array& key, const mx::array& history) { // Penalties reshape the logit landscape and so apply to greedy too. mx::array adj = history.shape().back() > 0 ? penalize(logits, params, history) : logits; + const int vocab = logits.shape().back(); + + // Reported log-probs come from `adj` — the penalized, pre-temperature, + // pre-filter distribution. That is a coherent softmax over the whole vocab (the + // temperature-scaled/top-k-masked logits are not), so the chosen-token logprob + // and its top-k alternatives are computed from the same array. + auto finish = [&](const mx::array& tokens) { + if (params.top_logprobs < 0) return with_no_top(tokens, gather_logprob(log_probs(adj), tokens)); + return attach_logprobs(tokens, log_probs(adj), params, vocab); + }; - if (params.temperature <= 0.0f) { - mx::array tokens = greedy(adj); - return {tokens, gather_logprob(adj, tokens)}; - } + if (params.temperature <= 0.0f) return finish(greedy(adj)); mx::array scaled = mx::divide(adj, mx::array(params.temperature, adj.dtype())); // Clamp top_k to the vocab size: mx::topk throws for k > axis length, and a // caller-supplied k larger than the vocab is just "keep everything" anyway. - const int vocab = logits.shape().back(); if (params.top_k > 0) scaled = apply_top_k(scaled, std::min(params.top_k, vocab)); if (params.top_p < 1.0f) scaled = apply_top_p(scaled, params.top_p); if (params.min_p > 0.0f) scaled = apply_min_p(scaled, params.min_p); - mx::array tokens = mx::astype(mx::random::categorical(scaled, /*axis=*/-1, key), mx::int32); - return {tokens, gather_logprob(scaled, tokens)}; + return finish(mx::astype(mx::random::categorical(scaled, /*axis=*/-1, key), mx::int32)); } } // namespace mlxforge diff --git a/src/sample/sampler.h b/src/sample/sampler.h index dd480ef..0a589ef 100644 --- a/src/sample/sampler.h +++ b/src/sample/sampler.h @@ -22,6 +22,12 @@ struct SamplingParams { float presence_penalty = 0.0f; // 0 => disabled (subtracts penalty if seen) uint64_t seed = 0; + // Per-token log-probability reporting (OpenAI logprobs/top_logprobs). -1 = off + // (no logprob work); 0 = report the chosen token's logprob only; N > 0 = also + // report the N most-likely alternatives. Off by default so the hot path is + // untouched unless a consumer asks for it. + int top_logprobs = -1; + // Whether any penalty is active (skips the per-row history machinery when not). bool has_penalties() const { return repetition_penalty != 1.0f || frequency_penalty != 0.0f || @@ -32,6 +38,10 @@ struct SamplingParams { struct SampleResult { mx::array tokens; // (B,) int32 mx::array logprobs; // (B,) fp32 — log-prob of each chosen token + // Top-N alternatives, when params.top_logprobs > 0 (else (B, 0) placeholders). + // Both are (B, K), aligned column-wise and ordered by descending log-prob. + mx::array top_tokens; // (B, K) int32 — the K most-likely token ids + mx::array top_logprobs; // (B, K) fp32 — their log-probs }; class Sampler { diff --git a/src/scheduler/request.h b/src/scheduler/request.h index 89ea733..b0ff0f4 100644 --- a/src/scheduler/request.h +++ b/src/scheduler/request.h @@ -14,28 +14,33 @@ #include #include +#include + #include "sample/json_grammar.h" #include "sample/sampler.h" namespace mlxforge { // Bounded, blocking, single-producer (worker) / single-consumer (request thread) -// token queue. push() applies backpressure when full (slow SSE consumers); the -// consumer pop()s until the producer close()s and the queue drains. -class TokenQueue { +// queue. push() applies backpressure when full (slow SSE consumers); the consumer +// pop()s until the producer close()s and the queue drains. Used for both the +// generated token ids (TokenQueue) and their optional per-token log-probs +// (LogprobQueue), which the worker pushes in lockstep. +template +class SpscQueue { public: - explicit TokenQueue(std::size_t capacity = 1024) : capacity_(capacity) {} + explicit SpscQueue(std::size_t capacity = 1024) : capacity_(capacity) {} - // Producer: append a token, blocking while full unless closed. - void push(int token) { + // Producer: append an item, blocking while full unless closed. + void push(T item) { std::unique_lock lk(m_); not_full_.wait(lk, [&] { return q_.size() < capacity_ || closed_; }); if (closed_) return; - q_.push(token); + q_.push(std::move(item)); not_empty_.notify_one(); } - // Producer: signal that no more tokens will be pushed. + // Producer: signal that no more items will be pushed. void close() { { std::lock_guard lk(m_); @@ -45,12 +50,12 @@ class TokenQueue { not_full_.notify_all(); } - // Consumer: pop the next token; returns false once closed and drained. - bool pop(int& out) { + // Consumer: pop the next item; returns false once closed and drained. + bool pop(T& out) { std::unique_lock lk(m_); not_empty_.wait(lk, [&] { return !q_.empty() || closed_; }); if (q_.empty()) return false; - out = q_.front(); + out = std::move(q_.front()); q_.pop(); not_full_.notify_one(); return true; @@ -58,13 +63,27 @@ class TokenQueue { private: std::size_t capacity_; - std::queue q_; + std::queue q_; bool closed_ = false; std::mutex m_; std::condition_variable not_empty_; std::condition_variable not_full_; }; +using TokenQueue = SpscQueue; + +// One emitted token's log-probability data (OpenAI logprobs). `id`/`logprob` are +// the chosen token and its log-prob; `top` holds the requested alternatives as +// (id, log-prob) pairs in descending order (empty when only the chosen logprob +// was requested). Carried on Request::logprobs in lockstep with Request::tokens. +struct TokenLogprob { + int id = 0; + float logprob = 0.0f; + std::vector> top; +}; + +using LogprobQueue = SpscQueue; + struct Request { std::vector prompt_ids; SamplingParams params; @@ -103,6 +122,9 @@ struct Request { std::atomic cancelled{false}; TokenQueue tokens; // worker pushes generated ids, then close()s + // Per-token log-probs, pushed in lockstep with `tokens` when + // params.top_logprobs >= 0 (otherwise never touched). Closed alongside `tokens`. + LogprobQueue logprobs; std::string finish_reason; // "stop" | "length" | "cancel" | "embed" // Metrics: enqueue stamped on submit, first_token/finish stamped by the worker. diff --git a/src/server/http_server.cpp b/src/server/http_server.cpp index bfffc56..9b06f4c 100644 --- a/src/server/http_server.cpp +++ b/src/server/http_server.cpp @@ -50,6 +50,10 @@ std::string HttpServer::next_id(const char* prefix) { std::shared_ptr HttpServer::make_request(const ChatRequest& cr) const { auto req = std::make_shared(); req->params = cr.params; + // OpenAI logprobs -> engine top_logprobs: off (-1) unless `logprobs` is set, + // then the requested alternatives count (0 = chosen-token logprob only). + // Unsupported on the single-stream vision path, so off when an image is present. + req->params.top_logprobs = (cr.logprobs && !cr.has_images()) ? cr.top_logprobs : -1; req->max_tokens = cr.max_tokens; req->eos_ids = cfg_.eos_token_ids; @@ -94,9 +98,18 @@ nlohmann::json HttpServer::run_blocking(const std::shared_ptr& req, const ChatRequest& cr) { const int prompt_tokens = static_cast(req->prompt_ids.size()); + // Drain tokens and (when enabled) their log-probs in lockstep. + const bool want_lp = req->params.top_logprobs >= 0; std::vector out; + std::vector out_lp; int tok = 0; - while (req->tokens.pop(tok)) out.push_back(tok); + while (req->tokens.pop(tok)) { + out.push_back(tok); + if (want_lp) { + TokenLogprob lp; + if (req->logprobs.pop(lp)) out_lp.push_back(std::move(lp)); + } + } const std::string content = tok_->decode(out); const std::string finish = finish_reason_of(req->finish_reason); @@ -110,8 +123,11 @@ nlohmann::json HttpServer::run_blocking(const std::shared_ptr& req, return make_chat_completion_tools(next_id("chatcmpl-"), created, model_name_, calls, prompt_tokens, completion_tokens); } + // Pass an array (possibly empty) when logprobs were requested, else null. + const json logprobs_content = + want_lp ? make_logprobs_content(out_lp, *tok_) : json(nullptr); return make_chat_completion(next_id("chatcmpl-"), created, model_name_, content, finish, - prompt_tokens, completion_tokens); + prompt_tokens, completion_tokens, logprobs_content); } // Legacy text completion shape. return {{"id", next_id("cmpl-")}, @@ -138,9 +154,15 @@ void HttpServer::stream_chat(const std::shared_ptr& req, httplib::Respo req->cancelled.store(true); // client disconnected -> worker evicts return false; }; + // Per-token log-probs, popped in lockstep with the token ids and attached + // (as choices[0].logprobs.content) to the next content delta we emit. + const bool want_lp = req->params.top_logprobs >= 0; + std::vector pending_lp; auto send_content = [&](const std::string& s) { - return s.empty() || send(sse_frame(make_chat_chunk(id, created, model, - {{"content", s}}, nullptr))); + if (s.empty()) return true; + json lp = want_lp ? make_logprobs_content(pending_lp, *tok) : json(nullptr); + pending_lp.clear(); + return send(sse_frame(make_chat_chunk(id, created, model, {{"content", s}}, nullptr, lp))); }; // First chunk announces the assistant role. @@ -155,6 +177,12 @@ void HttpServer::stream_chat(const std::shared_ptr& req, httplib::Respo bool buffering = allow_tools; int t = 0; while (req->tokens.pop(t)) { + // Pop this token's log-prob in lockstep (the worker pushes one per + // emitted token) so it stays aligned with the detokenized text. + if (want_lp) { + TokenLogprob lp; + if (req->logprobs.pop(lp)) pending_lp.push_back(std::move(lp)); + } std::string piece = detok.add(t); if (piece.empty()) continue; if (buffering) { diff --git a/src/server/openai.cpp b/src/server/openai.cpp index 0ee4095..503fbb1 100644 --- a/src/server/openai.cpp +++ b/src/server/openai.cpp @@ -148,6 +148,14 @@ void parse_common(const json& body, ChatRequest& r) { if (body.contains("seed") && !body["seed"].is_null()) r.params.seed = body["seed"].get(); + // OpenAI logprobs: `logprobs` (bool) turns on per-token reporting; with it, + // `top_logprobs` (0–20) is the alternatives count. We carry both on the request + // for serialization and let the HTTP layer fold them into params.top_logprobs. + r.logprobs = body.value("logprobs", false); + r.top_logprobs = body.value("top_logprobs", 0); + if (r.top_logprobs < 0 || r.top_logprobs > 20) + throw std::runtime_error("'top_logprobs' must be in [0, 20]"); + r.stream = body.value("stream", false); r.n = body.value("n", 1); r.stop = parse_stop(body); @@ -233,18 +241,45 @@ nlohmann::json make_embeddings_response(const std::string& model, }; } +namespace { +// One {token, logprob, bytes} entry: decode `id` to its text and raw UTF-8 bytes. +json logprob_entry(int id, float logprob, const Tokenizer& tok) { + const std::string text = tok.decode({id}); + json bytes = json::array(); + for (unsigned char c : text) bytes.push_back(static_cast(c)); + return {{"token", text}, {"logprob", logprob}, {"bytes", std::move(bytes)}}; +} +} // namespace + +nlohmann::json make_logprobs_content(const std::vector& logprobs, + const Tokenizer& tok) { + json content = json::array(); + for (const TokenLogprob& lp : logprobs) { + json entry = logprob_entry(lp.id, lp.logprob, tok); + json top = json::array(); + for (const auto& alt : lp.top) top.push_back(logprob_entry(alt.first, alt.second, tok)); + entry["top_logprobs"] = std::move(top); + content.push_back(std::move(entry)); + } + return content; +} + nlohmann::json make_chat_completion(const std::string& id, long created, const std::string& model, const std::string& content, const std::string& finish_reason, - int prompt_tokens, int completion_tokens) { + int prompt_tokens, int completion_tokens, + const nlohmann::json& logprobs_content) { + json choice = {{"index", 0}, + {"message", {{"role", "assistant"}, {"content", content}}}, + {"finish_reason", finish_reason}}; + // null => logprobs disabled (omit the field); an array (even empty) => enabled. + choice["logprobs"] = + logprobs_content.is_null() ? json(nullptr) : json{{"content", logprobs_content}}; return { {"id", id}, {"object", "chat.completion"}, {"created", created}, {"model", model}, - {"choices", - json::array({{{"index", 0}, - {"message", {{"role", "assistant"}, {"content", content}}}, - {"finish_reason", finish_reason}}})}, + {"choices", json::array({std::move(choice)})}, {"usage", make_usage(prompt_tokens, completion_tokens)}, }; } @@ -327,13 +362,16 @@ nlohmann::json make_chat_completion_tools(const std::string& id, long created, } nlohmann::json make_chat_chunk(const std::string& id, long created, const std::string& model, - const nlohmann::json& delta, const nlohmann::json& finish_reason) { + const nlohmann::json& delta, const nlohmann::json& finish_reason, + const nlohmann::json& logprobs_content) { + json choice = {{"index", 0}, {"delta", delta}, {"finish_reason", finish_reason}}; + // Attach logprobs for the tokens in this delta when present (null => omit). + if (!logprobs_content.is_null()) choice["logprobs"] = json{{"content", logprobs_content}}; return {{"id", id}, {"object", "chat.completion.chunk"}, {"created", created}, {"model", model}, - {"choices", - json::array({{{"index", 0}, {"delta", delta}, {"finish_reason", finish_reason}}})}}; + {"choices", json::array({std::move(choice)})}}; } std::string sse_frame(const nlohmann::json& payload) { return "data: " + payload.dump() + "\n\n"; } diff --git a/src/server/openai.h b/src/server/openai.h index fbd5940..bc50b09 100644 --- a/src/server/openai.h +++ b/src/server/openai.h @@ -9,6 +9,7 @@ #include #include "sample/sampler.h" +#include "scheduler/request.h" // TokenLogprob #include "tokenizer/tokenizer.h" namespace mlxforge { @@ -31,6 +32,11 @@ struct ChatRequest { } SamplingParams params; int max_tokens = 128; + // OpenAI logprobs: `logprobs` enables per-token log-prob reporting; when set, + // `top_logprobs` (0–20) is how many alternatives to include per token. Mapped + // onto params.top_logprobs (-1 when off) by the HTTP layer. + bool logprobs = false; + int top_logprobs = 0; bool stream = false; std::vector stop; int n = 1; @@ -73,10 +79,19 @@ EmbeddingsRequest parse_embeddings_request(const nlohmann::json& body); // chat.completion and text_completion response shapes. nlohmann::json make_usage(int prompt_tokens, int completion_tokens); -// Serialize a finished completion into the OpenAI chat.completion shape. +// Build the OpenAI logprobs `content` array from a request's per-token log-probs: +// one entry {token, logprob, bytes, top_logprobs:[{token, logprob, bytes}]} per +// token, where the token text and `bytes` come from decoding the id with `tok`. +nlohmann::json make_logprobs_content(const std::vector& logprobs, + const Tokenizer& tok); + +// Serialize a finished completion into the OpenAI chat.completion shape. When +// `logprobs_content` is non-null it is attached as choices[0].logprobs.content +// (pass an empty array for an enabled-but-empty result; null/omitted = disabled). nlohmann::json make_chat_completion(const std::string& id, long created, const std::string& model, const std::string& content, const std::string& finish_reason, - int prompt_tokens, int completion_tokens); + int prompt_tokens, int completion_tokens, + const nlohmann::json& logprobs_content = nlohmann::json()); // Detect a Llama-3.2 tool call in the model's decoded output. Returns the parsed // calls, or an empty vector when the text is not a tool call (treat as content). @@ -105,9 +120,11 @@ nlohmann::json make_models_list(const std::string& model); // One streaming chunk object. `delta` is the partial message delta (e.g. // {{"content","..."}} or {{"role","assistant"}}); `finish_reason` is the JSON -// finish reason (null until the final chunk). +// finish reason (null until the final chunk). When `logprobs_content` is non-null +// it is attached as choices[0].logprobs.content for the tokens in this delta. nlohmann::json make_chat_chunk(const std::string& id, long created, const std::string& model, - const nlohmann::json& delta, const nlohmann::json& finish_reason); + const nlohmann::json& delta, const nlohmann::json& finish_reason, + const nlohmann::json& logprobs_content = nlohmann::json()); // Wrap a JSON payload as an SSE frame: "data: \n\n". std::string sse_frame(const nlohmann::json& payload); diff --git a/tests/capi/capi_test.cpp b/tests/capi/capi_test.cpp index 549cbde..e5c9089 100644 --- a/tests/capi/capi_test.cpp +++ b/tests/capi/capi_test.cpp @@ -134,6 +134,69 @@ TEST_CASE("C ABI generates text and batches concurrent requests deterministicall mlxforge_engine_free(eng); } +TEST_CASE("C ABI reports per-token logprobs aligned with the generated tokens") { + if (!model_available()) { + MESSAGE("MLXFORGE_MODEL_DIR not present; skipping"); + return; + } + char* err = nullptr; + mlxforge_engine* eng = mlxforge_engine_create(model_dir().c_str(), nullptr, &err); + REQUIRE_MESSAGE(eng != nullptr, (err ? err : "engine_create failed")); + while (!mlxforge_engine_ready(eng)) std::this_thread::sleep_for(std::chrono::milliseconds(10)); + + mlxforge_sampling s = {}; // greedy + s.max_tokens = 12; + s.logprobs = 4; // the chosen token's logprob + 3 alternatives + + mlxforge_msg msg = {"user", "What is the capital of France?"}; + mlxforge_request* r = mlxforge_submit_chat(eng, &msg, 1, &s, &err); + REQUIRE_MESSAGE(r != nullptr, (err ? err : "submit_chat failed")); + + drain(r); // logprobs accumulate as the stream is drained + const std::string reason = mlxforge_request_finish_reason(r); + + char* lj = mlxforge_request_logprobs(r); + REQUIRE_MESSAGE(lj != nullptr, "expected a logprobs JSON payload"); + const std::string lp_json = lj; + mlxforge_string_free(lj); + mlxforge_request_free(r); + + auto content = nlohmann::json::parse(lp_json, nullptr, /*allow_exceptions=*/false); + REQUIRE_FALSE(content.is_discarded()); + REQUIRE(content.is_array()); + CHECK(content.size() >= 1); + // One entry per emitted token: a "length" stop means exactly max_tokens tokens. + if (reason == "length") CHECK(content.size() == 12); + CHECK(content.size() <= 12); + + for (const auto& e : content) { + CHECK(e.contains("token")); + CHECK(e["logprob"].is_number()); + CHECK(e["logprob"].get() <= 0.0); // a log-probability + CHECK(e["bytes"].is_array()); + REQUIRE(e["top_logprobs"].is_array()); + CHECK(e["top_logprobs"].size() == 3); // logprobs=4 => 3 alternatives + // Greedy: the chosen token is the most likely, so it equals the top entry. + CHECK(e["top_logprobs"][0]["token"] == e["token"]); + CHECK(e["top_logprobs"][0]["logprob"].get() == doctest::Approx(e["logprob"].get())); + // Alternatives are in descending log-prob order. + const auto& top = e["top_logprobs"]; + for (size_t i = 1; i < top.size(); ++i) + CHECK(top[i - 1]["logprob"].get() >= top[i]["logprob"].get()); + } + + // A request without logprobs returns no payload. + mlxforge_sampling plain = {}; + plain.max_tokens = 4; + mlxforge_request* r2 = mlxforge_submit_chat(eng, &msg, 1, &plain, nullptr); + REQUIRE(r2 != nullptr); + drain(r2); + CHECK(mlxforge_request_logprobs(r2) == nullptr); + mlxforge_request_free(r2); + + mlxforge_engine_free(eng); +} + TEST_CASE("C ABI embeddings are unit-normalized, deterministic, and semantic") { if (!model_available()) { MESSAGE("MLXFORGE_MODEL_DIR not present; skipping"); diff --git a/tests/runtime/single_stream_test.cpp b/tests/runtime/single_stream_test.cpp index ef0cb9c..d73fab8 100644 --- a/tests/runtime/single_stream_test.cpp +++ b/tests/runtime/single_stream_test.cpp @@ -42,6 +42,36 @@ TEST_CASE("loop terminates on max_tokens") { CHECK_FALSE(r.hit_eos); } +TEST_CASE("greedy_generate reports per-token logprobs aligned with the tokens") { + if (!model_available()) { + MESSAGE("MLXFORGE_MODEL_DIR not present; skipping"); + return; + } + mlxforge::LlamaModel& model = shared_model(); + std::vector prompt = load_token_ids("prompt_0_ids.npy"); + + mlxforge::GenerateResult r = mlxforge::greedy_generate( + model, prompt, /*max_tokens=*/8, model.config().eos_token_ids, {}, /*top_logprobs=*/3); + + // One logprob record per emitted token, each aligned with its id. + REQUIRE(r.token_logprobs.size() == r.tokens.size()); + for (size_t i = 0; i < r.tokens.size(); ++i) { + const mlxforge::TokenLogprob& lp = r.token_logprobs[i]; + CHECK(lp.id == r.tokens[i]); + CHECK(lp.logprob <= 0.0f); + REQUIRE(lp.top.size() == 3); + // Greedy: the chosen token is the most likely, so it heads the alternatives. + CHECK(lp.top.front().first == lp.id); + CHECK(lp.top.front().second == doctest::Approx(lp.logprob)); + for (size_t k = 1; k < lp.top.size(); ++k) CHECK(lp.top[k - 1].second >= lp.top[k].second); + } + + // Off by default: no logprob work, no records. + mlxforge::GenerateResult plain = + mlxforge::greedy_generate(model, prompt, /*max_tokens=*/4, model.config().eos_token_ids); + CHECK(plain.token_logprobs.empty()); +} + TEST_CASE("loop terminates on EOS") { if (!model_available()) { MESSAGE("MLXFORGE_MODEL_DIR not present; skipping"); diff --git a/tests/sample/sampler_test.cpp b/tests/sample/sampler_test.cpp index c580c43..3160b0b 100644 --- a/tests/sample/sampler_test.cpp +++ b/tests/sample/sampler_test.cpp @@ -140,3 +140,64 @@ TEST_CASE("sample returns batched tokens and logprobs") { CHECK(r.logprobs.shape() == mx::Shape{2}); for (float lp : floats(r.logprobs)) CHECK(lp <= 0.0f); // log-probabilities } + +TEST_CASE("top_logprobs off leaves the top arrays empty") { + mx::array logits = logits2d({{1.0f, 2.0f, 3.0f}}); + SampleResult r = Sampler::sample(logits, SamplingParams{}, mx::random::key(0)); // default -1 + CHECK(r.top_tokens.shape() == mx::Shape{1, 0}); + CHECK(r.top_logprobs.shape() == mx::Shape{1, 0}); +} + +TEST_CASE("top_logprobs == 0 reports the chosen logprob but no alternatives") { + mx::array logits = logits2d({{1.0f, 2.0f, 3.0f}}); + SamplingParams p; + p.temperature = 0.0f; // greedy + p.top_logprobs = 0; + SampleResult r = Sampler::sample(logits, p, mx::random::key(0)); + CHECK(ints(r.tokens) == std::vector{2}); // argmax + CHECK(floats(r.logprobs)[0] <= 0.0f); + CHECK(r.top_tokens.shape() == mx::Shape{1, 0}); + CHECK(r.top_logprobs.shape() == mx::Shape{1, 0}); +} + +TEST_CASE("top_logprobs > 0 returns the top-k ids in descending logprob order") { + // logits 1,4,2,3,0 -> ranked ids 1(4), 3(3), 2(2), 0(1), 4(0). + mx::array logits = logits2d({{1.0f, 4.0f, 2.0f, 3.0f, 0.0f}}); + SamplingParams p; + p.temperature = 0.0f; + p.top_logprobs = 3; + SampleResult r = Sampler::sample(logits, p, mx::random::key(0)); + CHECK(r.top_tokens.shape() == mx::Shape{1, 3}); + CHECK(r.top_logprobs.shape() == mx::Shape{1, 3}); + CHECK(ints(r.top_tokens) == std::vector{1, 3, 2}); + + std::vector lp = floats(r.top_logprobs); + CHECK(lp[0] >= lp[1]); + CHECK(lp[1] >= lp[2]); + for (float x : lp) CHECK(x <= 0.0f); + // The chosen token is the argmax and its logprob equals the top alternative. + CHECK(ints(r.tokens) == std::vector{1}); + CHECK(floats(r.logprobs)[0] == doctest::Approx(lp[0])); +} + +TEST_CASE("a full top-k recovers a normalized distribution") { + mx::array logits = logits2d({{0.5f, 1.5f, -0.5f, 2.0f}}); + SamplingParams p; + p.temperature = 0.0f; + p.top_logprobs = 4; // == vocab + SampleResult r = Sampler::sample(logits, p, mx::random::key(0)); + float sum = 0.0f; + for (float x : floats(r.top_logprobs)) sum += std::exp(x); + CHECK(sum == doctest::Approx(1.0f).epsilon(0.01)); +} + +TEST_CASE("top_logprobs is reported under temperature sampling too") { + mx::array logits = logits2d({{1.0f, 4.0f, 2.0f, 3.0f}}); + SamplingParams p; + p.temperature = 0.8f; + p.top_logprobs = 2; + SampleResult r = Sampler::sample(logits, p, mx::random::key(3)); + CHECK(r.top_tokens.shape() == mx::Shape{1, 2}); + // Alternatives come from the pre-temperature distribution: highest-logit first. + CHECK(ints(r.top_tokens) == std::vector{1, 3}); +} diff --git a/tests/server/openai_test.cpp b/tests/server/openai_test.cpp index 7b1b310..eb1451c 100644 --- a/tests/server/openai_test.cpp +++ b/tests/server/openai_test.cpp @@ -1,7 +1,10 @@ // OpenAI request parsing + response serialization (pure, no GPU). #include +#include + #include "server/openai.h" +#include "tokenizer/tokenizer.h" using namespace mlxforge; using nlohmann::json; @@ -78,6 +81,93 @@ TEST_CASE("make_chat_completion emits the OpenAI chat.completion shape") { CHECK(c["usage"]["total_tokens"] == 13); } +TEST_CASE("parse_chat_request reads logprobs and top_logprobs") { + const std::string msgs = R"("messages":[{"role":"user","content":"x"}])"; + ChatRequest off = parse_chat_request(json::parse("{" + msgs + "}")); + CHECK_FALSE(off.logprobs); + CHECK(off.top_logprobs == 0); + + ChatRequest on = + parse_chat_request(json::parse("{" + msgs + R"(,"logprobs":true,"top_logprobs":5})")); + CHECK(on.logprobs); + CHECK(on.top_logprobs == 5); + + // top_logprobs is validated to the OpenAI [0, 20] range. + CHECK_THROWS_AS(parse_chat_request(json::parse("{" + msgs + R"(,"top_logprobs":21})")), + std::runtime_error); + CHECK_THROWS_AS(parse_chat_request(json::parse("{" + msgs + R"(,"top_logprobs":-1})")), + std::runtime_error); +} + +TEST_CASE("make_chat_completion attaches a logprobs block when content is given") { + // A hand-built content array (the shape make_logprobs_content produces). + json content = json::array( + {{{"token", "Paris"}, {"logprob", -0.1}, {"bytes", json::array({80, 97, 114, 105, 115})}, + {"top_logprobs", json::array({{{"token", "Paris"}, {"logprob", -0.1}, {"bytes", json::array()}}, + {{"token", "Lyon"}, {"logprob", -2.0}, {"bytes", json::array()}}})}}}); + json c = make_chat_completion("chatcmpl-1", 1, "mlxforge", "Paris", "stop", 4, 1, content); + REQUIRE(c["choices"][0]["logprobs"].is_object()); + const json& lc = c["choices"][0]["logprobs"]["content"]; + REQUIRE(lc.is_array()); + CHECK(lc.size() == 1); + CHECK(lc[0]["token"] == "Paris"); + CHECK(lc[0]["logprob"] == doctest::Approx(-0.1)); + CHECK(lc[0]["bytes"].size() == 5); + CHECK(lc[0]["top_logprobs"].size() == 2); + + // Omitted (null default) => logprobs is present but null (OpenAI convention). + json off = make_chat_completion("chatcmpl-1", 1, "mlxforge", "Paris", "stop", 4, 1); + CHECK(off["choices"][0]["logprobs"].is_null()); +} + +TEST_CASE("make_chat_chunk attaches logprobs to the streamed choice") { + json content = json::array({{{"token", " a"}, {"logprob", -0.5}, {"bytes", json::array({32, 97})}, + {"top_logprobs", json::array()}}}); + json chunk = make_chat_chunk("chatcmpl-1", 1, "mlxforge", {{"content", " a"}}, nullptr, content); + REQUIRE(chunk["choices"][0]["logprobs"].is_object()); + CHECK(chunk["choices"][0]["logprobs"]["content"][0]["token"] == " a"); + + // Without logprobs the field is omitted (back-compat with existing chunks). + json plain = make_chat_chunk("chatcmpl-1", 1, "mlxforge", {{"content", " a"}}, nullptr); + CHECK_FALSE(plain["choices"][0].contains("logprobs")); +} + +TEST_CASE("make_logprobs_content decodes tokens, bytes, and alternatives") { + // Needs a real tokenizer to decode ids; uses the cached model's tokenizer. + const std::string dir = MLXFORGE_MODEL_DIR; + if (dir.empty() || !std::ifstream(dir + "/tokenizer.json").good()) { + MESSAGE("MLXFORGE_MODEL_DIR not present; skipping logprobs content test"); + return; + } + Tokenizer tok = Tokenizer::from_file(dir + "/tokenizer.json"); + + // Two tokens; the first carries two alternatives, the second none. + std::vector lps; + TokenLogprob a; + a.id = tok.encode("Paris").back(); + a.logprob = -0.25f; + a.top = {{a.id, -0.25f}, {tok.encode("Lyon").back(), -1.5f}}; + TokenLogprob b; + b.id = tok.encode(" France").back(); + b.logprob = -0.5f; + lps = {a, b}; + + json content = make_logprobs_content(lps, tok); + REQUIRE(content.is_array()); + REQUIRE(content.size() == 2); + + // Entry 0: token text matches a decode, bytes are the UTF-8 of that text, and + // top_logprobs has both alternatives. + CHECK(content[0]["token"] == tok.decode({a.id})); + CHECK(content[0]["logprob"] == doctest::Approx(-0.25)); + CHECK(content[0]["bytes"].size() == tok.decode({a.id}).size()); + REQUIRE(content[0]["top_logprobs"].is_array()); + CHECK(content[0]["top_logprobs"].size() == 2); + CHECK(content[0]["top_logprobs"][0]["logprob"] == doctest::Approx(-0.25)); + // Entry 1: no alternatives. + CHECK(content[1]["top_logprobs"].empty()); +} + TEST_CASE("parse_completion_request reads a prompt string") { ChatRequest r = parse_completion_request(json::parse(R"({"prompt":"once upon","max_tokens":8})")); CHECK_FALSE(r.is_chat);