diff --git a/apps/mlxforge_cli.cpp b/apps/mlxforge_cli.cpp
index 08d7d5a..87694d4 100644
--- a/apps/mlxforge_cli.cpp
+++ b/apps/mlxforge_cli.cpp
@@ -6,10 +6,11 @@
// mlxforge-cli dump-weights
// - Loads a model's weights from the supplied directory, prints key/shape/dtype for each tensor,
// asserts that all tensors are fp16, and reports the peak resident memory used.
-// mlxforge-cli generate [max_tokens]
+// mlxforge-cli generate [max_tokens] [--logprobs [N]]
// - Runs greedy single-stream generation: pre-fills the prompt (as raw text using the chat template
// or as a pre-tokenized .npy of ids), then streams the detokenized text to stdout until EOS or
-// max_tokens.
+// max_tokens. With --logprobs [N], each emitted token's log-prob (and its N most-likely
+// alternatives) is printed to stderr after generation; stdout stays the generated text.
// mlxforge-cli bench [max_tokens] [runs]
// - Repeatable throughput benchmark over a fixed prompt: one discarded warmup run, then `runs`
// timed runs (defaults: max_tokens=128, runs=3) reporting time-to-first-token and decode tok/s.
@@ -22,6 +23,7 @@
// which will be downloaded on first use.
#include
+#include
#include
#include
#include
@@ -162,8 +164,24 @@ int run_dump_weights(const std::string& spec) {
return non_fp16 == 0 ? 0 : 1;
}
+// Render a token's text for the logprobs dump: quoted, with the common control
+// characters escaped so whitespace tokens stay legible on one line.
+std::string show_token(const std::string& s) {
+ std::string out = "\"";
+ for (char c : s) {
+ if (c == '\n') out += "\\n";
+ else if (c == '\t') out += "\\t";
+ else if (c == '\r') out += "\\r";
+ else out += c;
+ }
+ return out + "\"";
+}
+
// Performs generation using a loaded model, with either raw text or pre-tokenized prompts.
-int run_generate(const std::string& spec, const std::string& prompt_arg, int max_tokens) {
+// `top_logprobs` mirrors the engine knob: -1 = off; 0 = each token's own log-prob;
+// N > 0 = also its N most-likely alternatives (printed to stderr after generation).
+int run_generate(const std::string& spec, const std::string& prompt_arg, int max_tokens,
+ int top_logprobs = -1) {
// Resolve and load the model (GGUF file or safetensors dir; downloads if needed)
LoadedModel lm = load_for_inference(spec);
mlxforge::DecoderModel& model = *lm.model;
@@ -194,13 +212,29 @@ int run_generate(const std::string& spec, const std::string& prompt_arg, int max
std::string piece = detok.add(id);
std::fwrite(piece.data(), 1, piece.size(), stdout);
std::fflush(stdout);
- });
+ }, top_logprobs);
// Output any final detokenized tail remaining in the streaming detokenizer
std::string tail = detok.finish();
std::fwrite(tail.data(), 1, tail.size(), stdout);
std::fputc('\n', stdout);
+ // Per-token log-probs go to stderr (stdout stays the generated text). One line
+ // per emitted token: its text + log-prob, then any requested alternatives.
+ if (top_logprobs >= 0) {
+ std::fprintf(stderr, "logprobs (%zu tokens):\n", r.token_logprobs.size());
+ for (const mlxforge::TokenLogprob& lp : r.token_logprobs) {
+ std::fprintf(stderr, " %-16s logprob=%8.4f", show_token(tok.decode({lp.id})).c_str(),
+ lp.logprob);
+ if (!lp.top.empty()) {
+ std::fprintf(stderr, " top:");
+ for (const auto& alt : lp.top)
+ std::fprintf(stderr, " %s=%.4f", show_token(tok.decode({alt.first})).c_str(), alt.second);
+ }
+ std::fputc('\n', stderr);
+ }
+ }
+
// Log some generation statistics
mlxforge::log::info("generated {} tokens{}", r.tokens.size(),
r.hit_eos ? " (stopped at EOS)" : "");
@@ -334,12 +368,27 @@ int main(int argc, char** argv) {
if (cmd == "generate") {
// Print usage if not enough arguments; otherwise run generation logic
if (argc < 4) {
- std::fprintf(stderr, "usage: mlxforge-cli generate [max_tokens]\n");
+ std::fprintf(stderr,
+ "usage: mlxforge-cli generate [max_tokens] "
+ "[--logprobs [N]]\n");
return 2;
}
- // Parse max_tokens if provided, otherwise default to 64
- const int max_tokens = argc >= 5 ? std::stoi(argv[4]) : 64;
- return run_generate(argv[2], argv[3], max_tokens);
+ // Positional [max_tokens] (default 64) plus an optional --logprobs [N] flag (N
+ // alternatives, default 0 = the chosen token's own log-prob only).
+ int max_tokens = 64;
+ int top_logprobs = -1;
+ for (int i = 4; i < argc; ++i) {
+ const std::string a = argv[i];
+ if (a == "--logprobs") {
+ if (i + 1 < argc && std::isdigit(static_cast(argv[i + 1][0])))
+ top_logprobs = std::stoi(argv[++i]);
+ else
+ top_logprobs = 0;
+ } else {
+ max_tokens = std::stoi(a);
+ }
+ }
+ return run_generate(argv[2], argv[3], max_tokens, top_logprobs);
}
if (cmd == "image") {
// Vision-language generation: describe / answer about an image.
diff --git a/bindings/node/index.d.ts b/bindings/node/index.d.ts
index fe8982f..f35567d 100644
--- a/bindings/node/index.d.ts
+++ b/bindings/node/index.d.ts
@@ -24,6 +24,12 @@ export interface SamplingOptions {
seed?: number;
/** Max new tokens (default 64). */
maxTokens?: number;
+ /**
+ * OpenAI logprobs. 0 (default) => off. N > 0 => report each emitted token's
+ * own log-prob plus its (N - 1) most-likely alternatives (so 1 = chosen-only).
+ * Retrieve them from the stream's {@link Stream.logprobs} after consumption.
+ */
+ logprobs?: number;
/**
* Constrained decoding. "json" forces any valid JSON value; otherwise a
* JSON-Schema string (supported subset: a top-level object with ordered,
@@ -39,12 +45,29 @@ export interface ChatMessage {
content: string;
}
+/** One token's log-probability (OpenAI logprobs `content` entry). */
+export interface TokenLogprob {
+ /** The token's decoded text. */
+ token: string;
+ /** Natural-log probability of the token (<= 0). */
+ logprob: number;
+ /** The token's raw UTF-8 bytes. */
+ bytes: number[];
+ /** The most-likely alternatives at this position (may be empty). */
+ top_logprobs: { token: string; logprob: number; bytes: number[] }[];
+}
+
/** An async-iterable stream of decoded text chunks for one request. */
export interface Stream extends AsyncIterable {
/** Ask the engine to stop generating this request. */
cancel(): void;
/** "stop" | "length" | "cancel" | "" (while running). */
readonly finishReason: string;
+ /**
+ * Per-token log-probs, available after the stream has been fully consumed.
+ * Returns null when `logprobs` was not requested (or none were produced).
+ */
+ logprobs(): TokenLogprob[] | null;
}
export class Engine {
diff --git a/bindings/node/index.js b/bindings/node/index.js
index 87c3b67..02954f5 100644
--- a/bindings/node/index.js
+++ b/bindings/node/index.js
@@ -8,6 +8,9 @@ const native = require('./build/Release/mlxforge_node.node');
// Wrap a native Request as an async-iterable of decoded text chunks. The
// request is disposed when iteration finishes (or the consumer breaks out).
function streamRequest(req) {
+ // Captured once the stream is fully drained (the native handle is disposed
+ // right after, so we read logprobs while it is still valid).
+ let logprobs = null;
return {
// The underlying native handle, in case callers want cancel()/finishReason().
request: req,
@@ -17,10 +20,18 @@ function streamRequest(req) {
get finishReason() {
return req.finishReason();
},
+ /**
+ * Per-token log-probs (OpenAI logprobs `content` shape), or null when none
+ * were requested. Available after the stream has been fully consumed.
+ */
+ logprobs() {
+ return logprobs;
+ },
async *[Symbol.asyncIterator]() {
try {
let chunk;
while ((chunk = await req.next()) !== null) yield chunk;
+ logprobs = req.logprobs(); // capture before dispose() invalidates the handle
} finally {
req.dispose();
}
diff --git a/bindings/node/src/addon.cc b/bindings/node/src/addon.cc
index f6f498c..f127e67 100644
--- a/bindings/node/src/addon.cc
+++ b/bindings/node/src/addon.cc
@@ -32,6 +32,9 @@ mlxforge_sampling parse_sampling(const Napi::Object& o, std::string& schema_out)
s.presence_penalty = static_cast(num("presencePenalty", 0.0));
s.seed = static_cast(num("seed", 0.0));
s.max_tokens = static_cast(num("maxTokens", 0.0));
+ // OpenAI logprobs: 0 = off; N > 0 = the chosen token's logprob plus (N - 1)
+ // alternatives (so 1 = chosen-only). Retrieved via Request.logprobs().
+ s.logprobs = static_cast(num("logprobs", 0.0));
if (o.Has("jsonSchema") && o.Get("jsonSchema").IsString())
schema_out = o.Get("jsonSchema").As().Utf8Value();
else if (o.Has("responseFormat") && o.Get("responseFormat").IsString())
@@ -53,6 +56,7 @@ class RequestWrap : public Napi::ObjectWrap {
InstanceMethod("next", &RequestWrap::Next),
InstanceMethod("cancel", &RequestWrap::Cancel),
InstanceMethod("finishReason", &RequestWrap::FinishReason),
+ InstanceMethod("logprobs", &RequestWrap::Logprobs),
InstanceMethod("dispose", &RequestWrap::Dispose),
});
g_request_ctor = Napi::Persistent(f);
@@ -93,6 +97,19 @@ class RequestWrap : public Napi::ObjectWrap {
return Napi::String::New(info.Env(), req_ ? mlxforge_request_finish_reason(req_) : "");
}
+ // The accumulated per-token log-probs, parsed from the C ABI's OpenAI-shaped
+ // JSON (or null when none / not requested). Call after the stream is drained,
+ // before dispose().
+ Napi::Value Logprobs(const Napi::CallbackInfo& info) {
+ Napi::Env env = info.Env();
+ char* j = req_ ? mlxforge_request_logprobs(req_) : nullptr;
+ if (!j) return env.Null();
+ std::string s(j);
+ mlxforge_string_free(j);
+ Napi::Object json = env.Global().Get("JSON").As();
+ return json.Get("parse").As().Call(json, {Napi::String::New(env, s)});
+ }
+
void Dispose(const Napi::CallbackInfo&) { free_req(); }
};
diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs
index 53e5e1f..73c80cb 100644
--- a/bindings/rust/src/lib.rs
+++ b/bindings/rust/src/lib.rs
@@ -48,6 +48,7 @@ struct CSampling {
seed: u64,
max_tokens: c_int,
json_schema: *const c_char,
+ logprobs: c_int,
}
// Mirrors mlxforge_embed_opts. pooling/add_eos are tri-state: -1 = model default.
@@ -122,6 +123,7 @@ extern "C" {
) -> *mut mlxforge_request;
fn mlxforge_request_next(r: *mut mlxforge_request, text: *mut *mut c_char) -> c_int;
fn mlxforge_request_finish_reason(r: *mut mlxforge_request) -> *const c_char;
+ fn mlxforge_request_logprobs(r: *mut mlxforge_request) -> *mut c_char;
fn mlxforge_request_free(r: *mut mlxforge_request);
}
@@ -159,6 +161,9 @@ pub struct Sampling {
pub max_tokens: i32,
/// Constrained decoding: "json" or a JSON-Schema string (see the C ABI docs).
pub json_schema: Option,
+ /// OpenAI logprobs: 0 = off; N > 0 = the chosen token's log-prob plus (N - 1)
+ /// alternatives (so 1 = chosen-only). Retrieve via [`Engine::chat_with_logprobs`].
+ pub logprobs: i32,
}
impl Sampling {
@@ -234,6 +239,7 @@ impl Engine {
seed: s.seed,
max_tokens: s.max_tokens,
json_schema,
+ logprobs: s.logprobs,
}
}
@@ -268,6 +274,42 @@ impl Engine {
Ok(drain(req))
}
+ /// Run a chat completion and return both the text and the per-token log-probs
+ /// (the C ABI's OpenAI-shaped JSON `content` array, or `None` when `logprobs`
+ /// was not set on `sampling`).
+ pub fn chat_with_logprobs(
+ &self,
+ messages: &[(&str, &str)],
+ sampling: &Sampling,
+ ) -> Result<(String, Option), String> {
+ let owned: Vec<(CString, CString)> = messages
+ .iter()
+ .map(|(r, c)| {
+ (
+ CString::new(*r).unwrap_or_default(),
+ CString::new(*c).unwrap_or_default(),
+ )
+ })
+ .collect();
+ let msgs: Vec = owned
+ .iter()
+ .map(|(r, c)| Msg {
+ role: r.as_ptr(),
+ content: c.as_ptr(),
+ })
+ .collect();
+
+ let mut schema_keep: Option = None;
+ let cs = Self::c_sampling(sampling, &mut schema_keep);
+ let mut err: *mut c_char = ptr::null_mut();
+ let req =
+ unsafe { mlxforge_submit_chat(self.handle, msgs.as_ptr(), msgs.len(), &cs, &mut err) };
+ if req.is_null() {
+ return Err(unsafe { take_string(err) }.unwrap_or_else(|| "submit failed".into()));
+ }
+ Ok(drain_with_logprobs(req))
+ }
+
/// Run a raw-text completion (no chat template) to completion.
pub fn text(&self, prompt: &str, sampling: &Sampling) -> Result {
let cprompt = CString::new(prompt).map_err(|_| "prompt contains NUL".to_string())?;
@@ -436,3 +478,25 @@ fn drain(req: *mut mlxforge_request) -> String {
unsafe { mlxforge_request_free(req) };
out
}
+
+// Drain a request to text and then read its accumulated per-token log-probs
+// (the OpenAI-shaped JSON `content` array, or None when none were produced),
+// freeing the request.
+fn drain_with_logprobs(req: *mut mlxforge_request) -> (String, Option) {
+ let mut out = String::new();
+ loop {
+ let mut text: *mut c_char = ptr::null_mut();
+ let rc = unsafe { mlxforge_request_next(req, &mut text) };
+ if rc == 0 {
+ if let Some(s) = unsafe { take_string(text) } {
+ out.push_str(&s);
+ }
+ } else {
+ break;
+ }
+ }
+ // Read logprobs while the request is still alive, then free it.
+ let logprobs = unsafe { take_string(mlxforge_request_logprobs(req)) };
+ unsafe { mlxforge_request_free(req) };
+ (out, logprobs)
+}
diff --git a/bindings/swift/Sources/MLXForge/MLXForge.swift b/bindings/swift/Sources/MLXForge/MLXForge.swift
index a09eeea..2f4830e 100644
--- a/bindings/swift/Sources/MLXForge/MLXForge.swift
+++ b/bindings/swift/Sources/MLXForge/MLXForge.swift
@@ -23,6 +23,9 @@ public struct Sampling {
/// JSON-Schema string (supported subset: a top-level object with ordered,
/// required, scalar-typed properties). Output is masked to be well-formed JSON.
public var jsonSchema: String? = nil
+ /// OpenAI logprobs: 0 = off; N > 0 = the chosen token's log-prob plus (N - 1)
+ /// alternatives (so 1 = chosen-only). Retrieved via `Engine.completeWithLogprobs`.
+ public var logprobs: Int32 = 0
public init() {}
public static var greedy: Sampling { Sampling() }
@@ -34,7 +37,7 @@ public struct Sampling {
temperature: temperature, top_k: topK, top_p: topP, min_p: minP,
repetition_penalty: repetitionPenalty, frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty, seed: seed, max_tokens: maxTokens,
- json_schema: nil)
+ json_schema: nil, logprobs: logprobs)
}
}
@@ -168,6 +171,37 @@ public final class Engine {
return out
}
+ /// Run a chat to completion and return both the text and the per-token log-probs
+ /// (the C ABI's OpenAI-shaped JSON `content` string, or nil when `logprobs` was
+ /// not set on `sampling`). Logprobs are read once the stream is fully drained.
+ public func completeWithLogprobs(_ messages: [ChatMessage], sampling: Sampling = .greedy)
+ async throws -> (text: String, logprobs: String?)
+ {
+ let req = try submitChat(messages, sampling)
+ return try await withCheckedThrowingContinuation { cont in
+ DispatchQueue.global(qos: .userInitiated).async {
+ var out = ""
+ while true {
+ var text: UnsafeMutablePointer?
+ let rc = mlxforge_request_next(req, &text)
+ if rc == 0, let t = text {
+ out += String(cString: t)
+ mlxforge_string_free(t)
+ } else {
+ break
+ }
+ }
+ var logprobs: String? = nil
+ if let lp = mlxforge_request_logprobs(req) {
+ logprobs = String(cString: lp)
+ mlxforge_string_free(lp)
+ }
+ mlxforge_request_free(req)
+ cont.resume(returning: (out, logprobs))
+ }
+ }
+ }
+
/// Embed text into a (by default unit-normalized) vector.
///
/// With all defaults the model self-selects its convention: a Qwen3-Embedding
diff --git a/cmake/abi-baseline.txt b/cmake/abi-baseline.txt
index aa934c5..cc2cf02 100644
--- a/cmake/abi-baseline.txt
+++ b/cmake/abi-baseline.txt
@@ -9,6 +9,7 @@ mlxforge_floats_free
mlxforge_request_cancel
mlxforge_request_finish_reason
mlxforge_request_free
+mlxforge_request_logprobs
mlxforge_request_next
mlxforge_string_free
mlxforge_submit_chat
diff --git a/src/capi/mlxforge.cpp b/src/capi/mlxforge.cpp
index f96f53e..712a026 100644
--- a/src/capi/mlxforge.cpp
+++ b/src/capi/mlxforge.cpp
@@ -15,6 +15,8 @@
#include
#include
+#include
+
#include "core/config.h"
#include "runtime/engine.h"
#include "scheduler/request.h"
@@ -61,6 +63,9 @@ mlxforge::SamplingParams to_params(const mlxforge_sampling* s) {
p.frequency_penalty = std::isfinite(s->frequency_penalty) ? s->frequency_penalty : 0.0f;
p.presence_penalty = std::isfinite(s->presence_penalty) ? s->presence_penalty : 0.0f;
p.seed = s->seed;
+ // logprobs: 0 => off (-1); N > 0 => the chosen token's logprob plus (N - 1)
+ // alternatives, so the engine's top_logprobs (alternatives count) is N - 1.
+ p.top_logprobs = s->logprobs > 0 ? s->logprobs - 1 : -1;
return p;
}
@@ -68,6 +73,14 @@ int sampling_max_tokens(const mlxforge_sampling* s) {
return (s && s->max_tokens > 0) ? s->max_tokens : 64;
}
+// One {token, logprob, bytes} entry: decode `id` to its text and raw UTF-8 bytes.
+nlohmann::json lp_entry(int id, float logprob, const mlxforge::Tokenizer& tok) {
+ const std::string text = tok.decode({id});
+ nlohmann::json bytes = nlohmann::json::array();
+ for (unsigned char c : text) bytes.push_back(static_cast(c));
+ return {{"token", text}, {"logprob", logprob}, {"bytes", std::move(bytes)}};
+}
+
} // namespace
// Opaque handles. The engine wrapper owns the C++ Engine. The request wrapper
@@ -83,6 +96,12 @@ struct mlxforge_request {
std::shared_ptr req;
std::unique_ptr detok;
bool done = false;
+ // OpenAI logprobs: when `want_logprobs`, each token drained by
+ // mlxforge_request_next pops its log-prob (in lockstep) into `logprobs`;
+ // mlxforge_request_logprobs serializes the accumulated list using `tok`.
+ bool want_logprobs = false;
+ const mlxforge::Tokenizer* tok = nullptr;
+ std::vector logprobs;
};
extern "C" {
@@ -231,6 +250,8 @@ mlxforge_request* submit_ids(mlxforge_engine* engine, std::vector prompt_id
handle->req = req;
handle->detok =
std::make_unique(engine->engine->tokenizer());
+ handle->tok = &engine->engine->tokenizer();
+ handle->want_logprobs = req->params.top_logprobs >= 0;
return handle.release();
}
@@ -370,6 +391,12 @@ int mlxforge_request_next(mlxforge_request* req, char** text) {
for (;;) {
int tok = 0;
if (req->req->tokens.pop(tok)) {
+ // Pop this token's log-prob in lockstep (the worker pushes one per emitted
+ // token) and accumulate it for mlxforge_request_logprobs.
+ if (req->want_logprobs) {
+ mlxforge::TokenLogprob lp;
+ if (req->req->logprobs.pop(lp)) req->logprobs.push_back(std::move(lp));
+ }
std::string piece = req->detok->add(tok);
if (piece.empty()) continue;
if (text) *text = dup_cstr(piece);
@@ -399,6 +426,23 @@ const char* mlxforge_request_finish_reason(mlxforge_request* req) {
return (req && req->req) ? req->req->finish_reason.c_str() : "";
}
+char* mlxforge_request_logprobs(mlxforge_request* req) {
+ if (!req || !req->tok || req->logprobs.empty()) return nullptr;
+ try {
+ nlohmann::json content = nlohmann::json::array();
+ for (const mlxforge::TokenLogprob& lp : req->logprobs) {
+ nlohmann::json entry = lp_entry(lp.id, lp.logprob, *req->tok);
+ nlohmann::json top = nlohmann::json::array();
+ for (const auto& alt : lp.top) top.push_back(lp_entry(alt.first, alt.second, *req->tok));
+ entry["top_logprobs"] = std::move(top);
+ content.push_back(std::move(entry));
+ }
+ return dup_cstr(content.dump());
+ } catch (...) {
+ return nullptr;
+ }
+}
+
void mlxforge_request_free(mlxforge_request* req) {
if (!req) return;
try {
@@ -408,6 +452,12 @@ void mlxforge_request_free(mlxforge_request* req) {
req->req->cancelled.store(true);
int tok = 0;
while (req->req->tokens.pop(tok)) {
+ // Drain log-probs in lockstep too, so the worker's producer never blocks
+ // on a full, abandoned log-prob queue.
+ if (req->want_logprobs) {
+ mlxforge::TokenLogprob lp;
+ req->req->logprobs.pop(lp);
+ }
}
}
} catch (...) {
diff --git a/src/capi/mlxforge.h b/src/capi/mlxforge.h
index 4bc14dd..d09ccbf 100644
--- a/src/capi/mlxforge.h
+++ b/src/capi/mlxforge.h
@@ -37,8 +37,12 @@ extern "C" {
* last-token pooling, trailing EOS, instruction prefix).
* v3: mlxforge_submit_image (Qwen3-VL vision-language: a prompt + one image,
* served single-stream).
- * v4: mlxforge_image + mlxforge_submit_images (a prompt + N images). */
-#define MLXFORGE_ABI_VERSION 4
+ * v4: mlxforge_image + mlxforge_submit_images (a prompt + N images).
+ * v5: mlxforge_sampling.logprobs + mlxforge_request_logprobs (OpenAI per-token
+ * log-probabilities: the chosen token's logprob and its top-N alternatives,
+ * accumulated as the request is drained and returned as an OpenAI-shaped
+ * JSON array). */
+#define MLXFORGE_ABI_VERSION 5
typedef struct mlxforge_engine mlxforge_engine;
typedef struct mlxforge_request mlxforge_request;
@@ -72,6 +76,12 @@ typedef struct {
* object with ordered, required, scalar-typed properties). The output is
* masked so it can only be well-formed JSON. */
const char* json_schema;
+ /* Per-token log-probabilities (OpenAI logprobs). 0 (the zero-init default) =>
+ * off. N > 0 => report each emitted token's own log-prob plus its (N - 1) most
+ * likely alternatives, so 1 = the chosen token's log-prob only. Retrieve them
+ * with mlxforge_request_logprobs() once the request is drained. Not supported
+ * on vision (mlxforge_submit_image) requests, which ignore it. */
+ int logprobs;
} mlxforge_sampling;
/* ---- Library info ---------------------------------------------------------*/
@@ -217,6 +227,16 @@ void mlxforge_request_cancel(mlxforge_request* req);
* 1. Owned by the request. */
const char* mlxforge_request_finish_reason(mlxforge_request* req);
+/* Per-token log-probabilities for the tokens emitted so far, as a newly-allocated
+ * JSON array in the OpenAI logprobs `content` shape:
+ * [{ "token": "...", "logprob": , "bytes": [...],
+ * "top_logprobs": [{ "token", "logprob", "bytes" }, ...] }, ...]
+ * Log-probs accumulate as mlxforge_request_next() drains the stream, so call this
+ * once the request is done (next() returned 1). Returns NULL when logprobs were
+ * not requested (sampling.logprobs == 0) or none were produced. The caller owns
+ * the string and frees it with mlxforge_string_free(). */
+char* mlxforge_request_logprobs(mlxforge_request* req);
+
/* Destroy a request. If it is still running it is cancelled and drained first
* (so the worker never blocks on a full token queue). NULL is ignored. */
void mlxforge_request_free(mlxforge_request* req);
diff --git a/src/runtime/single_stream.cpp b/src/runtime/single_stream.cpp
index 14c35fe..eab2205 100644
--- a/src/runtime/single_stream.cpp
+++ b/src/runtime/single_stream.cpp
@@ -7,25 +7,55 @@
#include "sample/sampler.h"
#include "mlx/ops.h"
+#include "mlx/random.h"
#include "mlx/transforms.h"
namespace mlxforge {
namespace {
-// Greedy token id from the last position of (1, L, vocab) logits.
-int greedy_last(const mx::array& logits) {
+std::vector to_ints(const mx::array& a) {
+ mx::array c = mx::contiguous(mx::astype(a, mx::int32));
+ mx::eval(c);
+ return std::vector(c.data(), c.data() + c.size());
+}
+std::vector to_floats(const mx::array& a) {
+ mx::array c = mx::contiguous(mx::astype(a, mx::float32));
+ mx::eval(c);
+ return std::vector(c.data(), c.data() + c.size());
+}
+
+// Greedy token id from the last position of (1, L, vocab) logits. When `lp` is
+// non-null, also fill it with the chosen token's log-prob and (per `top_logprobs`)
+// its top-K alternatives, reusing the sampler's greedy + logprob path.
+int greedy_last(const mx::array& logits, int top_logprobs, TokenLogprob* lp) {
const int L = logits.shape()[1];
const int V = logits.shape()[2];
mx::array last = mx::reshape(mx::slice(logits, {0, L - 1, 0}, {1, L, V}), {1, V});
- mx::array tok = Sampler::greedy(last);
- mx::eval(tok);
- return tok.item();
+ if (!lp) {
+ mx::array tok = Sampler::greedy(last);
+ mx::eval(tok);
+ return tok.item();
+ }
+ SamplingParams p;
+ p.temperature = 0.0f; // greedy
+ p.top_logprobs = top_logprobs;
+ SampleResult res = Sampler::sample(last, p, mx::random::key(0));
+ // Read through to_ints/to_floats (which astype first): the chosen logprob is in
+ // the model's fp16, so a raw item() would misread the 2-byte value.
+ const int id = to_ints(res.tokens)[0];
+ lp->id = id;
+ lp->logprob = to_floats(res.logprobs)[0];
+ lp->top.clear();
+ const std::vector tids = to_ints(res.top_tokens);
+ const std::vector tlp = to_floats(res.top_logprobs);
+ for (size_t k = 0; k < tids.size(); ++k) lp->top.emplace_back(tids[k], tlp[k]);
+ return id;
}
} // namespace
GenerateResult greedy_generate(const DecoderModel& model, const std::vector& prompt_ids,
int max_tokens, const std::vector& eos_ids,
- const std::function& on_token) {
+ const std::function& on_token, int top_logprobs) {
auto is_eos = [&](int id) {
return std::find(eos_ids.begin(), eos_ids.end(), id) != eos_ids.end();
};
@@ -39,10 +69,15 @@ GenerateResult greedy_generate(const DecoderModel& model, const std::vector
KVCache cache(model.config().n_layers);
mx::array prompt(prompt_ids.data(), {1, static_cast(prompt_ids.size())}, mx::int32);
+ // `next_lp` carries the log-prob record for `next`, pushed when (and only when)
+ // `next` is actually emitted (EOS is excluded, like its token id).
+ const bool want_lp = top_logprobs >= 0;
+ TokenLogprob next_lp;
+
// Prefill + first sample: time to first token. greedy_last() eval()s, so the
// elapsed wall time covers the actual GPU work, not just graph construction.
const auto t_start = Clock::now();
- int next = greedy_last(model.forward(prompt, &cache));
+ int next = greedy_last(model.forward(prompt, &cache), top_logprobs, want_lp ? &next_lp : nullptr);
result.ttft_ms = ms_since(t_start);
const auto t_decode_start = Clock::now();
@@ -52,10 +87,11 @@ GenerateResult greedy_generate(const DecoderModel& model, const std::vector
break;
}
result.tokens.push_back(next);
+ if (want_lp) result.token_logprobs.push_back(next_lp);
if (on_token) on_token(next);
mx::array step(&next, {1, 1}, mx::int32);
- next = greedy_last(model.forward(step, &cache));
+ next = greedy_last(model.forward(step, &cache), top_logprobs, want_lp ? &next_lp : nullptr);
}
// The first token came from prefill; throughput here is the steady-state
// decode rate over the tokens generated after it.
diff --git a/src/runtime/single_stream.h b/src/runtime/single_stream.h
index 571c06b..3000255 100644
--- a/src/runtime/single_stream.h
+++ b/src/runtime/single_stream.h
@@ -8,11 +8,15 @@
#include
#include "model/decoder_model.h"
+#include "scheduler/request.h" // TokenLogprob
namespace mlxforge {
struct GenerateResult {
std::vector tokens; // generated token ids (EOS excluded)
+ // Per-token log-probs (OpenAI logprobs), aligned with `tokens`; populated only
+ // when greedy_generate is asked for them (top_logprobs >= 0), else empty.
+ std::vector token_logprobs;
bool hit_eos = false; // stopped because an EOS token was produced
double ttft_ms = 0.0; // prefill + first sample (time to first token)
double decode_ms = 0.0; // wall time spent generating tokens after the first
@@ -26,8 +30,12 @@ struct GenerateResult {
// Greedy (argmax) single-stream generation. Calls `on_token(id)` for each
// emitted token. Stops when an EOS token would be produced or after max_tokens.
+// `top_logprobs` mirrors SamplingParams: -1 = off (no logprob work); 0 = record
+// each emitted token's own log-prob into result.token_logprobs; N > 0 = also
+// record its N most-likely alternatives.
GenerateResult greedy_generate(const DecoderModel& model, const std::vector& prompt_ids,
int max_tokens, const std::vector& eos_ids,
- const std::function& on_token = {});
+ const std::function& on_token = {},
+ int top_logprobs = -1);
} // namespace mlxforge
diff --git a/src/runtime/worker.cpp b/src/runtime/worker.cpp
index 143debe..8340d39 100644
--- a/src/runtime/worker.cpp
+++ b/src/runtime/worker.cpp
@@ -28,18 +28,27 @@ std::vector read_ids(const mx::array& a) {
return std::vector(c.data(), c.data() + c.size());
}
+std::vector read_floats(const mx::array& a) {
+ mx::array c = mx::contiguous(mx::astype(a, mx::float32));
+ mx::eval(c);
+ return std::vector(c.data(), c.data() + c.size());
+}
+
bool is_eos(const Request& req, int id) {
return std::find(req.eos_ids.begin(), req.eos_ids.end(), id) != req.eos_ids.end();
}
// Emit token `id` for a request, returning true if the request is now finished.
-bool consume(Request& req, int& produced, int id) {
+// `lp` (when non-null) is the token's log-prob record, pushed in lockstep with
+// the token; EOS is never emitted, so it carries no logprob.
+bool consume(Request& req, int& produced, int id, const TokenLogprob* lp) {
if (is_eos(req, id)) {
req.finish_reason = "stop";
return true;
}
if (produced == 0) req.first_token_time = Request::Clock::now(); // TTFT marker
req.tokens.push(id);
+ if (lp) req.logprobs.push(*lp);
if (++produced >= req.max_tokens) {
req.finish_reason = "length";
return true;
@@ -64,6 +73,7 @@ void Worker::handle_embedding(Request& req) {
req.finish_reason = "error";
}
req.tokens.close(); // unblock the waiting submitter
+ req.logprobs.close();
}
void Worker::admit_multimodal(const std::shared_ptr& req) {
@@ -111,6 +121,7 @@ void Worker::admit_multimodal(const std::shared_ptr& req) {
log::error("worker: multimodal admit error: {}", e.what());
req->finish_reason = "error";
req->tokens.close();
+ req->logprobs.close();
}
}
@@ -270,27 +281,31 @@ void Worker::register_rows(const std::vector>& incoming
}
}
- std::vector first =
- read_ids(sample_rows(last_logits, base, static_cast(incoming.size())));
+ const int count = static_cast(incoming.size());
+ std::vector first;
+ std::vector row_lp;
+ std::vector has_lp;
+ finalize_sample(sample_rows(last_logits, base, count), count, first, row_lp, has_lp);
- for (size_t i = 0; i < incoming.size(); ++i) {
- const int b = base + static_cast(i);
+ for (int i = 0; i < count; ++i) {
+ const int b = base + i;
feed_[b] = first[i]; // feed the first token next step
history_[b].push_back(first[i]); // and let later penalties see it
advance_grammar(*reqs_[b], first[i]);
if (reqs_[b]->cancelled.load()) {
reqs_[b]->finish_reason = "cancel";
finished_[b] = true;
- } else if (consume(*reqs_[b], produced_[b], first[i])) {
+ } else if (consume(*reqs_[b], produced_[b], first[i], has_lp[i] ? &row_lp[i] : nullptr)) {
finished_[b] = true;
}
}
}
-mx::array Worker::sample_rows(const mx::array& logits, int row_offset, int count) {
+Worker::SampledRows Worker::sample_rows(const mx::array& logits, int row_offset, int count) {
const int vocab = logits.shape()[1];
std::vector tokens;
tokens.reserve(count);
+ SampledRows out{mx::zeros({0}, mx::int32), {}, {}, {}, {}};
for (int i = 0; i < count; ++i) {
const int r = row_offset + i;
mx::array row = mx::slice(logits, {i, 0}, {i + 1, vocab}); // (1, vocab)
@@ -313,8 +328,45 @@ mx::array Worker::sample_rows(const mx::array& logits, int row_offset, int count
return Sampler::sample(row, p, ks.second, history);
}();
tokens.push_back(res.tokens); // (1,)
+ // Collect the log-prob arrays only for rows that asked, so the logprob
+ // subgraph stays dead (MLX prunes it) for everyone else.
+ if (p.top_logprobs >= 0) {
+ out.lp_rows.push_back(i);
+ out.lp_chosen.push_back(res.logprobs); // (1,)
+ out.lp_top_ids.push_back(res.top_tokens); // (1, K)
+ out.lp_top_lp.push_back(res.top_logprobs); // (1, K)
+ }
+ }
+ out.tokens = mx::concatenate(tokens, /*axis=*/0); // (count,)
+ return out;
+}
+
+void Worker::finalize_sample(const SampledRows& s, int count, std::vector& ids,
+ std::vector& row_lp, std::vector& has_lp) {
+ // The ONE async_eval per step, over the whole batch: the chosen tokens plus
+ // every logprob-enabled row's log-prob arrays, so they ride the same eval.
+ std::vector to_eval;
+ to_eval.reserve(1 + s.lp_chosen.size() + s.lp_top_ids.size() + s.lp_top_lp.size());
+ to_eval.push_back(s.tokens);
+ for (const auto& a : s.lp_chosen) to_eval.push_back(a);
+ for (const auto& a : s.lp_top_ids) to_eval.push_back(a);
+ for (const auto& a : s.lp_top_lp) to_eval.push_back(a);
+ mx::async_eval(to_eval);
+
+ ids = read_ids(s.tokens);
+ row_lp.assign(count, TokenLogprob{});
+ has_lp.assign(count, 0);
+ for (size_t j = 0; j < s.lp_rows.size(); ++j) {
+ const int i = s.lp_rows[j];
+ has_lp[i] = 1;
+ TokenLogprob& lp = row_lp[i];
+ lp.id = ids[i];
+ lp.logprob = read_floats(s.lp_chosen[j])[0];
+ const std::vector tids = read_ids(s.lp_top_ids[j]);
+ const std::vector tlp = read_floats(s.lp_top_lp[j]);
+ lp.top.reserve(tids.size());
+ for (size_t k = 0; k < tids.size(); ++k) lp.top.emplace_back(tids[k], tlp[k]);
}
- return mx::concatenate(tokens, /*axis=*/0); // (count,)
}
void Worker::decode_step() {
@@ -334,10 +386,13 @@ void Worker::decode_step() {
mx::array inputs(fed.data(), {bucket, 1}, mx::int32);
mx::array logits = model_->forward(inputs, *cache_); // (bucket, 1, vocab)
// Sample only the B real rows (dummy rows are excluded from the graph).
- mx::array next = sample_rows(mx::reshape(logits, {bucket, logits.shape()[2]}), 0, B);
+ SampledRows sampled = sample_rows(mx::reshape(logits, {bucket, logits.shape()[2]}), 0, B);
- mx::async_eval(next); // the ONE eval per decode step, over the whole batch
- std::vector ids = read_ids(next);
+ // One async_eval over the whole batch (tokens + any logprob arrays).
+ std::vector ids;
+ std::vector row_lp;
+ std::vector has_lp;
+ finalize_sample(sampled, B, ids, row_lp, has_lp);
for (int b = 0; b < B; ++b) {
if (finished_[b]) continue;
@@ -349,7 +404,8 @@ void Worker::decode_step() {
feed_[b] = ids[b];
history_[b].push_back(ids[b]); // penalties see the full sequence so far
advance_grammar(*reqs_[b], ids[b]);
- if (consume(*reqs_[b], produced_[b], ids[b])) finished_[b] = true;
+ if (consume(*reqs_[b], produced_[b], ids[b], has_lp[b] ? &row_lp[b] : nullptr))
+ finished_[b] = true;
}
// Drop the dummy rows so the cache holds only real rows again.
@@ -387,6 +443,7 @@ void Worker::evict_finished() {
request_us_sum_ += static_cast(us(now - r.enqueue_time).count());
reqs_[b]->tokens.close(); // signal the consumer
+ reqs_[b]->logprobs.close();
} else {
keep.push_back(b);
}
diff --git a/src/runtime/worker.h b/src/runtime/worker.h
index 8b65355..9e9a213 100644
--- a/src/runtime/worker.h
+++ b/src/runtime/worker.h
@@ -19,6 +19,7 @@
#include "cache/batch_kv_cache.h"
#include "model/decoder_model.h"
#include "runtime/metrics.h"
+#include "scheduler/request.h" // Request, TokenLogprob
#include "scheduler/scheduler.h"
#include "mlx/array.h"
@@ -70,11 +71,28 @@ class Worker {
// One decode step over the whole batch: forward -> sample -> async_eval ->
// push each row's token, marking finished rows.
void decode_step();
+
+ // Result of sampling the active batch in one graph: the chosen tokens, plus —
+ // for the rows that requested log-probs (params.top_logprobs >= 0) — their
+ // per-row log-prob arrays. All are built into the same graph so one async_eval
+ // covers them; the parallel vectors are aligned with `lp_rows`.
+ struct SampledRows {
+ mx::array tokens; // (count,) int32
+ std::vector lp_rows; // row-local indices (0..count-1) wanting logprobs
+ std::vector lp_chosen; // aligned: (1,) chosen-token log-prob
+ std::vector lp_top_ids; // aligned: (1,K) int32 alternatives
+ std::vector lp_top_lp; // aligned: (1,K) fp32 alternatives
+ };
// Sample one token for `count` rows of `logits`, where logits row i belongs to
// request reqs_[row_offset + i]. Applies each request's SamplingParams + penalty
// history and advances its RNG key. Builds one graph (no eval) so the caller
- // keeps the single-async_eval-per-step invariant. Returns a (count,) int32 array.
- mx::array sample_rows(const mx::array& logits, int row_offset, int count);
+ // keeps the single-async_eval-per-step invariant.
+ SampledRows sample_rows(const mx::array& logits, int row_offset, int count);
+ // Async-eval a sampled batch (the ONE eval per step) and read it back: fills
+ // `ids` with the chosen token id per local row, and — for logprob-enabled rows —
+ // `row_lp`/`has_lp` (both sized `count`) with the reconstructed TokenLogprob.
+ void finalize_sample(const SampledRows& s, int count, std::vector& ids,
+ std::vector& row_lp, std::vector& has_lp);
// Drop rows marked finished (filter the cache, compact the row vectors,
// close their token queues).
void evict_finished();
diff --git a/src/sample/sampler.cpp b/src/sample/sampler.cpp
index 33fa58b..a05983a 100644
--- a/src/sample/sampler.cpp
+++ b/src/sample/sampler.cpp
@@ -11,13 +11,49 @@ namespace mlxforge {
namespace {
constexpr float kNegInf = -std::numeric_limits::infinity();
-// log-prob of each chosen token under softmax(logits): (B, vocab), (B,) -> (B,).
-mx::array gather_logprob(const mx::array& logits, const mx::array& tokens) {
+// Log-prob distribution over the vocab: log(softmax(logits)), (B, vocab).
+mx::array log_probs(const mx::array& logits) {
+ return mx::log(mx::softmax(logits, /*axis=*/-1, /*precise=*/true));
+}
+
+// log-prob of each chosen token, given a precomputed (B, vocab) log-prob array
+// and the (B,) chosen ids -> (B,).
+mx::array gather_logprob(const mx::array& logp, const mx::array& tokens) {
const int batch = tokens.shape()[0];
- mx::array logp = mx::log(mx::softmax(logits, /*axis=*/-1, /*precise=*/true));
mx::array idx = mx::reshape(tokens, {batch, 1});
return mx::reshape(mx::take_along_axis(logp, idx, /*axis=*/-1), {batch});
}
+
+// Top-k (id, log-prob) per row in descending log-prob order, from a (B, vocab)
+// log-prob array. Returns ((B, k) int32 ids, (B, k) fp32 log-probs). Pure graph
+// ops — no eval, so the caller keeps the single-async_eval-per-step invariant.
+std::pair top_k_logprobs(const mx::array& logp, int k) {
+ const int batch = logp.shape()[0];
+ // argsort of the negated log-probs is ascending of -logp = descending of logp;
+ // the first k columns are the k most-likely ids, already in order.
+ mx::array order = mx::argsort(mx::negative(logp), /*axis=*/-1); // (B, vocab)
+ mx::array top_ids = mx::slice(order, {0, 0}, {batch, k}); // (B, k)
+ mx::array top_lp = mx::take_along_axis(logp, top_ids, /*axis=*/-1);
+ return {mx::astype(top_ids, mx::int32), top_lp};
+}
+
+// Empty (B, 0) placeholders for the top-k fields when no alternatives were asked.
+SampleResult with_no_top(const mx::array& tokens, const mx::array& logprobs) {
+ const int batch = tokens.shape()[0];
+ return {tokens, logprobs, mx::zeros({batch, 0}, mx::int32),
+ mx::zeros({batch, 0}, mx::float32)};
+}
+
+// Attach the chosen-token logprob and (when params.top_logprobs > 0) the top-k
+// alternatives, both derived from `dist` (a coherent (B, vocab) log-prob array).
+SampleResult attach_logprobs(const mx::array& tokens, const mx::array& dist,
+ const SamplingParams& params, int vocab) {
+ mx::array chosen = gather_logprob(dist, tokens);
+ const int k = std::min(params.top_logprobs, vocab);
+ if (k <= 0) return with_no_top(tokens, chosen);
+ std::pair top = top_k_logprobs(dist, k);
+ return {tokens, chosen, top.first, top.second};
+}
} // namespace
mx::array Sampler::greedy(const mx::array& logits) {
@@ -106,22 +142,27 @@ SampleResult Sampler::sample(const mx::array& logits, const SamplingParams& para
const mx::array& key, const mx::array& history) {
// Penalties reshape the logit landscape and so apply to greedy too.
mx::array adj = history.shape().back() > 0 ? penalize(logits, params, history) : logits;
+ const int vocab = logits.shape().back();
+
+ // Reported log-probs come from `adj` — the penalized, pre-temperature,
+ // pre-filter distribution. That is a coherent softmax over the whole vocab (the
+ // temperature-scaled/top-k-masked logits are not), so the chosen-token logprob
+ // and its top-k alternatives are computed from the same array.
+ auto finish = [&](const mx::array& tokens) {
+ if (params.top_logprobs < 0) return with_no_top(tokens, gather_logprob(log_probs(adj), tokens));
+ return attach_logprobs(tokens, log_probs(adj), params, vocab);
+ };
- if (params.temperature <= 0.0f) {
- mx::array tokens = greedy(adj);
- return {tokens, gather_logprob(adj, tokens)};
- }
+ if (params.temperature <= 0.0f) return finish(greedy(adj));
mx::array scaled = mx::divide(adj, mx::array(params.temperature, adj.dtype()));
// Clamp top_k to the vocab size: mx::topk throws for k > axis length, and a
// caller-supplied k larger than the vocab is just "keep everything" anyway.
- const int vocab = logits.shape().back();
if (params.top_k > 0) scaled = apply_top_k(scaled, std::min(params.top_k, vocab));
if (params.top_p < 1.0f) scaled = apply_top_p(scaled, params.top_p);
if (params.min_p > 0.0f) scaled = apply_min_p(scaled, params.min_p);
- mx::array tokens = mx::astype(mx::random::categorical(scaled, /*axis=*/-1, key), mx::int32);
- return {tokens, gather_logprob(scaled, tokens)};
+ return finish(mx::astype(mx::random::categorical(scaled, /*axis=*/-1, key), mx::int32));
}
} // namespace mlxforge
diff --git a/src/sample/sampler.h b/src/sample/sampler.h
index dd480ef..0a589ef 100644
--- a/src/sample/sampler.h
+++ b/src/sample/sampler.h
@@ -22,6 +22,12 @@ struct SamplingParams {
float presence_penalty = 0.0f; // 0 => disabled (subtracts penalty if seen)
uint64_t seed = 0;
+ // Per-token log-probability reporting (OpenAI logprobs/top_logprobs). -1 = off
+ // (no logprob work); 0 = report the chosen token's logprob only; N > 0 = also
+ // report the N most-likely alternatives. Off by default so the hot path is
+ // untouched unless a consumer asks for it.
+ int top_logprobs = -1;
+
// Whether any penalty is active (skips the per-row history machinery when not).
bool has_penalties() const {
return repetition_penalty != 1.0f || frequency_penalty != 0.0f ||
@@ -32,6 +38,10 @@ struct SamplingParams {
struct SampleResult {
mx::array tokens; // (B,) int32
mx::array logprobs; // (B,) fp32 — log-prob of each chosen token
+ // Top-N alternatives, when params.top_logprobs > 0 (else (B, 0) placeholders).
+ // Both are (B, K), aligned column-wise and ordered by descending log-prob.
+ mx::array top_tokens; // (B, K) int32 — the K most-likely token ids
+ mx::array top_logprobs; // (B, K) fp32 — their log-probs
};
class Sampler {
diff --git a/src/scheduler/request.h b/src/scheduler/request.h
index 89ea733..b0ff0f4 100644
--- a/src/scheduler/request.h
+++ b/src/scheduler/request.h
@@ -14,28 +14,33 @@
#include
#include
+#include
+
#include "sample/json_grammar.h"
#include "sample/sampler.h"
namespace mlxforge {
// Bounded, blocking, single-producer (worker) / single-consumer (request thread)
-// token queue. push() applies backpressure when full (slow SSE consumers); the
-// consumer pop()s until the producer close()s and the queue drains.
-class TokenQueue {
+// queue. push() applies backpressure when full (slow SSE consumers); the consumer
+// pop()s until the producer close()s and the queue drains. Used for both the
+// generated token ids (TokenQueue) and their optional per-token log-probs
+// (LogprobQueue), which the worker pushes in lockstep.
+template
+class SpscQueue {
public:
- explicit TokenQueue(std::size_t capacity = 1024) : capacity_(capacity) {}
+ explicit SpscQueue(std::size_t capacity = 1024) : capacity_(capacity) {}
- // Producer: append a token, blocking while full unless closed.
- void push(int token) {
+ // Producer: append an item, blocking while full unless closed.
+ void push(T item) {
std::unique_lock lk(m_);
not_full_.wait(lk, [&] { return q_.size() < capacity_ || closed_; });
if (closed_) return;
- q_.push(token);
+ q_.push(std::move(item));
not_empty_.notify_one();
}
- // Producer: signal that no more tokens will be pushed.
+ // Producer: signal that no more items will be pushed.
void close() {
{
std::lock_guard lk(m_);
@@ -45,12 +50,12 @@ class TokenQueue {
not_full_.notify_all();
}
- // Consumer: pop the next token; returns false once closed and drained.
- bool pop(int& out) {
+ // Consumer: pop the next item; returns false once closed and drained.
+ bool pop(T& out) {
std::unique_lock lk(m_);
not_empty_.wait(lk, [&] { return !q_.empty() || closed_; });
if (q_.empty()) return false;
- out = q_.front();
+ out = std::move(q_.front());
q_.pop();
not_full_.notify_one();
return true;
@@ -58,13 +63,27 @@ class TokenQueue {
private:
std::size_t capacity_;
- std::queue q_;
+ std::queue q_;
bool closed_ = false;
std::mutex m_;
std::condition_variable not_empty_;
std::condition_variable not_full_;
};
+using TokenQueue = SpscQueue;
+
+// One emitted token's log-probability data (OpenAI logprobs). `id`/`logprob` are
+// the chosen token and its log-prob; `top` holds the requested alternatives as
+// (id, log-prob) pairs in descending order (empty when only the chosen logprob
+// was requested). Carried on Request::logprobs in lockstep with Request::tokens.
+struct TokenLogprob {
+ int id = 0;
+ float logprob = 0.0f;
+ std::vector> top;
+};
+
+using LogprobQueue = SpscQueue;
+
struct Request {
std::vector prompt_ids;
SamplingParams params;
@@ -103,6 +122,9 @@ struct Request {
std::atomic cancelled{false};
TokenQueue tokens; // worker pushes generated ids, then close()s
+ // Per-token log-probs, pushed in lockstep with `tokens` when
+ // params.top_logprobs >= 0 (otherwise never touched). Closed alongside `tokens`.
+ LogprobQueue logprobs;
std::string finish_reason; // "stop" | "length" | "cancel" | "embed"
// Metrics: enqueue stamped on submit, first_token/finish stamped by the worker.
diff --git a/src/server/http_server.cpp b/src/server/http_server.cpp
index bfffc56..9b06f4c 100644
--- a/src/server/http_server.cpp
+++ b/src/server/http_server.cpp
@@ -50,6 +50,10 @@ std::string HttpServer::next_id(const char* prefix) {
std::shared_ptr HttpServer::make_request(const ChatRequest& cr) const {
auto req = std::make_shared();
req->params = cr.params;
+ // OpenAI logprobs -> engine top_logprobs: off (-1) unless `logprobs` is set,
+ // then the requested alternatives count (0 = chosen-token logprob only).
+ // Unsupported on the single-stream vision path, so off when an image is present.
+ req->params.top_logprobs = (cr.logprobs && !cr.has_images()) ? cr.top_logprobs : -1;
req->max_tokens = cr.max_tokens;
req->eos_ids = cfg_.eos_token_ids;
@@ -94,9 +98,18 @@ nlohmann::json HttpServer::run_blocking(const std::shared_ptr& req,
const ChatRequest& cr) {
const int prompt_tokens = static_cast(req->prompt_ids.size());
+ // Drain tokens and (when enabled) their log-probs in lockstep.
+ const bool want_lp = req->params.top_logprobs >= 0;
std::vector out;
+ std::vector out_lp;
int tok = 0;
- while (req->tokens.pop(tok)) out.push_back(tok);
+ while (req->tokens.pop(tok)) {
+ out.push_back(tok);
+ if (want_lp) {
+ TokenLogprob lp;
+ if (req->logprobs.pop(lp)) out_lp.push_back(std::move(lp));
+ }
+ }
const std::string content = tok_->decode(out);
const std::string finish = finish_reason_of(req->finish_reason);
@@ -110,8 +123,11 @@ nlohmann::json HttpServer::run_blocking(const std::shared_ptr& req,
return make_chat_completion_tools(next_id("chatcmpl-"), created, model_name_, calls,
prompt_tokens, completion_tokens);
}
+ // Pass an array (possibly empty) when logprobs were requested, else null.
+ const json logprobs_content =
+ want_lp ? make_logprobs_content(out_lp, *tok_) : json(nullptr);
return make_chat_completion(next_id("chatcmpl-"), created, model_name_, content, finish,
- prompt_tokens, completion_tokens);
+ prompt_tokens, completion_tokens, logprobs_content);
}
// Legacy text completion shape.
return {{"id", next_id("cmpl-")},
@@ -138,9 +154,15 @@ void HttpServer::stream_chat(const std::shared_ptr& req, httplib::Respo
req->cancelled.store(true); // client disconnected -> worker evicts
return false;
};
+ // Per-token log-probs, popped in lockstep with the token ids and attached
+ // (as choices[0].logprobs.content) to the next content delta we emit.
+ const bool want_lp = req->params.top_logprobs >= 0;
+ std::vector pending_lp;
auto send_content = [&](const std::string& s) {
- return s.empty() || send(sse_frame(make_chat_chunk(id, created, model,
- {{"content", s}}, nullptr)));
+ if (s.empty()) return true;
+ json lp = want_lp ? make_logprobs_content(pending_lp, *tok) : json(nullptr);
+ pending_lp.clear();
+ return send(sse_frame(make_chat_chunk(id, created, model, {{"content", s}}, nullptr, lp)));
};
// First chunk announces the assistant role.
@@ -155,6 +177,12 @@ void HttpServer::stream_chat(const std::shared_ptr& req, httplib::Respo
bool buffering = allow_tools;
int t = 0;
while (req->tokens.pop(t)) {
+ // Pop this token's log-prob in lockstep (the worker pushes one per
+ // emitted token) so it stays aligned with the detokenized text.
+ if (want_lp) {
+ TokenLogprob lp;
+ if (req->logprobs.pop(lp)) pending_lp.push_back(std::move(lp));
+ }
std::string piece = detok.add(t);
if (piece.empty()) continue;
if (buffering) {
diff --git a/src/server/openai.cpp b/src/server/openai.cpp
index 0ee4095..503fbb1 100644
--- a/src/server/openai.cpp
+++ b/src/server/openai.cpp
@@ -148,6 +148,14 @@ void parse_common(const json& body, ChatRequest& r) {
if (body.contains("seed") && !body["seed"].is_null())
r.params.seed = body["seed"].get();
+ // OpenAI logprobs: `logprobs` (bool) turns on per-token reporting; with it,
+ // `top_logprobs` (0–20) is the alternatives count. We carry both on the request
+ // for serialization and let the HTTP layer fold them into params.top_logprobs.
+ r.logprobs = body.value("logprobs", false);
+ r.top_logprobs = body.value("top_logprobs", 0);
+ if (r.top_logprobs < 0 || r.top_logprobs > 20)
+ throw std::runtime_error("'top_logprobs' must be in [0, 20]");
+
r.stream = body.value("stream", false);
r.n = body.value("n", 1);
r.stop = parse_stop(body);
@@ -233,18 +241,45 @@ nlohmann::json make_embeddings_response(const std::string& model,
};
}
+namespace {
+// One {token, logprob, bytes} entry: decode `id` to its text and raw UTF-8 bytes.
+json logprob_entry(int id, float logprob, const Tokenizer& tok) {
+ const std::string text = tok.decode({id});
+ json bytes = json::array();
+ for (unsigned char c : text) bytes.push_back(static_cast(c));
+ return {{"token", text}, {"logprob", logprob}, {"bytes", std::move(bytes)}};
+}
+} // namespace
+
+nlohmann::json make_logprobs_content(const std::vector& logprobs,
+ const Tokenizer& tok) {
+ json content = json::array();
+ for (const TokenLogprob& lp : logprobs) {
+ json entry = logprob_entry(lp.id, lp.logprob, tok);
+ json top = json::array();
+ for (const auto& alt : lp.top) top.push_back(logprob_entry(alt.first, alt.second, tok));
+ entry["top_logprobs"] = std::move(top);
+ content.push_back(std::move(entry));
+ }
+ return content;
+}
+
nlohmann::json make_chat_completion(const std::string& id, long created, const std::string& model,
const std::string& content, const std::string& finish_reason,
- int prompt_tokens, int completion_tokens) {
+ int prompt_tokens, int completion_tokens,
+ const nlohmann::json& logprobs_content) {
+ json choice = {{"index", 0},
+ {"message", {{"role", "assistant"}, {"content", content}}},
+ {"finish_reason", finish_reason}};
+ // null => logprobs disabled (omit the field); an array (even empty) => enabled.
+ choice["logprobs"] =
+ logprobs_content.is_null() ? json(nullptr) : json{{"content", logprobs_content}};
return {
{"id", id},
{"object", "chat.completion"},
{"created", created},
{"model", model},
- {"choices",
- json::array({{{"index", 0},
- {"message", {{"role", "assistant"}, {"content", content}}},
- {"finish_reason", finish_reason}}})},
+ {"choices", json::array({std::move(choice)})},
{"usage", make_usage(prompt_tokens, completion_tokens)},
};
}
@@ -327,13 +362,16 @@ nlohmann::json make_chat_completion_tools(const std::string& id, long created,
}
nlohmann::json make_chat_chunk(const std::string& id, long created, const std::string& model,
- const nlohmann::json& delta, const nlohmann::json& finish_reason) {
+ const nlohmann::json& delta, const nlohmann::json& finish_reason,
+ const nlohmann::json& logprobs_content) {
+ json choice = {{"index", 0}, {"delta", delta}, {"finish_reason", finish_reason}};
+ // Attach logprobs for the tokens in this delta when present (null => omit).
+ if (!logprobs_content.is_null()) choice["logprobs"] = json{{"content", logprobs_content}};
return {{"id", id},
{"object", "chat.completion.chunk"},
{"created", created},
{"model", model},
- {"choices",
- json::array({{{"index", 0}, {"delta", delta}, {"finish_reason", finish_reason}}})}};
+ {"choices", json::array({std::move(choice)})}};
}
std::string sse_frame(const nlohmann::json& payload) { return "data: " + payload.dump() + "\n\n"; }
diff --git a/src/server/openai.h b/src/server/openai.h
index fbd5940..bc50b09 100644
--- a/src/server/openai.h
+++ b/src/server/openai.h
@@ -9,6 +9,7 @@
#include
#include "sample/sampler.h"
+#include "scheduler/request.h" // TokenLogprob
#include "tokenizer/tokenizer.h"
namespace mlxforge {
@@ -31,6 +32,11 @@ struct ChatRequest {
}
SamplingParams params;
int max_tokens = 128;
+ // OpenAI logprobs: `logprobs` enables per-token log-prob reporting; when set,
+ // `top_logprobs` (0–20) is how many alternatives to include per token. Mapped
+ // onto params.top_logprobs (-1 when off) by the HTTP layer.
+ bool logprobs = false;
+ int top_logprobs = 0;
bool stream = false;
std::vector stop;
int n = 1;
@@ -73,10 +79,19 @@ EmbeddingsRequest parse_embeddings_request(const nlohmann::json& body);
// chat.completion and text_completion response shapes.
nlohmann::json make_usage(int prompt_tokens, int completion_tokens);
-// Serialize a finished completion into the OpenAI chat.completion shape.
+// Build the OpenAI logprobs `content` array from a request's per-token log-probs:
+// one entry {token, logprob, bytes, top_logprobs:[{token, logprob, bytes}]} per
+// token, where the token text and `bytes` come from decoding the id with `tok`.
+nlohmann::json make_logprobs_content(const std::vector& logprobs,
+ const Tokenizer& tok);
+
+// Serialize a finished completion into the OpenAI chat.completion shape. When
+// `logprobs_content` is non-null it is attached as choices[0].logprobs.content
+// (pass an empty array for an enabled-but-empty result; null/omitted = disabled).
nlohmann::json make_chat_completion(const std::string& id, long created, const std::string& model,
const std::string& content, const std::string& finish_reason,
- int prompt_tokens, int completion_tokens);
+ int prompt_tokens, int completion_tokens,
+ const nlohmann::json& logprobs_content = nlohmann::json());
// Detect a Llama-3.2 tool call in the model's decoded output. Returns the parsed
// calls, or an empty vector when the text is not a tool call (treat as content).
@@ -105,9 +120,11 @@ nlohmann::json make_models_list(const std::string& model);
// One streaming chunk object. `delta` is the partial message delta (e.g.
// {{"content","..."}} or {{"role","assistant"}}); `finish_reason` is the JSON
-// finish reason (null until the final chunk).
+// finish reason (null until the final chunk). When `logprobs_content` is non-null
+// it is attached as choices[0].logprobs.content for the tokens in this delta.
nlohmann::json make_chat_chunk(const std::string& id, long created, const std::string& model,
- const nlohmann::json& delta, const nlohmann::json& finish_reason);
+ const nlohmann::json& delta, const nlohmann::json& finish_reason,
+ const nlohmann::json& logprobs_content = nlohmann::json());
// Wrap a JSON payload as an SSE frame: "data: \n\n".
std::string sse_frame(const nlohmann::json& payload);
diff --git a/tests/capi/capi_test.cpp b/tests/capi/capi_test.cpp
index 549cbde..e5c9089 100644
--- a/tests/capi/capi_test.cpp
+++ b/tests/capi/capi_test.cpp
@@ -134,6 +134,69 @@ TEST_CASE("C ABI generates text and batches concurrent requests deterministicall
mlxforge_engine_free(eng);
}
+TEST_CASE("C ABI reports per-token logprobs aligned with the generated tokens") {
+ if (!model_available()) {
+ MESSAGE("MLXFORGE_MODEL_DIR not present; skipping");
+ return;
+ }
+ char* err = nullptr;
+ mlxforge_engine* eng = mlxforge_engine_create(model_dir().c_str(), nullptr, &err);
+ REQUIRE_MESSAGE(eng != nullptr, (err ? err : "engine_create failed"));
+ while (!mlxforge_engine_ready(eng)) std::this_thread::sleep_for(std::chrono::milliseconds(10));
+
+ mlxforge_sampling s = {}; // greedy
+ s.max_tokens = 12;
+ s.logprobs = 4; // the chosen token's logprob + 3 alternatives
+
+ mlxforge_msg msg = {"user", "What is the capital of France?"};
+ mlxforge_request* r = mlxforge_submit_chat(eng, &msg, 1, &s, &err);
+ REQUIRE_MESSAGE(r != nullptr, (err ? err : "submit_chat failed"));
+
+ drain(r); // logprobs accumulate as the stream is drained
+ const std::string reason = mlxforge_request_finish_reason(r);
+
+ char* lj = mlxforge_request_logprobs(r);
+ REQUIRE_MESSAGE(lj != nullptr, "expected a logprobs JSON payload");
+ const std::string lp_json = lj;
+ mlxforge_string_free(lj);
+ mlxforge_request_free(r);
+
+ auto content = nlohmann::json::parse(lp_json, nullptr, /*allow_exceptions=*/false);
+ REQUIRE_FALSE(content.is_discarded());
+ REQUIRE(content.is_array());
+ CHECK(content.size() >= 1);
+ // One entry per emitted token: a "length" stop means exactly max_tokens tokens.
+ if (reason == "length") CHECK(content.size() == 12);
+ CHECK(content.size() <= 12);
+
+ for (const auto& e : content) {
+ CHECK(e.contains("token"));
+ CHECK(e["logprob"].is_number());
+ CHECK(e["logprob"].get() <= 0.0); // a log-probability
+ CHECK(e["bytes"].is_array());
+ REQUIRE(e["top_logprobs"].is_array());
+ CHECK(e["top_logprobs"].size() == 3); // logprobs=4 => 3 alternatives
+ // Greedy: the chosen token is the most likely, so it equals the top entry.
+ CHECK(e["top_logprobs"][0]["token"] == e["token"]);
+ CHECK(e["top_logprobs"][0]["logprob"].get() == doctest::Approx(e["logprob"].get()));
+ // Alternatives are in descending log-prob order.
+ const auto& top = e["top_logprobs"];
+ for (size_t i = 1; i < top.size(); ++i)
+ CHECK(top[i - 1]["logprob"].get() >= top[i]["logprob"].get());
+ }
+
+ // A request without logprobs returns no payload.
+ mlxforge_sampling plain = {};
+ plain.max_tokens = 4;
+ mlxforge_request* r2 = mlxforge_submit_chat(eng, &msg, 1, &plain, nullptr);
+ REQUIRE(r2 != nullptr);
+ drain(r2);
+ CHECK(mlxforge_request_logprobs(r2) == nullptr);
+ mlxforge_request_free(r2);
+
+ mlxforge_engine_free(eng);
+}
+
TEST_CASE("C ABI embeddings are unit-normalized, deterministic, and semantic") {
if (!model_available()) {
MESSAGE("MLXFORGE_MODEL_DIR not present; skipping");
diff --git a/tests/runtime/single_stream_test.cpp b/tests/runtime/single_stream_test.cpp
index ef0cb9c..d73fab8 100644
--- a/tests/runtime/single_stream_test.cpp
+++ b/tests/runtime/single_stream_test.cpp
@@ -42,6 +42,36 @@ TEST_CASE("loop terminates on max_tokens") {
CHECK_FALSE(r.hit_eos);
}
+TEST_CASE("greedy_generate reports per-token logprobs aligned with the tokens") {
+ if (!model_available()) {
+ MESSAGE("MLXFORGE_MODEL_DIR not present; skipping");
+ return;
+ }
+ mlxforge::LlamaModel& model = shared_model();
+ std::vector prompt = load_token_ids("prompt_0_ids.npy");
+
+ mlxforge::GenerateResult r = mlxforge::greedy_generate(
+ model, prompt, /*max_tokens=*/8, model.config().eos_token_ids, {}, /*top_logprobs=*/3);
+
+ // One logprob record per emitted token, each aligned with its id.
+ REQUIRE(r.token_logprobs.size() == r.tokens.size());
+ for (size_t i = 0; i < r.tokens.size(); ++i) {
+ const mlxforge::TokenLogprob& lp = r.token_logprobs[i];
+ CHECK(lp.id == r.tokens[i]);
+ CHECK(lp.logprob <= 0.0f);
+ REQUIRE(lp.top.size() == 3);
+ // Greedy: the chosen token is the most likely, so it heads the alternatives.
+ CHECK(lp.top.front().first == lp.id);
+ CHECK(lp.top.front().second == doctest::Approx(lp.logprob));
+ for (size_t k = 1; k < lp.top.size(); ++k) CHECK(lp.top[k - 1].second >= lp.top[k].second);
+ }
+
+ // Off by default: no logprob work, no records.
+ mlxforge::GenerateResult plain =
+ mlxforge::greedy_generate(model, prompt, /*max_tokens=*/4, model.config().eos_token_ids);
+ CHECK(plain.token_logprobs.empty());
+}
+
TEST_CASE("loop terminates on EOS") {
if (!model_available()) {
MESSAGE("MLXFORGE_MODEL_DIR not present; skipping");
diff --git a/tests/sample/sampler_test.cpp b/tests/sample/sampler_test.cpp
index c580c43..3160b0b 100644
--- a/tests/sample/sampler_test.cpp
+++ b/tests/sample/sampler_test.cpp
@@ -140,3 +140,64 @@ TEST_CASE("sample returns batched tokens and logprobs") {
CHECK(r.logprobs.shape() == mx::Shape{2});
for (float lp : floats(r.logprobs)) CHECK(lp <= 0.0f); // log-probabilities
}
+
+TEST_CASE("top_logprobs off leaves the top arrays empty") {
+ mx::array logits = logits2d({{1.0f, 2.0f, 3.0f}});
+ SampleResult r = Sampler::sample(logits, SamplingParams{}, mx::random::key(0)); // default -1
+ CHECK(r.top_tokens.shape() == mx::Shape{1, 0});
+ CHECK(r.top_logprobs.shape() == mx::Shape{1, 0});
+}
+
+TEST_CASE("top_logprobs == 0 reports the chosen logprob but no alternatives") {
+ mx::array logits = logits2d({{1.0f, 2.0f, 3.0f}});
+ SamplingParams p;
+ p.temperature = 0.0f; // greedy
+ p.top_logprobs = 0;
+ SampleResult r = Sampler::sample(logits, p, mx::random::key(0));
+ CHECK(ints(r.tokens) == std::vector{2}); // argmax
+ CHECK(floats(r.logprobs)[0] <= 0.0f);
+ CHECK(r.top_tokens.shape() == mx::Shape{1, 0});
+ CHECK(r.top_logprobs.shape() == mx::Shape{1, 0});
+}
+
+TEST_CASE("top_logprobs > 0 returns the top-k ids in descending logprob order") {
+ // logits 1,4,2,3,0 -> ranked ids 1(4), 3(3), 2(2), 0(1), 4(0).
+ mx::array logits = logits2d({{1.0f, 4.0f, 2.0f, 3.0f, 0.0f}});
+ SamplingParams p;
+ p.temperature = 0.0f;
+ p.top_logprobs = 3;
+ SampleResult r = Sampler::sample(logits, p, mx::random::key(0));
+ CHECK(r.top_tokens.shape() == mx::Shape{1, 3});
+ CHECK(r.top_logprobs.shape() == mx::Shape{1, 3});
+ CHECK(ints(r.top_tokens) == std::vector{1, 3, 2});
+
+ std::vector lp = floats(r.top_logprobs);
+ CHECK(lp[0] >= lp[1]);
+ CHECK(lp[1] >= lp[2]);
+ for (float x : lp) CHECK(x <= 0.0f);
+ // The chosen token is the argmax and its logprob equals the top alternative.
+ CHECK(ints(r.tokens) == std::vector{1});
+ CHECK(floats(r.logprobs)[0] == doctest::Approx(lp[0]));
+}
+
+TEST_CASE("a full top-k recovers a normalized distribution") {
+ mx::array logits = logits2d({{0.5f, 1.5f, -0.5f, 2.0f}});
+ SamplingParams p;
+ p.temperature = 0.0f;
+ p.top_logprobs = 4; // == vocab
+ SampleResult r = Sampler::sample(logits, p, mx::random::key(0));
+ float sum = 0.0f;
+ for (float x : floats(r.top_logprobs)) sum += std::exp(x);
+ CHECK(sum == doctest::Approx(1.0f).epsilon(0.01));
+}
+
+TEST_CASE("top_logprobs is reported under temperature sampling too") {
+ mx::array logits = logits2d({{1.0f, 4.0f, 2.0f, 3.0f}});
+ SamplingParams p;
+ p.temperature = 0.8f;
+ p.top_logprobs = 2;
+ SampleResult r = Sampler::sample(logits, p, mx::random::key(3));
+ CHECK(r.top_tokens.shape() == mx::Shape{1, 2});
+ // Alternatives come from the pre-temperature distribution: highest-logit first.
+ CHECK(ints(r.top_tokens) == std::vector{1, 3});
+}
diff --git a/tests/server/openai_test.cpp b/tests/server/openai_test.cpp
index 7b1b310..eb1451c 100644
--- a/tests/server/openai_test.cpp
+++ b/tests/server/openai_test.cpp
@@ -1,7 +1,10 @@
// OpenAI request parsing + response serialization (pure, no GPU).
#include
+#include
+
#include "server/openai.h"
+#include "tokenizer/tokenizer.h"
using namespace mlxforge;
using nlohmann::json;
@@ -78,6 +81,93 @@ TEST_CASE("make_chat_completion emits the OpenAI chat.completion shape") {
CHECK(c["usage"]["total_tokens"] == 13);
}
+TEST_CASE("parse_chat_request reads logprobs and top_logprobs") {
+ const std::string msgs = R"("messages":[{"role":"user","content":"x"}])";
+ ChatRequest off = parse_chat_request(json::parse("{" + msgs + "}"));
+ CHECK_FALSE(off.logprobs);
+ CHECK(off.top_logprobs == 0);
+
+ ChatRequest on =
+ parse_chat_request(json::parse("{" + msgs + R"(,"logprobs":true,"top_logprobs":5})"));
+ CHECK(on.logprobs);
+ CHECK(on.top_logprobs == 5);
+
+ // top_logprobs is validated to the OpenAI [0, 20] range.
+ CHECK_THROWS_AS(parse_chat_request(json::parse("{" + msgs + R"(,"top_logprobs":21})")),
+ std::runtime_error);
+ CHECK_THROWS_AS(parse_chat_request(json::parse("{" + msgs + R"(,"top_logprobs":-1})")),
+ std::runtime_error);
+}
+
+TEST_CASE("make_chat_completion attaches a logprobs block when content is given") {
+ // A hand-built content array (the shape make_logprobs_content produces).
+ json content = json::array(
+ {{{"token", "Paris"}, {"logprob", -0.1}, {"bytes", json::array({80, 97, 114, 105, 115})},
+ {"top_logprobs", json::array({{{"token", "Paris"}, {"logprob", -0.1}, {"bytes", json::array()}},
+ {{"token", "Lyon"}, {"logprob", -2.0}, {"bytes", json::array()}}})}}});
+ json c = make_chat_completion("chatcmpl-1", 1, "mlxforge", "Paris", "stop", 4, 1, content);
+ REQUIRE(c["choices"][0]["logprobs"].is_object());
+ const json& lc = c["choices"][0]["logprobs"]["content"];
+ REQUIRE(lc.is_array());
+ CHECK(lc.size() == 1);
+ CHECK(lc[0]["token"] == "Paris");
+ CHECK(lc[0]["logprob"] == doctest::Approx(-0.1));
+ CHECK(lc[0]["bytes"].size() == 5);
+ CHECK(lc[0]["top_logprobs"].size() == 2);
+
+ // Omitted (null default) => logprobs is present but null (OpenAI convention).
+ json off = make_chat_completion("chatcmpl-1", 1, "mlxforge", "Paris", "stop", 4, 1);
+ CHECK(off["choices"][0]["logprobs"].is_null());
+}
+
+TEST_CASE("make_chat_chunk attaches logprobs to the streamed choice") {
+ json content = json::array({{{"token", " a"}, {"logprob", -0.5}, {"bytes", json::array({32, 97})},
+ {"top_logprobs", json::array()}}});
+ json chunk = make_chat_chunk("chatcmpl-1", 1, "mlxforge", {{"content", " a"}}, nullptr, content);
+ REQUIRE(chunk["choices"][0]["logprobs"].is_object());
+ CHECK(chunk["choices"][0]["logprobs"]["content"][0]["token"] == " a");
+
+ // Without logprobs the field is omitted (back-compat with existing chunks).
+ json plain = make_chat_chunk("chatcmpl-1", 1, "mlxforge", {{"content", " a"}}, nullptr);
+ CHECK_FALSE(plain["choices"][0].contains("logprobs"));
+}
+
+TEST_CASE("make_logprobs_content decodes tokens, bytes, and alternatives") {
+ // Needs a real tokenizer to decode ids; uses the cached model's tokenizer.
+ const std::string dir = MLXFORGE_MODEL_DIR;
+ if (dir.empty() || !std::ifstream(dir + "/tokenizer.json").good()) {
+ MESSAGE("MLXFORGE_MODEL_DIR not present; skipping logprobs content test");
+ return;
+ }
+ Tokenizer tok = Tokenizer::from_file(dir + "/tokenizer.json");
+
+ // Two tokens; the first carries two alternatives, the second none.
+ std::vector lps;
+ TokenLogprob a;
+ a.id = tok.encode("Paris").back();
+ a.logprob = -0.25f;
+ a.top = {{a.id, -0.25f}, {tok.encode("Lyon").back(), -1.5f}};
+ TokenLogprob b;
+ b.id = tok.encode(" France").back();
+ b.logprob = -0.5f;
+ lps = {a, b};
+
+ json content = make_logprobs_content(lps, tok);
+ REQUIRE(content.is_array());
+ REQUIRE(content.size() == 2);
+
+ // Entry 0: token text matches a decode, bytes are the UTF-8 of that text, and
+ // top_logprobs has both alternatives.
+ CHECK(content[0]["token"] == tok.decode({a.id}));
+ CHECK(content[0]["logprob"] == doctest::Approx(-0.25));
+ CHECK(content[0]["bytes"].size() == tok.decode({a.id}).size());
+ REQUIRE(content[0]["top_logprobs"].is_array());
+ CHECK(content[0]["top_logprobs"].size() == 2);
+ CHECK(content[0]["top_logprobs"][0]["logprob"] == doctest::Approx(-0.25));
+ // Entry 1: no alternatives.
+ CHECK(content[1]["top_logprobs"].empty());
+}
+
TEST_CASE("parse_completion_request reads a prompt string") {
ChatRequest r = parse_completion_request(json::parse(R"({"prompt":"once upon","max_tokens":8})"));
CHECK_FALSE(r.is_chat);