Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 57 additions & 8 deletions apps/mlxforge_cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
// mlxforge-cli dump-weights <dir>
// - Loads a model's weights from the supplied directory, prints key/shape/dtype for each tensor,
// asserts that all tensors are fp16, and reports the peak resident memory used.
// mlxforge-cli generate <model> <prompt> [max_tokens]
// mlxforge-cli generate <model> <prompt> [max_tokens] [--logprobs [N]]
// - Runs greedy single-stream generation: pre-fills the prompt (as raw text using the chat template
// or as a pre-tokenized .npy of ids), then streams the detokenized text to stdout until EOS or
// max_tokens.
// max_tokens. With --logprobs [N], each emitted token's log-prob (and its N most-likely
// alternatives) is printed to stderr after generation; stdout stays the generated text.
// mlxforge-cli bench <model> [max_tokens] [runs]
// - Repeatable throughput benchmark over a fixed prompt: one discarded warmup run, then `runs`
// timed runs (defaults: max_tokens=128, runs=3) reporting time-to-first-token and decode tok/s.
Expand All @@ -22,6 +23,7 @@
// which will be downloaded on first use.

#include <algorithm>
#include <cctype>
#include <chrono>
#include <cstdio>
#include <string>
Expand Down Expand Up @@ -162,8 +164,24 @@ int run_dump_weights(const std::string& spec) {
return non_fp16 == 0 ? 0 : 1;
}

// Render a token's text for the logprobs dump: quoted, with the common control
// characters escaped so whitespace tokens stay legible on one line.
std::string show_token(const std::string& s) {
std::string out = "\"";
for (char c : s) {
if (c == '\n') out += "\\n";
else if (c == '\t') out += "\\t";
else if (c == '\r') out += "\\r";
else out += c;
}
return out + "\"";
}

// Performs generation using a loaded model, with either raw text or pre-tokenized prompts.
int run_generate(const std::string& spec, const std::string& prompt_arg, int max_tokens) {
// `top_logprobs` mirrors the engine knob: -1 = off; 0 = each token's own log-prob;
// N > 0 = also its N most-likely alternatives (printed to stderr after generation).
int run_generate(const std::string& spec, const std::string& prompt_arg, int max_tokens,
int top_logprobs = -1) {
// Resolve and load the model (GGUF file or safetensors dir; downloads if needed)
LoadedModel lm = load_for_inference(spec);
mlxforge::DecoderModel& model = *lm.model;
Expand Down Expand Up @@ -194,13 +212,29 @@ int run_generate(const std::string& spec, const std::string& prompt_arg, int max
std::string piece = detok.add(id);
std::fwrite(piece.data(), 1, piece.size(), stdout);
std::fflush(stdout);
});
}, top_logprobs);

// Output any final detokenized tail remaining in the streaming detokenizer
std::string tail = detok.finish();
std::fwrite(tail.data(), 1, tail.size(), stdout);
std::fputc('\n', stdout);

// Per-token log-probs go to stderr (stdout stays the generated text). One line
// per emitted token: its text + log-prob, then any requested alternatives.
if (top_logprobs >= 0) {
std::fprintf(stderr, "logprobs (%zu tokens):\n", r.token_logprobs.size());
for (const mlxforge::TokenLogprob& lp : r.token_logprobs) {
std::fprintf(stderr, " %-16s logprob=%8.4f", show_token(tok.decode({lp.id})).c_str(),
lp.logprob);
if (!lp.top.empty()) {
std::fprintf(stderr, " top:");
for (const auto& alt : lp.top)
std::fprintf(stderr, " %s=%.4f", show_token(tok.decode({alt.first})).c_str(), alt.second);
}
std::fputc('\n', stderr);
}
}

// Log some generation statistics
mlxforge::log::info("generated {} tokens{}", r.tokens.size(),
r.hit_eos ? " (stopped at EOS)" : "");
Expand Down Expand Up @@ -334,12 +368,27 @@ int main(int argc, char** argv) {
if (cmd == "generate") {
// Print usage if not enough arguments; otherwise run generation logic
if (argc < 4) {
std::fprintf(stderr, "usage: mlxforge-cli generate <model_dir> <prompt_ids.npy> [max_tokens]\n");
std::fprintf(stderr,
"usage: mlxforge-cli generate <model_dir> <prompt_ids.npy> [max_tokens] "
"[--logprobs [N]]\n");
return 2;
}
// Parse max_tokens if provided, otherwise default to 64
const int max_tokens = argc >= 5 ? std::stoi(argv[4]) : 64;
return run_generate(argv[2], argv[3], max_tokens);
// Positional [max_tokens] (default 64) plus an optional --logprobs [N] flag (N
// alternatives, default 0 = the chosen token's own log-prob only).
int max_tokens = 64;
int top_logprobs = -1;
for (int i = 4; i < argc; ++i) {
const std::string a = argv[i];
if (a == "--logprobs") {
if (i + 1 < argc && std::isdigit(static_cast<unsigned char>(argv[i + 1][0])))
top_logprobs = std::stoi(argv[++i]);
else
top_logprobs = 0;
} else {
max_tokens = std::stoi(a);
}
}
return run_generate(argv[2], argv[3], max_tokens, top_logprobs);
}
if (cmd == "image") {
// Vision-language generation: describe / answer about an image.
Expand Down
23 changes: 23 additions & 0 deletions bindings/node/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ export interface SamplingOptions {
seed?: number;
/** Max new tokens (default 64). */
maxTokens?: number;
/**
* OpenAI logprobs. 0 (default) => off. N > 0 => report each emitted token's
* own log-prob plus its (N - 1) most-likely alternatives (so 1 = chosen-only).
* Retrieve them from the stream's {@link Stream.logprobs} after consumption.
*/
logprobs?: number;
/**
* Constrained decoding. "json" forces any valid JSON value; otherwise a
* JSON-Schema string (supported subset: a top-level object with ordered,
Expand All @@ -39,12 +45,29 @@ export interface ChatMessage {
content: string;
}

/** One token's log-probability (OpenAI logprobs `content` entry). */
export interface TokenLogprob {
/** The token's decoded text. */
token: string;
/** Natural-log probability of the token (<= 0). */
logprob: number;
/** The token's raw UTF-8 bytes. */
bytes: number[];
/** The most-likely alternatives at this position (may be empty). */
top_logprobs: { token: string; logprob: number; bytes: number[] }[];
}

/** An async-iterable stream of decoded text chunks for one request. */
export interface Stream extends AsyncIterable<string> {
/** Ask the engine to stop generating this request. */
cancel(): void;
/** "stop" | "length" | "cancel" | "" (while running). */
readonly finishReason: string;
/**
* Per-token log-probs, available after the stream has been fully consumed.
* Returns null when `logprobs` was not requested (or none were produced).
*/
logprobs(): TokenLogprob[] | null;
}

export class Engine {
Expand Down
11 changes: 11 additions & 0 deletions bindings/node/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ const native = require('./build/Release/mlxforge_node.node');
// Wrap a native Request as an async-iterable of decoded text chunks. The
// request is disposed when iteration finishes (or the consumer breaks out).
function streamRequest(req) {
// Captured once the stream is fully drained (the native handle is disposed
// right after, so we read logprobs while it is still valid).
let logprobs = null;
return {
// The underlying native handle, in case callers want cancel()/finishReason().
request: req,
Expand All @@ -17,10 +20,18 @@ function streamRequest(req) {
get finishReason() {
return req.finishReason();
},
/**
* Per-token log-probs (OpenAI logprobs `content` shape), or null when none
* were requested. Available after the stream has been fully consumed.
*/
logprobs() {
return logprobs;
},
async *[Symbol.asyncIterator]() {
try {
let chunk;
while ((chunk = await req.next()) !== null) yield chunk;
logprobs = req.logprobs(); // capture before dispose() invalidates the handle
} finally {
req.dispose();
}
Expand Down
17 changes: 17 additions & 0 deletions bindings/node/src/addon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ mlxforge_sampling parse_sampling(const Napi::Object& o, std::string& schema_out)
s.presence_penalty = static_cast<float>(num("presencePenalty", 0.0));
s.seed = static_cast<unsigned long long>(num("seed", 0.0));
s.max_tokens = static_cast<int>(num("maxTokens", 0.0));
// OpenAI logprobs: 0 = off; N > 0 = the chosen token's logprob plus (N - 1)
// alternatives (so 1 = chosen-only). Retrieved via Request.logprobs().
s.logprobs = static_cast<int>(num("logprobs", 0.0));
if (o.Has("jsonSchema") && o.Get("jsonSchema").IsString())
schema_out = o.Get("jsonSchema").As<Napi::String>().Utf8Value();
else if (o.Has("responseFormat") && o.Get("responseFormat").IsString())
Expand All @@ -53,6 +56,7 @@ class RequestWrap : public Napi::ObjectWrap<RequestWrap> {
InstanceMethod("next", &RequestWrap::Next),
InstanceMethod("cancel", &RequestWrap::Cancel),
InstanceMethod("finishReason", &RequestWrap::FinishReason),
InstanceMethod("logprobs", &RequestWrap::Logprobs),
InstanceMethod("dispose", &RequestWrap::Dispose),
});
g_request_ctor = Napi::Persistent(f);
Expand Down Expand Up @@ -93,6 +97,19 @@ class RequestWrap : public Napi::ObjectWrap<RequestWrap> {
return Napi::String::New(info.Env(), req_ ? mlxforge_request_finish_reason(req_) : "");
}

// The accumulated per-token log-probs, parsed from the C ABI's OpenAI-shaped
// JSON (or null when none / not requested). Call after the stream is drained,
// before dispose().
Napi::Value Logprobs(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
char* j = req_ ? mlxforge_request_logprobs(req_) : nullptr;
if (!j) return env.Null();
std::string s(j);
mlxforge_string_free(j);
Napi::Object json = env.Global().Get("JSON").As<Napi::Object>();
return json.Get("parse").As<Napi::Function>().Call(json, {Napi::String::New(env, s)});
}

void Dispose(const Napi::CallbackInfo&) { free_req(); }
};

Expand Down
64 changes: 64 additions & 0 deletions bindings/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ struct CSampling {
seed: u64,
max_tokens: c_int,
json_schema: *const c_char,
logprobs: c_int,
}

// Mirrors mlxforge_embed_opts. pooling/add_eos are tri-state: -1 = model default.
Expand Down Expand Up @@ -122,6 +123,7 @@ extern "C" {
) -> *mut mlxforge_request;
fn mlxforge_request_next(r: *mut mlxforge_request, text: *mut *mut c_char) -> c_int;
fn mlxforge_request_finish_reason(r: *mut mlxforge_request) -> *const c_char;
fn mlxforge_request_logprobs(r: *mut mlxforge_request) -> *mut c_char;
fn mlxforge_request_free(r: *mut mlxforge_request);
}

Expand Down Expand Up @@ -159,6 +161,9 @@ pub struct Sampling {
pub max_tokens: i32,
/// Constrained decoding: "json" or a JSON-Schema string (see the C ABI docs).
pub json_schema: Option<String>,
/// OpenAI logprobs: 0 = off; N > 0 = the chosen token's log-prob plus (N - 1)
/// alternatives (so 1 = chosen-only). Retrieve via [`Engine::chat_with_logprobs`].
pub logprobs: i32,
}

impl Sampling {
Expand Down Expand Up @@ -234,6 +239,7 @@ impl Engine {
seed: s.seed,
max_tokens: s.max_tokens,
json_schema,
logprobs: s.logprobs,
}
}

Expand Down Expand Up @@ -268,6 +274,42 @@ impl Engine {
Ok(drain(req))
}

/// Run a chat completion and return both the text and the per-token log-probs
/// (the C ABI's OpenAI-shaped JSON `content` array, or `None` when `logprobs`
/// was not set on `sampling`).
pub fn chat_with_logprobs(
&self,
messages: &[(&str, &str)],
sampling: &Sampling,
) -> Result<(String, Option<String>), String> {
let owned: Vec<(CString, CString)> = messages
.iter()
.map(|(r, c)| {
(
CString::new(*r).unwrap_or_default(),
CString::new(*c).unwrap_or_default(),
)
})
.collect();
let msgs: Vec<Msg> = owned
.iter()
.map(|(r, c)| Msg {
role: r.as_ptr(),
content: c.as_ptr(),
})
.collect();

let mut schema_keep: Option<CString> = None;
let cs = Self::c_sampling(sampling, &mut schema_keep);
let mut err: *mut c_char = ptr::null_mut();
let req =
unsafe { mlxforge_submit_chat(self.handle, msgs.as_ptr(), msgs.len(), &cs, &mut err) };
if req.is_null() {
return Err(unsafe { take_string(err) }.unwrap_or_else(|| "submit failed".into()));
}
Ok(drain_with_logprobs(req))
}

/// Run a raw-text completion (no chat template) to completion.
pub fn text(&self, prompt: &str, sampling: &Sampling) -> Result<String, String> {
let cprompt = CString::new(prompt).map_err(|_| "prompt contains NUL".to_string())?;
Expand Down Expand Up @@ -436,3 +478,25 @@ fn drain(req: *mut mlxforge_request) -> String {
unsafe { mlxforge_request_free(req) };
out
}

// Drain a request to text and then read its accumulated per-token log-probs
// (the OpenAI-shaped JSON `content` array, or None when none were produced),
// freeing the request.
fn drain_with_logprobs(req: *mut mlxforge_request) -> (String, Option<String>) {
let mut out = String::new();
loop {
let mut text: *mut c_char = ptr::null_mut();
let rc = unsafe { mlxforge_request_next(req, &mut text) };
if rc == 0 {
if let Some(s) = unsafe { take_string(text) } {
out.push_str(&s);
}
} else {
break;
}
}
// Read logprobs while the request is still alive, then free it.
let logprobs = unsafe { take_string(mlxforge_request_logprobs(req)) };
unsafe { mlxforge_request_free(req) };
(out, logprobs)
}
36 changes: 35 additions & 1 deletion bindings/swift/Sources/MLXForge/MLXForge.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public struct Sampling {
/// JSON-Schema string (supported subset: a top-level object with ordered,
/// required, scalar-typed properties). Output is masked to be well-formed JSON.
public var jsonSchema: String? = nil
/// OpenAI logprobs: 0 = off; N > 0 = the chosen token's log-prob plus (N - 1)
/// alternatives (so 1 = chosen-only). Retrieved via `Engine.completeWithLogprobs`.
public var logprobs: Int32 = 0

public init() {}
public static var greedy: Sampling { Sampling() }
Expand All @@ -34,7 +37,7 @@ public struct Sampling {
temperature: temperature, top_k: topK, top_p: topP, min_p: minP,
repetition_penalty: repetitionPenalty, frequency_penalty: frequencyPenalty,
presence_penalty: presencePenalty, seed: seed, max_tokens: maxTokens,
json_schema: nil)
json_schema: nil, logprobs: logprobs)
}
}

Expand Down Expand Up @@ -168,6 +171,37 @@ public final class Engine {
return out
}

/// Run a chat to completion and return both the text and the per-token log-probs
/// (the C ABI's OpenAI-shaped JSON `content` string, or nil when `logprobs` was
/// not set on `sampling`). Logprobs are read once the stream is fully drained.
public func completeWithLogprobs(_ messages: [ChatMessage], sampling: Sampling = .greedy)
async throws -> (text: String, logprobs: String?)
{
let req = try submitChat(messages, sampling)
return try await withCheckedThrowingContinuation { cont in
DispatchQueue.global(qos: .userInitiated).async {
var out = ""
while true {
var text: UnsafeMutablePointer<CChar>?
let rc = mlxforge_request_next(req, &text)
if rc == 0, let t = text {
out += String(cString: t)
mlxforge_string_free(t)
} else {
break
}
}
var logprobs: String? = nil
if let lp = mlxforge_request_logprobs(req) {
logprobs = String(cString: lp)
mlxforge_string_free(lp)
}
mlxforge_request_free(req)
cont.resume(returning: (out, logprobs))
}
}
}

/// Embed text into a (by default unit-normalized) vector.
///
/// With all defaults the model self-selects its convention: a Qwen3-Embedding
Expand Down
1 change: 1 addition & 0 deletions cmake/abi-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mlxforge_floats_free
mlxforge_request_cancel
mlxforge_request_finish_reason
mlxforge_request_free
mlxforge_request_logprobs
mlxforge_request_next
mlxforge_string_free
mlxforge_submit_chat
Expand Down
Loading
Loading