Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions swift/Sources/CoreAILanguageModels/Output/LogitsWriter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,22 @@ public enum LogitsLength: Sendable {

/// Represents logits information for a single generated token
public struct TokenLogits: Sendable {
let tokenId: Int32
let tokenText: String
let topLogits: [TopLogitEntry]
public let tokenId: Int32
public let tokenText: String
public let topLogits: [TopLogitEntry]

public init(tokenId: Int32, tokenText: String, topLogits: [TopLogitEntry]) {
self.tokenId = tokenId
self.tokenText = tokenText
self.topLogits = topLogits
}
}

/// Represents a single entry in top-K logits
public struct TopLogitEntry: Codable, Sendable {
let tokenId: Int32
let tokenText: String
let logit: Float
public let tokenId: Int32
public let tokenText: String
public let logit: Float

enum CodingKeys: String, CodingKey {
case tokenId = "token_id"
Expand Down
51 changes: 50 additions & 1 deletion swift/Sources/Tools/llm-runner/LLMRunnerMain.swift
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ struct LLMRunner: AsyncParsableCommand, Sendable {
@Option(name: .customLong("image"), help: "Path to an image file for vision-language models")
var imagePath: String?

@Option(
help: "Maximum tiles for image splitting (overrides model config). 1 = single crop, no tiling.")

@Flag(help: "Enable verbose logging")
var verbose: Bool = false

Expand Down Expand Up @@ -886,27 +889,73 @@ struct LLMRunner: AsyncParsableCommand, Sendable {
let inferenceID = InstrumentsProfiler.beginInference(
promptTokens: vlmTokens.count, maxTokens: maxTokens)

await PerformanceMetrics.shared.setPromptTokenCount(vlmTokens.count)

let tokenStream = try vlmEngine.generate(
with: embeddedInput,
tokens: vlmTokens,
samplingConfiguration: samplingConfiguration,
inferenceOptions: InferenceOptions(maxTokens: maxTokens)
inferenceOptions: InferenceOptions(
maxTokens: maxTokens,
includeLogits: printLogits || saveLogits != nil
Comment thread
stikves marked this conversation as resolved.
)
)

CLILogger.log("VLM generate started, maxTokens=\(maxTokens)", component: "VLM")

// Prompt (prefill) timing — first token latency
var promptSpan: ProfileSpan? = InstrumentsProfiler.beginPrompt(tokens: vlmTokens.count, engine: "CoreAIVLM")
var extendSpan: ProfileSpan?
let needsLogits = printLogits || saveLogits != nil
let topKCount = saveLogitsLength.topKForFile ?? 5

var generatedTokens: [Int] = []
var allTokenLogits: [TokenLogits] = []
var previousText = ""
for try await output in tokenStream {
if promptSpan != nil {
promptSpan?.end()
promptSpan = nil
extendSpan = InstrumentsProfiler.beginExtend(step: 0, tokens: 1)
}

let token = output.tokenId
if eosTokenIds.contains(token) { break }
generatedTokens.append(Int(token))

if needsLogits, let logits = output.logits {
let floatLogits = logits.map { Float($0) }
let topEntries = LogitsWriter.extractTopK(
from: floatLogits, tokenizer: tokenizer, k: topKCount)
let tokenText = tokenizer.decode(tokens: [Int(token)])
allTokenLogits.append(
TokenLogits(
tokenId: token, tokenText: tokenText, topLogits: topEntries))

if printLogits {
let desc = topEntries.prefix(5).map {
"[\($0.tokenId)]=\(String(format: "%.3f", $0.logit))"
}.joined(separator: " ")
print("\n logits top5: \(desc)", terminator: "")
}
}

let fullText = tokenizer.decode(tokens: generatedTokens)
let delta = String(fullText.dropFirst(previousText.count))
previousText = fullText
print(delta, terminator: "")
fflush(stdout)
}
promptSpan?.end()
extendSpan?.end()
print()

// Save logits to JSON if requested
if let path = saveLogits, !allTokenLogits.isEmpty {
try LogitsWriter.saveTopKJSON(tokenLogits: allTokenLogits, path: path)
}

// Record generation stats
InstrumentsProfiler.endInference(
generatedTokens: generatedTokens.count, signpostID: inferenceID)
await PerformanceMetrics.shared.setGeneratedTokenCount(generatedTokens.count)
Expand Down