diff --git a/swift/Sources/CoreAILanguageModels/Output/LogitsWriter.swift b/swift/Sources/CoreAILanguageModels/Output/LogitsWriter.swift index b49d0a8..673a7ed 100644 --- a/swift/Sources/CoreAILanguageModels/Output/LogitsWriter.swift +++ b/swift/Sources/CoreAILanguageModels/Output/LogitsWriter.swift @@ -77,16 +77,22 @@ public enum LogitsLength: Sendable { /// Represents logits information for a single generated token public struct TokenLogits: Sendable { - let tokenId: Int32 - let tokenText: String - let topLogits: [TopLogitEntry] + public let tokenId: Int32 + public let tokenText: String + public let topLogits: [TopLogitEntry] + + public init(tokenId: Int32, tokenText: String, topLogits: [TopLogitEntry]) { + self.tokenId = tokenId + self.tokenText = tokenText + self.topLogits = topLogits + } } /// Represents a single entry in top-K logits public struct TopLogitEntry: Codable, Sendable { - let tokenId: Int32 - let tokenText: String - let logit: Float + public let tokenId: Int32 + public let tokenText: String + public let logit: Float enum CodingKeys: String, CodingKey { case tokenId = "token_id" diff --git a/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift b/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift index 6b87b9a..cff76fa 100644 --- a/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift +++ b/swift/Sources/Tools/llm-runner/LLMRunnerMain.swift @@ -179,6 +179,9 @@ struct LLMRunner: AsyncParsableCommand, Sendable { @Option(name: .customLong("image"), help: "Path to an image file for vision-language models") var imagePath: String? + @Option( + help: "Maximum tiles for image splitting (overrides model config). 1 = single crop, no tiling.") + @Flag(help: "Enable verbose logging") var verbose: Bool = false @@ -886,27 +889,73 @@ struct LLMRunner: AsyncParsableCommand, Sendable { let inferenceID = InstrumentsProfiler.beginInference( promptTokens: vlmTokens.count, maxTokens: maxTokens) + await PerformanceMetrics.shared.setPromptTokenCount(vlmTokens.count) + let tokenStream = try vlmEngine.generate( with: embeddedInput, tokens: vlmTokens, samplingConfiguration: samplingConfiguration, - inferenceOptions: InferenceOptions(maxTokens: maxTokens) + inferenceOptions: InferenceOptions( + maxTokens: maxTokens, + includeLogits: printLogits || saveLogits != nil + ) ) + CLILogger.log("VLM generate started, maxTokens=\(maxTokens)", component: "VLM") + + // Prompt (prefill) timing — first token latency + var promptSpan: ProfileSpan? = InstrumentsProfiler.beginPrompt(tokens: vlmTokens.count, engine: "CoreAIVLM") + var extendSpan: ProfileSpan? + let needsLogits = printLogits || saveLogits != nil + let topKCount = saveLogitsLength.topKForFile ?? 5 + var generatedTokens: [Int] = [] + var allTokenLogits: [TokenLogits] = [] var previousText = "" for try await output in tokenStream { + if promptSpan != nil { + promptSpan?.end() + promptSpan = nil + extendSpan = InstrumentsProfiler.beginExtend(step: 0, tokens: 1) + } + let token = output.tokenId if eosTokenIds.contains(token) { break } generatedTokens.append(Int(token)) + + if needsLogits, let logits = output.logits { + let floatLogits = logits.map { Float($0) } + let topEntries = LogitsWriter.extractTopK( + from: floatLogits, tokenizer: tokenizer, k: topKCount) + let tokenText = tokenizer.decode(tokens: [Int(token)]) + allTokenLogits.append( + TokenLogits( + tokenId: token, tokenText: tokenText, topLogits: topEntries)) + + if printLogits { + let desc = topEntries.prefix(5).map { + "[\($0.tokenId)]=\(String(format: "%.3f", $0.logit))" + }.joined(separator: " ") + print("\n logits top5: \(desc)", terminator: "") + } + } + let fullText = tokenizer.decode(tokens: generatedTokens) let delta = String(fullText.dropFirst(previousText.count)) previousText = fullText print(delta, terminator: "") fflush(stdout) } + promptSpan?.end() + extendSpan?.end() print() + // Save logits to JSON if requested + if let path = saveLogits, !allTokenLogits.isEmpty { + try LogitsWriter.saveTopKJSON(tokenLogits: allTokenLogits, path: path) + } + + // Record generation stats InstrumentsProfiler.endInference( generatedTokens: generatedTokens.count, signpostID: inferenceID) await PerformanceMetrics.shared.setGeneratedTokenCount(generatedTokens.count)