diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a048de5..f8a1c892 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,23 @@ default; iOS, macOS, Linux, and Windows now default to `llama_cpp` only. * Added native release pin automation so the maintainer sync workflow updates Apple SPM checksums from published native release asset digests. + * Added `SpeculativeDecodingConfig` as a backend-neutral generation option for + selecting speculative decoding strategies such as MTP while keeping the + existing `GenerationParams.speculativeDecoding` flag as a compatibility + switch. + * Added llama.cpp native MTP speculative decoding for compatible GGUF models + through `SpeculativeDecodingConfig.mtp(...)`, defaulting to a conservative + one-token draft depth unless callers tune `draftTokenMax`. + * Updated the default llama.cpp native runtime pin to + `leehack/llamadart-native@b9547`, including the MTP wrapper exports and + `llama-common` runtime packaging. + * Added `ModelParams.speculativeRollbackTokenMax` so llama.cpp contexts can + reserve recurrent-state rollback snapshots required by Qwen3.5 MTP-style + models. + * Guarded llama.cpp MTP on Android Vulkan by default because the upstream + `draft-mtp` backend-sampling path can abort with `vk::DeviceLostError`; + CPU and other supported backends remain available, and a dart-define debug + override is available for reproductions. * **CI reliability**: * Cached and retried tiny GGUF test-model downloads used by VM integration tests so main-branch CI is less exposed to Hugging Face 429 rate limits. diff --git a/README.md b/README.md index 616f04c4..ee8e2596 100644 --- a/README.md +++ b/README.md @@ -114,7 +114,7 @@ hooks: llamadart: # Optional. Defaults to llamadart's tested native runtime pin. # Use a leehack/llamadart-native release tag when testing another build. - llamadart_native_tag: b9536 + llamadart_native_tag: b9547 # Optional. GitHub repository slug or github.com URL. llamadart_native_repository: leehack/llamadart-native @@ -142,7 +142,7 @@ the native-assets hook fails while downloading that asset. Native source overrides are for compatibility testing. They do not regenerate Dart FFI bindings or symbol lookups, so the selected binary still must be ABI- -and symbol-compatible with the default `leehack/llamadart-native@b9536` runtime. +and symbol-compatible with the default `leehack/llamadart-native@b9547` runtime. Available native tags are published on the [`leehack/llamadart-native` releases page](https://github.com/leehack/llamadart-native/releases). @@ -154,7 +154,7 @@ gh release list --repo leehack/llamadart-native --limit 20 Before overriding, confirm the release includes the asset for your target. The hook downloads files named `llamadart-native--.tar.gz`, for example -`llamadart-native-windows-x64-b9536.tar.gz`. +`llamadart-native-windows-x64-b9547.tar.gz`. For local testing, `llamadart_native_path` may point directly at a bundle archive, at an extracted bundle directory, or at a directory containing `//`, `/`, or the expected archive file. @@ -204,6 +204,44 @@ other models, or to override detection, pass `ModelParams.chatTemplate`. See for the template support matrix, real-model smoke commands, and how to add a family. +llama.cpp MTP speculative decoding is available for compatible GGUF models. For +Qwen3.5 MTP-style models, reserve rollback snapshots on the context and enable +MTP on the generation request: + +```dart +await engine.loadModel( + 'path/to/Qwen3.5-0.8B-MTP-Q4_K_M.gguf', + modelParams: const ModelParams( + contextSize: 2048, + batchSize: 512, + microBatchSize: 512, + speculativeRollbackTokenMax: 1, + ), +); + +await for (final token in engine.generate( + 'Explain local inference in one paragraph.', + params: const GenerationParams( + maxTokens: 128, + speculativeDecodingConfig: SpeculativeDecodingConfig.mtp( + draftTokenMax: 1, + ), + ), +)) { + stdout.write(token); +} +``` + +Higher `draftTokenMax` values can be faster on some models/devices, but they +should be benchmarked with the target model because excess draft depth can add +verification overhead. + +Android Vulkan MTP is currently disabled by default. The upstream llama.cpp +`draft-mtp` backend-sampling path can abort Android Vulkan processes with +`vk::DeviceLostError`; use CPU for Android MTP validation, or rebuild with +`--dart-define=LLAMADART_ANDROID_VULKAN_ALLOW_MTP=true` only when reproducing +or benchmarking that upstream path. + ### 6. Download and cache a remote model file ```dart @@ -372,7 +410,7 @@ overrides are rejected instead of being silently ignored. `.litertlm` generation honors `GenerationParams` `maxTokens`, `temp`, `topK`, `topP`, and `seed` on native and web, with `stopSequences` enforced by llamadart. Native LiteRT-LM also honors stream -batching thresholds and the opt-in `speculativeDecoding` flag; Web LiteRT-LM +batching thresholds and the opt-in speculative decoding APIs; Web LiteRT-LM rejects speculative decoding until the browser runtime exposes an equivalent control. llama.cpp-only sampling and constrained-decoding controls such as Min-P, repeat penalty overrides, grammar/lazy grammar triggers, @@ -384,7 +422,7 @@ the current strict structured-output boundary.
Full module matrix (available modules by target) -Available llama.cpp module matrix from the default native tag `b9536`: +Available llama.cpp module matrix from the default native tag `b9547`: | Target | Available backend modules in bundle | |--------|-------------------------------------| @@ -504,6 +542,8 @@ Notes: - `ModelParams.splitMode` passes through to llama.cpp `split_mode`; it defaults to upstream `layer` behavior. - `ModelParams.mainGpu` passes through to llama.cpp `main_gpu`. To select one GPU for the full model, use `splitMode: ModelSplitMode.none` with the desired `mainGpu` index. - `ModelParams.batchSize` (`n_batch`) and `ModelParams.microBatchSize` (`n_ubatch`) can be set independently for memory/performance tuning; defaults keep legacy behavior (`n_batch = n_ctx`, `n_ubatch = n_batch`). +- `ModelParams.speculativeRollbackTokenMax` passes through to llama.cpp `n_rs_seq`. Keep the default `0` for normal generation; set it to at least the MTP draft token max when a llama.cpp MTP model needs bounded rollback snapshots, such as Qwen3.5 MTP. +- Android Vulkan MTP is guarded by default because the upstream llama.cpp MTP backend-sampling path can crash the process. The debug-only escape hatch is `--dart-define=LLAMADART_ANDROID_VULKAN_ALLOW_MTP=true`. - `ModelParams.preferMemory64` and `ModelParams.modelBytesHint` are web/WebGPU only (ignored on native). They select the 64-bit (wasm64/mem64) bridge core so models larger than the ~4 GiB wasm32 address space (for example Gemma 4 E2B) can load; `null` auto-decides from the size hint (size-driven, no hardcoded model names). See the [WebGPU bridge docs](https://leehack.github.io/llamadart/docs/platforms/webgpu-bridge). - Apple targets use consolidated llama.cpp native libraries, so `llamadart_native_backends` does not split Apple backend modules. Use @@ -700,9 +740,9 @@ Current pinned runtime artifacts: | Runtime path | Published artifact | |--------------|--------------------| -| Native llama.cpp / GGUF | `leehack/llamadart-native@b9536` | +| Native llama.cpp / GGUF | `leehack/llamadart-native@b9547` | | Native LiteRT-LM / `.litertlm` | `leehack/litert-lm-native@v0.13.1` | -| Apple SPM llama.cpp / GGUF | `leehack/llamadart-native@b9536` Apple XCFramework | +| Apple SPM llama.cpp / GGUF | `leehack/llamadart-native@b9547` Apple XCFramework | | Apple SPM LiteRT-LM / `.litertlm` | `leehack/litert-lm-native@v0.13.1` Apple XCFrameworks | | Web llama.cpp / GGUF | `leehack/llama-web-bridge-assets@v0.1.16` | | Web LiteRT-LM / `.litertlm` | App-provided `@litert-lm/core` module URL; the chat app defaults to jsDelivr `@litert-lm/core/+esm` | diff --git a/darwin/llamadart/Package.swift b/darwin/llamadart/Package.swift index 2591c822..a3b90198 100644 --- a/darwin/llamadart/Package.swift +++ b/darwin/llamadart/Package.swift @@ -4,7 +4,7 @@ import PackageDescription let packageRoot = URL(fileURLWithPath: #filePath).deletingLastPathComponent() let artifactsRoot = packageRoot.appendingPathComponent("Artifacts") -let llamaCppTag = "b9536" +let llamaCppTag = "b9547" let liteRtLmTag = "v0.13.1" func localArtifactPath(_ name: String) -> String? { @@ -54,7 +54,7 @@ let package = Package( repository: "leehack/llamadart-native", artifactName: "llamadart-native-apple-xcframework-\(llamaCppTag).zip", tag: llamaCppTag, - checksum: "e71058acca310999c1c5ee03e52e1992bd4c31b528d97ca019e2ea132fc79ae8" + checksum: "df326c10018c0ac739560d0744db52598b7ea8158fd935b02f769d3ac2905237" ), nativeRepoBinaryTarget( name: "LiteRtLm", diff --git a/example/chat_app/lib/litert_lm_benchmark_app.dart b/example/chat_app/lib/litert_lm_benchmark_app.dart index ac8eb6b0..2f247c2b 100644 --- a/example/chat_app/lib/litert_lm_benchmark_app.dart +++ b/example/chat_app/lib/litert_lm_benchmark_app.dart @@ -183,7 +183,12 @@ Map _summarizeRuns(List> runs) { 'decodeWithSamplingTokensPerSecond', ), 'wallMilliseconds': _numericSummary(runs, 'wallMilliseconds'), + 'outputTokens': _numericSummary(runs, 'outputTokens'), 'evalTokens': _numericSummary(runs, 'evalTokens'), + 'targetWallTokensPerSecond': _numericSummary( + runs, + 'targetWallTokensPerSecond', + ), }; } @@ -479,6 +484,7 @@ class _LiteRtLmBenchmarkAppState extends State { contextSize: _maxTokens, gpuLayers: ModelParams.maxGpuLayers, preferredBackend: backendPreference, + speculativeRollbackTokenMax: _speculative ? 1 : 0, ), ); loadSw.stop(); @@ -492,7 +498,13 @@ class _LiteRtLmBenchmarkAppState extends State { await engine .generate( _promptController.text, - params: GenerationParams(maxTokens: _outputTokens, seed: 1), + params: GenerationParams( + maxTokens: _outputTokens, + seed: 1, + speculativeDecodingConfig: _speculative + ? const SpeculativeDecodingConfig.mtp() + : null, + ), ) .drain(); } @@ -506,22 +518,31 @@ class _LiteRtLmBenchmarkAppState extends State { final sw = Stopwatch()..start(); await for (final chunk in engine.generate( _promptController.text, - params: GenerationParams(maxTokens: _outputTokens, seed: 1), + params: GenerationParams( + maxTokens: _outputTokens, + seed: 1, + speculativeDecodingConfig: _speculative + ? const SpeculativeDecodingConfig.mtp() + : null, + ), )) { buffer.write(chunk); } sw.stop(); wallMs = sw.elapsedMilliseconds; lastText = buffer.toString(); + final outputTokenCount = lastText.isEmpty + ? 0 + : await engine.getTokenCount(lastText); perf = await engine.getPerformanceContext(); final runMetrics = { 'index': i, 'wallMilliseconds': wallMs, + 'speculativeDecoding': _speculative, + 'outputTokens': outputTokenCount, 'promptEvalTokens': perf?.promptEvalTokens, 'evalTokens': perf?.evalTokens, - 'hitEosBeforeTarget': perf == null - ? null - : perf.evalTokens < _outputTokens, + 'hitEosBeforeTarget': outputTokenCount < _outputTokens, 'promptEvalMs': perf?.promptEvalMs, 'evalMs': perf?.evalMs, 'sampleMs': perf?.sampleMs, @@ -535,9 +556,12 @@ class _LiteRtLmBenchmarkAppState extends State { perf == null || perf.evalMs + perf.sampleMs <= 0 ? null : perf.evalTokens / ((perf.evalMs + perf.sampleMs) / 1000.0), - 'wallTokensPerSecond': wallMs <= 0 || perf == null + 'wallTokensPerSecond': wallMs <= 0 || outputTokenCount <= 0 ? null - : perf.evalTokens / (wallMs / 1000.0), + : outputTokenCount / (wallMs / 1000.0), + 'targetWallTokensPerSecond': wallMs <= 0 + ? null + : _outputTokens / (wallMs / 1000.0), }; runsDetail.add(runMetrics); _append('RUN llamadart ${jsonEncode(runMetrics)}'); @@ -550,11 +574,15 @@ class _LiteRtLmBenchmarkAppState extends State { 'backendName': backendName, 'resolvedGpuLayers': resolvedGpuLayers, 'targetDecodeTokens': _outputTokens, + 'speculativeDecoding': _speculative, + 'outputTokens': runsDetail.isEmpty + ? null + : runsDetail.last['outputTokens'], 'promptEvalTokens': perf?.promptEvalTokens, 'evalTokens': perf?.evalTokens, - 'hitEosBeforeTarget': perf == null + 'hitEosBeforeTarget': runsDetail.isEmpty ? null - : perf.evalTokens < _outputTokens, + : runsDetail.last['hitEosBeforeTarget'], 'promptEvalMs': perf?.promptEvalMs, 'evalMs': perf?.evalMs, 'sampleMs': perf?.sampleMs, @@ -568,9 +596,12 @@ class _LiteRtLmBenchmarkAppState extends State { perf == null || perf.evalMs + perf.sampleMs <= 0 ? null : perf.evalTokens / ((perf.evalMs + perf.sampleMs) / 1000.0), - 'wallTokensPerSecond': wallMs <= 0 || perf == null + 'wallTokensPerSecond': runsDetail.isEmpty ? null - : perf.evalTokens / (wallMs / 1000.0), + : runsDetail.last['wallTokensPerSecond'], + 'targetWallTokensPerSecond': runsDetail.isEmpty + ? null + : runsDetail.last['targetWallTokensPerSecond'], 'runs': _runs, 'warmups': _warmups, 'measured': _summarizeRuns(runsDetail), diff --git a/example/chat_app/pubspec.lock b/example/chat_app/pubspec.lock index e6a5020f..29d4ff4a 100644 --- a/example/chat_app/pubspec.lock +++ b/example/chat_app/pubspec.lock @@ -349,7 +349,7 @@ packages: path: "../.." relative: true source: path - version: "0.7.1" + version: "0.7.2" logging: dependency: transitive description: diff --git a/hook/build.dart b/hook/build.dart index 51f014fe..8f3a29aa 100644 --- a/hook/build.dart +++ b/hook/build.dart @@ -11,7 +11,7 @@ import 'package:path/path.dart' as path; import 'package:llamadart/src/hook/native_bundle_config.dart'; -const _llamaCppTag = 'b9536'; +const _llamaCppTag = 'b9547'; const _nativeRepoSlug = 'leehack/llamadart-native'; const _packageName = 'llamadart'; diff --git a/lib/src/backends/litert_lm/litert_lm_backend_web.dart b/lib/src/backends/litert_lm/litert_lm_backend_web.dart index 869170d1..aae4370d 100644 --- a/lib/src/backends/litert_lm/litert_lm_backend_web.dart +++ b/lib/src/backends/litert_lm/litert_lm_backend_web.dart @@ -961,6 +961,9 @@ class LiteRtLmBackend if (params.speculativeDecoding) { unsupported.add('speculativeDecoding'); } + if (params.speculativeDecodingConfig != null) { + unsupported.add('speculativeDecodingConfig'); + } if (params.streamBatchTokenThreshold != defaults.streamBatchTokenThreshold) { unsupported.add('streamBatchTokenThreshold'); diff --git a/lib/src/backends/litert_lm/litert_lm_service.dart b/lib/src/backends/litert_lm/litert_lm_service.dart index 332e20a7..c30cd07c 100644 --- a/lib/src/backends/litert_lm/litert_lm_service.dart +++ b/lib/src/backends/litert_lm/litert_lm_service.dart @@ -522,7 +522,7 @@ class LiteRtLmService { ) { return _ensureClientForRuntime( outputTokens: params.maxTokens, - speculativeDecoding: params.speculativeDecoding, + speculativeDecoding: params.isSpeculativeDecodingEnabled, ); } @@ -843,6 +843,7 @@ class LiteRtLmService { if (params.grammarRoot != defaults.grammarRoot) { unsupported.add('grammarRoot'); } + _addUnsupportedSpeculativeDecodingOptions(params, unsupported); if (unsupported.isEmpty) { return; @@ -851,10 +852,30 @@ class LiteRtLmService { 'LiteRtLmBackend does not support llama.cpp-specific GenerationParams: ' '${unsupported.join(', ')}. Supported LiteRT-LM generation options are ' 'maxTokens, temp, topK, topP, seed, stopSequences, ' - 'speculativeDecoding, and native stream batching thresholds.', + 'speculativeDecoding, speculativeDecodingConfig, and native stream ' + 'batching thresholds.', ); } + void _addUnsupportedSpeculativeDecodingOptions( + GenerationParams params, + List unsupported, + ) { + final config = params.resolvedSpeculativeDecodingConfig; + if (config == null) { + return; + } + if (config.draftTokenMax != null) { + unsupported.add('speculativeDecodingConfig.draftTokenMax'); + } + if (config.draftTokenMin != null) { + unsupported.add('speculativeDecodingConfig.draftTokenMin'); + } + if (config.minProbability != null) { + unsupported.add('speculativeDecodingConfig.minProbability'); + } + } + int _defaultSamplerSeed() { return DateTime.now().microsecondsSinceEpoch & 0x7fffffff; } diff --git a/lib/src/backends/llama_cpp/bindings.dart b/lib/src/backends/llama_cpp/bindings.dart index dc70505d..8d37aefe 100644 --- a/lib/src/backends/llama_cpp/bindings.dart +++ b/lib/src/backends/llama_cpp/bindings.dart @@ -7589,12 +7589,121 @@ external int mtmd_helper_decode_image_chunk( @ffi.Native() external void llama_dart_set_log_level(int level); +@ffi.Native< + ffi.Pointer Function( + ffi.Pointer, + ffi.Pointer, + llama_context_params, + ffi.Int32, + ffi.Int32, + ffi.Float, + ffi.Bool, + ) +>() +external ffi.Pointer llama_dart_mtp_init( + ffi.Pointer model, + ffi.Pointer ctx_tgt, + llama_context_params ctx_params, + int draft_token_max, + int draft_token_min, + double min_probability, + bool backend_sampling, +); + +@ffi.Native)>() +external void llama_dart_mtp_free(ffi.Pointer mtp); + +@ffi.Native Function(ffi.Pointer)>() +external ffi.Pointer llama_dart_mtp_get_draft_context( + ffi.Pointer mtp, +); + +@ffi.Native< + ffi.Bool Function( + ffi.Pointer, + llama_seq_id, + ffi.Pointer, + ffi.Int32, + ) +>() +external bool llama_dart_mtp_begin( + ffi.Pointer mtp, + int seq_id, + ffi.Pointer prompt, + int prompt_count, +); + +@ffi.Native, llama_batch)>() +external bool llama_dart_mtp_process_batch( + ffi.Pointer mtp, + llama_batch batch, +); + +@ffi.Native< + ffi.Int32 Function( + ffi.Pointer, + llama_seq_id, + llama_pos, + llama_token, + ffi.Pointer, + ffi.Int32, + ffi.Int32, + ffi.Pointer, + ffi.Int32, + ) +>() +external int llama_dart_mtp_draft( + ffi.Pointer mtp, + int seq_id, + int n_past, + int id_last, + ffi.Pointer prompt, + int prompt_count, + int draft_token_max, + ffi.Pointer out_tokens, + int out_capacity, +); + +@ffi.Native< + ffi.Void Function(ffi.Pointer, llama_seq_id, ffi.Uint16) +>() +external void llama_dart_mtp_accept( + ffi.Pointer mtp, + int seq_id, + int accepted_count, +); + +@ffi.Native< + ffi.Int32 Function( + ffi.Pointer, + ffi.Pointer, + ffi.Pointer, + ffi.Int32, + ffi.Pointer, + ffi.Int32, + ffi.Pointer, + ffi.Int32, + ) +>() +external int llama_dart_sampler_sample_and_accept_n( + ffi.Pointer sampler, + ffi.Pointer ctx, + ffi.Pointer idxs, + int idx_count, + ffi.Pointer draft_tokens, + int draft_count, + ffi.Pointer out_tokens, + int out_capacity, +); + final class llama_vocab extends ffi.Opaque {} final class llama_model extends ffi.Opaque {} final class llama_context extends ffi.Opaque {} +final class llama_dart_mtp extends ffi.Opaque {} + typedef llama_token = ffi.Int32; typedef Dartllama_token = int; diff --git a/lib/src/backends/llama_cpp/llama_cpp_service.dart b/lib/src/backends/llama_cpp/llama_cpp_service.dart index 06a8ce23..72c94a33 100644 --- a/lib/src/backends/llama_cpp/llama_cpp_service.dart +++ b/lib/src/backends/llama_cpp/llama_cpp_service.dart @@ -16,6 +16,8 @@ import '../../core/models/inference/model_params.dart'; import 'load_param_helpers.dart'; import 'bindings.dart'; +const _llamadartWrapperAssetId = 'package:llamadart/llamadart_wrapper'; + typedef _GgmlBackendLoadNative = ggml_backend_reg_t Function(Pointer); typedef _GgmlBackendLoadDart = ggml_backend_reg_t Function(Pointer); typedef _GgmlBackendInitNative = ggml_backend_reg_t Function(); @@ -155,11 +157,196 @@ typedef _MtmdHelperEvalChunksDart = ); typedef _MtmdLogSetNative = Void Function(ggml_log_callback, Pointer); typedef _MtmdLogSetDart = void Function(ggml_log_callback, Pointer); +typedef _LlamaDartMtpInitNative = + Pointer Function( + Pointer, + Pointer, + llama_context_params, + Int32, + Int32, + Float, + Bool, + ); +typedef _LlamaDartMtpInitDart = + Pointer Function( + Pointer, + Pointer, + llama_context_params, + int, + int, + double, + bool, + ); +typedef _LlamaDartMtpFreeNative = Void Function(Pointer); +typedef _LlamaDartMtpFreeDart = void Function(Pointer); +typedef _LlamaDartMtpGetDraftContextNative = + Pointer Function(Pointer); +typedef _LlamaDartMtpGetDraftContextDart = + Pointer Function(Pointer); +typedef _LlamaDartMtpBeginNative = + Bool Function(Pointer, llama_seq_id, Pointer, Int32); +typedef _LlamaDartMtpBeginDart = + bool Function(Pointer, int, Pointer, int); +typedef _LlamaDartMtpProcessBatchNative = + Bool Function(Pointer, llama_batch); +typedef _LlamaDartMtpProcessBatchDart = + bool Function(Pointer, llama_batch); +typedef _LlamaDartMtpDraftNative = + Int32 Function( + Pointer, + llama_seq_id, + llama_pos, + llama_token, + Pointer, + Int32, + Int32, + Pointer, + Int32, + ); +typedef _LlamaDartMtpDraftDart = + int Function( + Pointer, + int, + int, + int, + Pointer, + int, + int, + Pointer, + int, + ); +typedef _LlamaDartMtpAcceptNative = + Void Function(Pointer, llama_seq_id, Uint16); +typedef _LlamaDartMtpAcceptDart = + void Function(Pointer, int, int); +typedef _LlamaDartSamplerSampleAndAcceptNNative = + Int32 Function( + Pointer, + Pointer, + Pointer, + Int32, + Pointer, + Int32, + Pointer, + Int32, + ); +typedef _LlamaDartSamplerSampleAndAcceptNDart = + int Function( + Pointer, + Pointer, + Pointer, + int, + Pointer, + int, + Pointer, + int, + ); + +@Native<_LlamaDartMtpInitNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_mtp_init', +) +external Pointer _llamadartWrapperMtpInit( + Pointer model, + Pointer context, + llama_context_params contextParams, + int draftTokenMax, + int draftTokenMin, + double minProbability, + bool backendSampling, +); + +@Native<_LlamaDartMtpFreeNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_mtp_free', +) +external void _llamadartWrapperMtpFree(Pointer mtp); + +@Native<_LlamaDartMtpGetDraftContextNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_mtp_get_draft_context', +) +external Pointer _llamadartWrapperMtpGetDraftContext( + Pointer mtp, +); + +@Native<_LlamaDartMtpBeginNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_mtp_begin', +) +external bool _llamadartWrapperMtpBegin( + Pointer mtp, + int seqId, + Pointer prompt, + int promptCount, +); + +@Native<_LlamaDartMtpProcessBatchNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_mtp_process_batch', +) +external bool _llamadartWrapperMtpProcessBatch( + Pointer mtp, + llama_batch batch, +); + +@Native<_LlamaDartMtpDraftNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_mtp_draft', +) +external int _llamadartWrapperMtpDraft( + Pointer mtp, + int seqId, + int nPast, + int idLast, + Pointer prompt, + int promptCount, + int draftTokenMax, + Pointer outTokens, + int outCapacity, +); + +@Native<_LlamaDartMtpAcceptNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_mtp_accept', +) +external void _llamadartWrapperMtpAccept( + Pointer mtp, + int seqId, + int acceptedCount, +); + +@Native<_LlamaDartSamplerSampleAndAcceptNNative>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_sampler_sample_and_accept_n', +) +external int _llamadartWrapperSamplerSampleAndAcceptN( + Pointer sampler, + Pointer context, + Pointer indexes, + int indexCount, + Pointer draftTokens, + int draftCount, + Pointer outTokens, + int outCapacity, +); final RegExp _linuxLlamadartProcMapsPattern = RegExp( r'/libllamadart\.so(?:\.\d+)?$', ); +class _LlamaCppMtpConfig { + const _LlamaCppMtpConfig({ + required this.draftTokenMax, + required this.draftTokenMin, + required this.minProbability, + }); + + final int draftTokenMax; + final int draftTokenMin; + final double minProbability; +} + /// Service responsible for managing Llama.cpp models and contexts. /// /// This service handles the direct interaction with the native Llama.cpp library, @@ -177,6 +364,10 @@ class LlamaCppService { 'LLAMADART_ANDROID_VULKAN_ALLOW_FLASH_ATTN', defaultValue: false, ); + static const bool _androidVulkanAllowMtp = bool.fromEnvironment( + 'LLAMADART_ANDROID_VULKAN_ALLOW_MTP', + defaultValue: false, + ); static const int _maxStartupDiagnostics = 32; static const Map _androidCpuVariantPriority = { @@ -232,6 +423,8 @@ class LlamaCppService { bool _mtmdFallbackLookupAttempted = false; bool _mtmdPrimarySymbolsUnavailable = false; _MtmdApi? _mtmdFallbackApi; + bool _mtpApiLookupAttempted = false; + _MtpApi? _mtpApi; final List _startupDiagnostics = []; // --- Internal State --- @@ -340,6 +533,24 @@ class LlamaCppService { normalized.contains('qwen_qwen3.5-4b'); } + /// Returns whether Android Vulkan MTP should be rejected before generation. + /// + /// Upstream llama.cpp `draft-mtp` backend sampling currently can abort Android + /// Vulkan processes with `vk::DeviceLostError`. Keep the public feature usable + /// on CPU and other backends, but fail fast for this backend combination unless + /// explicitly enabled for debugging/benchmarking. + static bool shouldRejectAndroidVulkanMtp( + String? backendName, { + int? resolvedGpuLayers, + bool isAndroid = false, + bool allowMtp = false, + }) { + if (!isAndroid || allowMtp || (resolvedGpuLayers ?? 0) <= 0) { + return false; + } + return (backendName ?? '').toLowerCase().contains('vulkan'); + } + /// Resolves effective context batch parameters. /// /// Uses the shared non-FFI helper so native and WebGPU batch semantics stay @@ -2044,6 +2255,75 @@ class LlamaCppService { } } + _MtpApi _resolveMtpApi() { + final cached = _mtpApi; + if (cached != null) { + return cached; + } + + if (_mtpApiLookupAttempted) { + throw UnsupportedError(_mtpUnavailableMessage()); + } + _mtpApiLookupAttempted = true; + + if (!Platform.isWindows) { + try { + final direct = _MtpApi.direct(); + _mtpApi = direct; + return direct; + } catch (_) {} + } + + if (Platform.isWindows) { + try { + final asset = _MtpApi.windowsAsset(); + _mtpApi = asset; + return asset; + } catch (_) {} + } + + for (final candidate in _llamadartWrapperLibraryCandidates()) { + try { + final library = DynamicLibrary.open(candidate); + final api = _MtpApi.tryLoad(library); + if (api != null) { + _mtpApi = api; + return api; + } + } catch (_) { + continue; + } + } + + throw UnsupportedError(_mtpUnavailableMessage()); + } + + String _mtpUnavailableMessage() { + return 'llama.cpp MTP speculative decoding is unavailable in this native ' + 'runtime bundle (missing llama_dart_mtp_* wrapper symbols).'; + } + + List _llamadartWrapperLibraryCandidates() { + final candidates = [..._llamadartAssetUriCandidates()]; + final fileNameCandidates = _llamadartLibraryCandidateFileNames(); + final pattern = _llamadartLibraryPattern(); + for (final directoryPath in _llamadartFallbackLookupDirectories()) { + for (final fileName in fileNameCandidates) { + candidates.add(path.join(directoryPath, fileName)); + } + for (final fileName in _matchingLibraryNames(directoryPath, pattern)) { + candidates.add(path.join(directoryPath, fileName)); + } + } + candidates.addAll(fileNameCandidates); + + final seen = {}; + return [ + for (final candidate in candidates) + if (seen.add(candidate)) candidate, + ]; + } + List _llamadartAssetUriCandidates() { // Prefer asset-URI resolution so Windows split bundles can reliably resolve // the wrapper helper library without relying on process cwd/search paths. @@ -2066,6 +2346,7 @@ class LlamaCppService { final executableDir = path.dirname(Platform.resolvedExecutable); directories.add(executableDir); directories.add(Directory.current.path); + directories.add(path.join(Directory.current.path, '.dart_tool', 'lib')); if (Platform.isIOS) { directories.add(path.normalize(path.join(executableDir, 'Frameworks'))); @@ -2620,6 +2901,7 @@ class LlamaCppService { ctxParams.n_batch = resolvedBatchSizes.batchSize; ctxParams.n_ubatch = resolvedBatchSizes.microBatchSize; ctxParams.n_seq_max = resolvedMaxParallelSequences; + ctxParams.n_rs_seq = params.speculativeRollbackTokenMax; ctxParams.n_threads = params.numberOfThreads; ctxParams.n_threads_batch = params.numberOfThreadsBatch; if (resolvedMaxParallelSequences > 1) { @@ -2709,6 +2991,67 @@ class LlamaCppService { _contexts.remove(handle)?.dispose(); } + _LlamaCppMtpConfig? _resolveLlamaCppMtpConfig( + GenerationParams params, { + required bool hasMediaParts, + }) { + final speculativeConfig = params.resolvedSpeculativeDecodingConfig; + if (speculativeConfig == null) { + return null; + } + + if (hasMediaParts) { + throw UnsupportedError( + 'llama.cpp MTP speculative decoding currently supports text-only ' + 'generation in llamadart.', + ); + } + if (params.grammar != null) { + throw UnsupportedError( + 'llama.cpp MTP speculative decoding does not yet support grammar ' + 'sampling in llamadart.', + ); + } + + switch (speculativeConfig.strategy) { + case SpeculativeDecodingStrategy.backendDefault: + case SpeculativeDecodingStrategy.mtp: + break; + } + + final draftTokenMax = speculativeConfig.draftTokenMax ?? 1; + final draftTokenMin = speculativeConfig.draftTokenMin ?? 0; + final minProbability = speculativeConfig.minProbability ?? 0.0; + + if (draftTokenMax <= 0) { + throw RangeError.value( + draftTokenMax, + 'draftTokenMax', + 'must be greater than zero for llama.cpp MTP', + ); + } + if (draftTokenMin < 0 || draftTokenMin > draftTokenMax) { + throw RangeError.value( + draftTokenMin, + 'draftTokenMin', + 'must be between zero and draftTokenMax for llama.cpp MTP', + ); + } + if (minProbability < 0.0 || minProbability > 1.0) { + throw RangeError.value( + minProbability, + 'minProbability', + 'must be between 0.0 and 1.0 for llama.cpp MTP', + ); + } + + return _LlamaCppMtpConfig( + draftTokenMax: draftTokenMax, + draftTokenMin: draftTokenMin, + minProbability: minProbability, + ); + } + /// Generates text based on the given [prompt] and [params]. /// /// Returns a [Stream] of token bytes. @@ -2720,14 +3063,6 @@ class LlamaCppService { int cancelTokenAddress, { List? parts, }) async* { - if (params.speculativeDecoding) { - throw UnsupportedError( - 'llama.cpp speculative decoding is not exposed by llamadart yet. ' - 'Use the LiteRT-LM native backend or track llama.cpp support in ' - 'issues #168/#190.', - ); - } - var ctx = _contexts[contextHandle]; if (ctx == null) throw Exception("Invalid context handle"); _generatingContexts.update( @@ -2742,6 +3077,8 @@ class LlamaCppService { Pointer rootPtr = nullptr; _LazyGrammarConfig? lazyGrammarConfig; Pointer sampler = nullptr; + Pointer mtpSession = nullptr; + _MtpApi? mtpApi; try { final modelHandle = _contextToModel[contextHandle]!; @@ -2751,13 +3088,33 @@ class LlamaCppService { final hasMediaParts = parts?.any((p) => p is LlamaImageContent || p is LlamaAudioContent) ?? false; + final mtpConfig = _resolveLlamaCppMtpConfig( + params, + hasMediaParts: hasMediaParts, + ); + if (mtpConfig != null && + shouldRejectAndroidVulkanMtp( + _modelBackendNames[modelHandle], + resolvedGpuLayers: _modelResolvedGpuLayers[modelHandle], + isAndroid: Platform.isAndroid, + allowMtp: _androidVulkanAllowMtp, + )) { + throw UnsupportedError( + 'llama.cpp MTP speculative decoding is disabled for Android Vulkan ' + 'because the upstream draft-mtp backend-sampling path can abort with ' + 'vk::DeviceLostError. Use the CPU backend, disable MTP, or rebuild ' + 'with LLAMADART_ANDROID_VULKAN_ALLOW_MTP=true for debugging.', + ); + } // 1. Reset Context ctx = _resetContext( contextHandle, ctx, - clearMemory: hasMediaParts || !params.reusePromptPrefix, + clearMemory: + hasMediaParts || mtpConfig != null || !params.reusePromptPrefix, ); + ctx.resetLastPerf(); llama_perf_context_reset(ctx.pointer); final existingSampler = _samplers[contextHandle]; if (existingSampler != null) { @@ -2770,6 +3127,29 @@ class LlamaCppService { tokensPtr = malloc(nCtx); pieceBuf = malloc(256); + if (mtpConfig != null) { + mtpApi = _resolveMtpApi(); + mtpSession = mtpApi.init( + model.pointer, + ctx.pointer, + modelParams, + mtpConfig.draftTokenMax, + mtpConfig.draftTokenMin, + mtpConfig.minProbability, + true, + ); + if (mtpSession == nullptr) { + throw UnsupportedError( + 'llama.cpp MTP speculative decoding is not available for this ' + 'model/context. Use an MTP GGUF model, a native libllamadart build ' + 'that includes llama-common, and set ' + 'ModelParams.speculativeRollbackTokenMax >= draftTokenMax when the ' + 'target architecture needs bounded rollback snapshots.', + ); + } + llama_perf_context_reset(ctx.pointer); + } + if (params.grammar != null) { grammarPtr = params.grammar!.toNativeUtf8(); rootPtr = params.grammarRoot.toNativeUtf8(); @@ -2791,7 +3171,10 @@ class LlamaCppService { tokensPtr, nCtx, modelParams, - allowTextPromptReuse: !hasMediaParts && params.reusePromptPrefix, + allowTextPromptReuse: + mtpConfig == null && !hasMediaParts && params.reusePromptPrefix, + mtpSession: mtpSession, + mtpApi: mtpApi, ); promptEvalStopwatch.stop(); ctx.lastPerfPromptEvalMs = @@ -2799,6 +3182,10 @@ class LlamaCppService { ctx.lastPerfPromptEvalTokens = initialTokens; _ensureLogitsAvailableAfterPromptEval(ctx.pointer); + if (mtpSession != nullptr && + !mtpApi!.begin(mtpSession, 0, tokensPtr, initialTokens)) { + throw Exception("Failed to initialize llama.cpp MTP prompt state"); + } // 4. Initialize and Run Sampler Loop sampler = _initializeSampler( @@ -2820,21 +3207,42 @@ class LlamaCppService { params.preservedTokens, ); - yield* _runInferenceLoop( - ctx, - batch, - vocab, - sampler, - params, - initialTokens, - nCtx, - cancelTokenAddress, - pieceBuf, - grammarPtr, - preservedTokenIds, - effectiveStopSequences, - ); + if (mtpSession != nullptr && mtpConfig != null) { + yield* _runMtpInferenceLoop( + ctx, + batch, + vocab, + sampler, + params, + mtpConfig, + initialTokens, + nCtx, + cancelTokenAddress, + pieceBuf, + preservedTokenIds, + effectiveStopSequences, + mtpSession, + mtpApi!, + tokensPtr, + ); + } else { + yield* _runInferenceLoop( + ctx, + batch, + vocab, + sampler, + params, + initialTokens, + nCtx, + cancelTokenAddress, + pieceBuf, + grammarPtr, + preservedTokenIds, + effectiveStopSequences, + ); + } } finally { + if (mtpSession != nullptr) mtpApi?.free(mtpSession); if (sampler != nullptr) llama_sampler_free(sampler); final remaining = (_generatingContexts[contextHandle] ?? 1) - 1; if (remaining <= 0) { @@ -3245,6 +3653,8 @@ class LlamaCppService { int nCtx, llama_context_params modelParams, { required bool allowTextPromptReuse, + required Pointer mtpSession, + required _MtpApi? mtpApi, }) { final mediaParts = parts @@ -3273,6 +3683,8 @@ class LlamaCppService { ctx, maxBatchTokens: modelParams.n_batch, allowPromptReuse: allowTextPromptReuse, + mtpSession: mtpSession, + mtpApi: mtpApi, ); } } @@ -3511,6 +3923,8 @@ class LlamaCppService { _LlamaContextWrapper ctx, { required int maxBatchTokens, required bool allowPromptReuse, + required Pointer mtpSession, + required _MtpApi? mtpApi, }) { final promptPtr = prompt.toNativeUtf8(); final shouldAddSpecial = !_promptStartsWithBosToken(vocab, prompt); @@ -3536,6 +3950,9 @@ class LlamaCppService { ctx, nTokens, maxBatchTokens: maxBatchTokens, + outputAllLogits: mtpSession != nullptr, + mtpSession: mtpSession, + mtpApi: mtpApi, ); } @@ -3547,6 +3964,9 @@ class LlamaCppService { ctx, nTokens, maxBatchTokens: maxBatchTokens, + outputAllLogits: mtpSession != nullptr, + mtpSession: mtpSession, + mtpApi: mtpApi, ); } @@ -3571,6 +3991,9 @@ class LlamaCppService { nTokens, maxBatchTokens: maxBatchTokens, existingCachedTokens: canReuseCachedCopy ? cachedTokens : null, + outputAllLogits: mtpSession != nullptr, + mtpSession: mtpSession, + mtpApi: mtpApi, ); } @@ -3582,6 +4005,9 @@ class LlamaCppService { ctx, nTokens, maxBatchTokens: maxBatchTokens, + outputAllLogits: mtpSession != nullptr, + mtpSession: mtpSession, + mtpApi: mtpApi, ); } @@ -3597,6 +4023,8 @@ class LlamaCppService { ctx, nTokens, maxBatchTokens: maxBatchTokens, + outputAllLogits: mtpSession != nullptr, + mtpSession: mtpSession, ); } @@ -3608,6 +4036,9 @@ class LlamaCppService { startTokenIndex: decodeStart, tokenCount: suffixTokenCount, maxBatchTokens: maxBatchTokens, + outputAllLogits: mtpSession != nullptr, + mtpSession: mtpSession, + mtpApi: mtpApi, ); ctx.cachedPromptTokens = exactStateLoadMatch @@ -3625,6 +4056,9 @@ class LlamaCppService { int nTokens, { required int maxBatchTokens, List? existingCachedTokens, + bool outputAllLogits = false, + Pointer? mtpSession, + _MtpApi? mtpApi, }) { _clearContextMemory(ctx.pointer); _decodePromptSegment( @@ -3634,6 +4068,9 @@ class LlamaCppService { startTokenIndex: 0, tokenCount: nTokens, maxBatchTokens: maxBatchTokens, + outputAllLogits: outputAllLogits, + mtpSession: mtpSession, + mtpApi: mtpApi, ); ctx.cachedPromptTokens = existingCachedTokens ?? _copyPromptTokens(tokensPtr, nTokens); @@ -3655,6 +4092,9 @@ class LlamaCppService { required int startTokenIndex, required int tokenCount, required int maxBatchTokens, + bool outputAllLogits = false, + Pointer? mtpSession, + _MtpApi? mtpApi, }) { if (tokenCount <= 0) { return; @@ -3679,12 +4119,17 @@ class LlamaCppService { batch.n_seq_id[i] = 1; batch.seq_id[i][0] = 0; final isLastTokenInPrompt = decoded + i == tokenCount - 1; - batch.logits[i] = isLastTokenInPrompt ? 1 : 0; + batch.logits[i] = outputAllLogits || isLastTokenInPrompt ? 1 : 0; } if (llama_decode(ctx.pointer, batch) != 0) { throw Exception("Initial decode failed"); } + if (mtpSession != null && + mtpSession != nullptr && + !mtpApi!.processBatch(mtpSession, batch)) { + throw Exception("MTP prompt decode processing failed"); + } decoded += chunkTokenCount; } @@ -3853,6 +4298,276 @@ class LlamaCppService { ctx.lastPerfSampleCount = generatedTokens; } + Stream> _runMtpInferenceLoop( + _LlamaContextWrapper ctx, + llama_batch batch, + Pointer vocab, + Pointer sampler, + GenerationParams params, + _LlamaCppMtpConfig mtpConfig, + int startPos, + int nCtx, + int cancelTokenAddress, + Pointer pieceBuf, + Set preservedTokenIds, + List stopSequences, + Pointer mtpSession, + _MtpApi mtpApi, + Pointer tokensPtr, + ) async* { + final cancelToken = Pointer.fromAddress(cancelTokenAddress); + final draftCapacity = mtpConfig.draftTokenMax; + final draftPtr = malloc(draftCapacity); + final idxPtr = malloc(draftCapacity + 1); + final acceptedPtr = malloc(draftCapacity + 1); + + int currentPos = startPos; + int? pendingSampledToken; + final accumulatedBytes = []; + final evalStopwatch = Stopwatch()..start(); + var sampleMicros = 0; + var evalMicros = 0; + var generatedTokens = 0; + var shouldStop = false; + + try { + while (!shouldStop && generatedTokens < params.maxTokens) { + if (cancelToken.value == 1) break; + if (currentPos >= nCtx) break; + + int selectedToken; + if (pendingSampledToken != null) { + selectedToken = pendingSampledToken; + pendingSampledToken = null; + } else { + final sampleTick = Stopwatch()..start(); + selectedToken = llama_sampler_sample(sampler, ctx.pointer, -1); + sampleTick.stop(); + sampleMicros += sampleTick.elapsedMicroseconds; + if (llama_vocab_is_eog(vocab, selectedToken)) break; + + final pieceTick = Stopwatch()..start(); + final n = llama_token_to_piece( + vocab, + selectedToken, + pieceBuf.cast(), + 256, + 0, + preservedTokenIds.contains(selectedToken), + ); + pieceTick.stop(); + sampleMicros += pieceTick.elapsedMicroseconds; + generatedTokens++; + + if (n > 0) { + final bytes = pieceBuf.asTypedList(n).toList(); + yield bytes; + if (stopSequences.isNotEmpty) { + accumulatedBytes.addAll(bytes); + if (accumulatedBytes.length > 64) { + accumulatedBytes.removeRange(0, accumulatedBytes.length - 64); + } + final text = utf8.decode(accumulatedBytes, allowMalformed: true); + if (stopSequences.any((s) => text.endsWith(s))) { + shouldStop = true; + } + } + } + + if (shouldStop || generatedTokens >= params.maxTokens) { + break; + } + } + + final remainingToGenerate = params.maxTokens - generatedTokens; + final batchCapacity = math.max(1, llama_n_batch(ctx.pointer)); + final rollbackCapacity = llama_n_rs_seq(ctx.pointer); + final contextDraftCapacity = rollbackCapacity > 0 + ? math.min(nCtx - currentPos - 2, rollbackCapacity) + : nCtx - currentPos - 2; + final draftLimit = math.min( + mtpConfig.draftTokenMax, + math.min( + math.min(remainingToGenerate - 1, contextDraftCapacity), + batchCapacity - 1, + ), + ); + + var draftCount = 0; + if (draftLimit > 0) { + final draftTick = Stopwatch()..start(); + draftCount = mtpApi.draft( + mtpSession, + 0, + currentPos, + selectedToken, + tokensPtr, + currentPos, + draftLimit, + draftPtr, + draftCapacity, + ); + draftTick.stop(); + sampleMicros += draftTick.elapsedMicroseconds; + if (draftCount < 0) { + throw Exception("llama.cpp MTP draft failed"); + } + } + + if (draftCount <= 0) { + batch.n_tokens = 1; + batch.token[0] = selectedToken; + batch.pos[0] = currentPos; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 1; + + final evalTick = Stopwatch()..start(); + final decodeStatus = llama_decode(ctx.pointer, batch); + if (decodeStatus == 0 && !mtpApi.processBatch(mtpSession, batch)) { + throw Exception("MTP decode processing failed"); + } + evalTick.stop(); + evalMicros += evalTick.elapsedMicroseconds; + if (decodeStatus != 0) break; + + tokensPtr[currentPos] = selectedToken; + currentPos++; + continue; + } + + final draftContext = mtpApi.getDraftContext(mtpSession); + final draftMemory = draftContext == nullptr + ? nullptr + : llama_get_memory(draftContext); + if (draftMemory == nullptr || + !llama_memory_seq_rm(draftMemory, 0, currentPos, -1)) { + throw UnsupportedError( + 'llama.cpp MTP draft rollback failed for this context.', + ); + } + + final batchTokens = draftCount + 1; + batch.n_tokens = batchTokens; + batch.token[0] = selectedToken; + batch.pos[0] = currentPos; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = 0; + batch.logits[0] = 1; + for (int i = 0; i < draftCount; i++) { + final batchIndex = i + 1; + batch.token[batchIndex] = draftPtr[i]; + batch.pos[batchIndex] = currentPos + batchIndex; + batch.n_seq_id[batchIndex] = 1; + batch.seq_id[batchIndex][0] = 0; + batch.logits[batchIndex] = 1; + } + + final evalTick = Stopwatch()..start(); + final decodeStatus = llama_decode(ctx.pointer, batch); + if (decodeStatus == 0 && !mtpApi.processBatch(mtpSession, batch)) { + throw Exception("MTP decode processing failed"); + } + evalTick.stop(); + evalMicros += evalTick.elapsedMicroseconds; + if (decodeStatus != 0) break; + + for (int i = 0; i < batchTokens; i++) { + idxPtr[i] = i; + } + + final verifyTick = Stopwatch()..start(); + final acceptedCount = mtpApi.sampleAndAcceptN( + sampler, + ctx.pointer, + idxPtr, + batchTokens, + draftPtr, + draftCount, + acceptedPtr, + batchTokens, + ); + verifyTick.stop(); + sampleMicros += verifyTick.elapsedMicroseconds; + if (acceptedCount <= 0) { + throw Exception("llama.cpp MTP draft verification failed"); + } + + final acceptedDraftCount = acceptedCount - 1; + mtpApi.accept(mtpSession, 0, acceptedDraftCount); + + final keepUntil = currentPos + 1 + acceptedDraftCount; + final targetMemory = llama_get_memory(ctx.pointer); + if (targetMemory == nullptr || + !llama_memory_seq_rm(targetMemory, 0, keepUntil, -1) || + !llama_memory_seq_rm(draftMemory, 0, keepUntil, -1)) { + throw UnsupportedError( + 'llama.cpp MTP target rollback failed for this context.', + ); + } + + tokensPtr[currentPos] = selectedToken; + for (int i = 0; i < acceptedDraftCount; i++) { + tokensPtr[currentPos + 1 + i] = acceptedPtr[i]; + } + currentPos = keepUntil; + + for (int i = 0; i < acceptedCount; i++) { + final token = acceptedPtr[i]; + if (llama_vocab_is_eog(vocab, token)) { + shouldStop = true; + break; + } + + final pieceTick = Stopwatch()..start(); + final n = llama_token_to_piece( + vocab, + token, + pieceBuf.cast(), + 256, + 0, + preservedTokenIds.contains(token), + ); + pieceTick.stop(); + sampleMicros += pieceTick.elapsedMicroseconds; + generatedTokens++; + + if (n > 0) { + final bytes = pieceBuf.asTypedList(n).toList(); + yield bytes; + if (stopSequences.isNotEmpty) { + accumulatedBytes.addAll(bytes); + if (accumulatedBytes.length > 64) { + accumulatedBytes.removeRange(0, accumulatedBytes.length - 64); + } + final text = utf8.decode(accumulatedBytes, allowMalformed: true); + if (stopSequences.any((s) => text.endsWith(s))) { + shouldStop = true; + } + } + } + + if (shouldStop || generatedTokens >= params.maxTokens) { + break; + } + } + + if (!shouldStop && generatedTokens < params.maxTokens) { + pendingSampledToken = acceptedPtr[acceptedCount - 1]; + } + } + } finally { + malloc.free(draftPtr); + malloc.free(idxPtr); + malloc.free(acceptedPtr); + evalStopwatch.stop(); + ctx.lastPerfEvalMs = evalMicros / 1000.0; + ctx.lastPerfSampleMs = sampleMicros / 1000.0; + ctx.lastPerfEvalTokens = generatedTokens; + ctx.lastPerfSampleCount = generatedTokens; + } + } + _LazyGrammarConfig? _buildLazyGrammarConfig(GenerationParams params) { final triggerPatterns = []; final triggerTokens = []; @@ -4881,20 +5596,22 @@ class LlamaCppService { final sampler = _samplers[contextHandle]; final samplerPerf = sampler != null ? llama_perf_sampler(sampler) : null; - final promptEvalMs = perf.t_p_eval_ms > 0 - ? perf.t_p_eval_ms - : ctx.lastPerfPromptEvalMs; - final evalMs = perf.t_eval_ms > 0 ? perf.t_eval_ms : ctx.lastPerfEvalMs; - final sampleMs = (samplerPerf?.t_sample_ms ?? 0) > 0 - ? samplerPerf!.t_sample_ms - : ctx.lastPerfSampleMs; + final promptEvalMs = ctx.lastPerfPromptEvalMs > 0 + ? ctx.lastPerfPromptEvalMs + : perf.t_p_eval_ms; + final evalMs = ctx.lastPerfEvalMs > 0 ? ctx.lastPerfEvalMs : perf.t_eval_ms; + final sampleMs = ctx.lastPerfSampleMs > 0 + ? ctx.lastPerfSampleMs + : (samplerPerf?.t_sample_ms ?? 0); final promptEvalTokens = perf.n_p_eval > 0 ? perf.n_p_eval : ctx.lastPerfPromptEvalTokens; - final evalTokens = perf.n_eval > 0 ? perf.n_eval : ctx.lastPerfEvalTokens; - final sampleCount = (samplerPerf?.n_sample ?? 0) > 0 - ? samplerPerf!.n_sample - : ctx.lastPerfSampleCount; + final evalTokens = ctx.lastPerfEvalTokens > 0 + ? ctx.lastPerfEvalTokens + : perf.n_eval; + final sampleCount = ctx.lastPerfSampleCount > 0 + ? ctx.lastPerfSampleCount + : (samplerPerf?.n_sample ?? 0); return ( loadMs: perf.t_load_ms, @@ -4981,6 +5698,132 @@ class _LazyGrammarConfig { } } +class _MtpApi { + final _LlamaDartMtpInitDart init; + final _LlamaDartMtpFreeDart free; + final _LlamaDartMtpGetDraftContextDart getDraftContext; + final _LlamaDartMtpBeginDart begin; + final _LlamaDartMtpProcessBatchDart processBatch; + final _LlamaDartMtpDraftDart draft; + final _LlamaDartMtpAcceptDart accept; + final _LlamaDartSamplerSampleAndAcceptNDart sampleAndAcceptN; + + const _MtpApi({ + required this.init, + required this.free, + required this.getDraftContext, + required this.begin, + required this.processBatch, + required this.draft, + required this.accept, + required this.sampleAndAcceptN, + }); + + factory _MtpApi.direct() { + final api = _MtpApi( + init: llama_dart_mtp_init, + free: llama_dart_mtp_free, + getDraftContext: llama_dart_mtp_get_draft_context, + begin: llama_dart_mtp_begin, + processBatch: llama_dart_mtp_process_batch, + draft: llama_dart_mtp_draft, + accept: llama_dart_mtp_accept, + sampleAndAcceptN: llama_dart_sampler_sample_and_accept_n, + ); + api.probe(); + return api; + } + + factory _MtpApi.windowsAsset() { + _llamadartWrapperMtpFree(nullptr.cast()); + return _MtpApi( + init: _llamadartWrapperMtpInit, + free: _llamadartWrapperMtpFree, + getDraftContext: _llamadartWrapperMtpGetDraftContext, + begin: _llamadartWrapperMtpBegin, + processBatch: _llamadartWrapperMtpProcessBatch, + draft: _llamadartWrapperMtpDraft, + accept: _llamadartWrapperMtpAccept, + sampleAndAcceptN: _llamadartWrapperSamplerSampleAndAcceptN, + ); + } + + void probe() { + final nullMtp = nullptr.cast(); + final nullModel = nullptr.cast(); + final nullContext = nullptr.cast(); + final nullSampler = nullptr.cast(); + final nullTokenArray = nullptr.cast(); + final ctxParams = llama_context_default_params(); + final batch = llama_batch_init(1, 0, 1); + try { + init(nullModel, nullContext, ctxParams, 1, 0, 0.0, true); + free(nullMtp); + getDraftContext(nullMtp); + begin(nullMtp, 0, nullTokenArray, 0); + processBatch(nullMtp, batch); + draft(nullMtp, 0, 0, 0, nullTokenArray, 0, 1, nullTokenArray, 0); + accept(nullMtp, 0, 0); + sampleAndAcceptN( + nullSampler, + nullContext, + nullTokenArray, + 0, + nullTokenArray, + 0, + nullTokenArray, + 0, + ); + } finally { + llama_batch_free(batch); + } + } + + static _MtpApi? tryLoad(DynamicLibrary library) { + try { + return _MtpApi( + init: library + .lookupFunction<_LlamaDartMtpInitNative, _LlamaDartMtpInitDart>( + 'llama_dart_mtp_init', + ), + free: library + .lookupFunction<_LlamaDartMtpFreeNative, _LlamaDartMtpFreeDart>( + 'llama_dart_mtp_free', + ), + getDraftContext: library + .lookupFunction< + _LlamaDartMtpGetDraftContextNative, + _LlamaDartMtpGetDraftContextDart + >('llama_dart_mtp_get_draft_context'), + begin: library + .lookupFunction<_LlamaDartMtpBeginNative, _LlamaDartMtpBeginDart>( + 'llama_dart_mtp_begin', + ), + processBatch: library + .lookupFunction< + _LlamaDartMtpProcessBatchNative, + _LlamaDartMtpProcessBatchDart + >('llama_dart_mtp_process_batch'), + draft: library + .lookupFunction<_LlamaDartMtpDraftNative, _LlamaDartMtpDraftDart>( + 'llama_dart_mtp_draft', + ), + accept: library + .lookupFunction<_LlamaDartMtpAcceptNative, _LlamaDartMtpAcceptDart>( + 'llama_dart_mtp_accept', + ), + sampleAndAcceptN: library + .lookupFunction< + _LlamaDartSamplerSampleAndAcceptNNative, + _LlamaDartSamplerSampleAndAcceptNDart + >('llama_dart_sampler_sample_and_accept_n'), + ); + } catch (_) { + return null; + } + } +} + class _MtmdApi { final _MtmdDefaultMarkerDart defaultMarker; final _MtmdContextParamsDefaultDart contextParamsDefault; @@ -5142,6 +5985,15 @@ class _LlamaContextWrapper { int lastPerfEvalTokens = 0; int lastPerfSampleCount = 0; _LlamaContextWrapper(this.pointer, this._modelKeepAlive); + void resetLastPerf() { + lastPerfPromptEvalMs = 0; + lastPerfEvalMs = 0; + lastPerfSampleMs = 0; + lastPerfPromptEvalTokens = 0; + lastPerfEvalTokens = 0; + lastPerfSampleCount = 0; + } + void dispose() { // ignore: unused_local_variable final _ = _modelKeepAlive; diff --git a/lib/src/backends/webgpu/webgpu_backend.dart b/lib/src/backends/webgpu/webgpu_backend.dart index 4f437524..c1fce724 100644 --- a/lib/src/backends/webgpu/webgpu_backend.dart +++ b/lib/src/backends/webgpu/webgpu_backend.dart @@ -1649,7 +1649,7 @@ class WebGpuLlamaBackend GenerationParams params, { List? parts, }) { - if (params.speculativeDecoding) { + if (params.isSpeculativeDecodingEnabled) { throw UnsupportedError( 'WebGPU speculative decoding is not supported yet.', ); diff --git a/lib/src/core/models/inference/generation_params.dart b/lib/src/core/models/inference/generation_params.dart index 908be2b3..621fe1a3 100644 --- a/lib/src/core/models/inference/generation_params.dart +++ b/lib/src/core/models/inference/generation_params.dart @@ -32,6 +32,76 @@ class GenerationGrammarTrigger { }); } +/// Backend-neutral speculative decoding strategy. +enum SpeculativeDecodingStrategy { + /// Let the selected backend choose its native speculative decoding mode. + backendDefault, + + /// Multi-token prediction. + /// + /// llama.cpp maps this to its `draft-mtp` speculative path. LiteRT-LM native + /// currently maps this to its runtime speculative decoding switch. + mtp, +} + +/// Backend-neutral speculative decoding configuration. +/// +/// Backends map the strategy and knobs they support to their native runtime. +/// Unsupported strategy/option combinations must fail explicitly instead of +/// silently falling back. +class SpeculativeDecodingConfig { + /// Strategy to use when speculative decoding is enabled. + final SpeculativeDecodingStrategy strategy; + + /// Maximum number of draft tokens to propose per speculative step. + /// + /// `null` lets the backend choose its default. + final int? draftTokenMax; + + /// Minimum number of draft tokens required for speculative verification. + /// + /// `null` lets the backend choose its default. + final int? draftTokenMin; + + /// Minimum draft-token probability accepted by the backend. + /// + /// `null` lets the backend choose its default. + final double? minProbability; + + /// Creates a backend-neutral speculative decoding configuration. + const SpeculativeDecodingConfig({ + this.strategy = SpeculativeDecodingStrategy.backendDefault, + this.draftTokenMax, + this.draftTokenMin, + this.minProbability, + }) : assert(draftTokenMax == null || draftTokenMax >= 0), + assert(draftTokenMin == null || draftTokenMin >= 0), + assert( + minProbability == null || + (minProbability >= 0.0 && minProbability <= 1.0), + ); + + /// Enables the backend's default speculative decoding behavior. + const SpeculativeDecodingConfig.backendDefault() + : strategy = SpeculativeDecodingStrategy.backendDefault, + draftTokenMax = null, + draftTokenMin = null, + minProbability = null; + + /// Enables multi-token prediction speculative decoding. + const SpeculativeDecodingConfig.mtp({ + this.draftTokenMax, + this.draftTokenMin, + this.minProbability, + }) : strategy = SpeculativeDecodingStrategy.mtp, + assert(draftTokenMax == null || draftTokenMax >= 0), + assert(draftTokenMin == null || draftTokenMin >= 0), + assert( + minProbability == null || + (minProbability >= 0.0 && minProbability <= 1.0), + ); +} + /// Parameters controlling the token sampling and generation process. class GenerationParams { /// Default prompt prefix reuse behavior for native generation. @@ -92,11 +162,22 @@ class GenerationParams { /// Enables backend-native speculative decoding when supported. /// - /// Native LiteRT-LM currently honors this flag by forwarding it to the - /// runtime's speculative decoding setting. llama.cpp, WebGPU, and LiteRT-LM - /// web reject this option until their speculative paths are implemented. + /// Native LiteRT-LM forwards this flag to the runtime's speculative decoding + /// setting. llama.cpp maps it to the backend-default speculative strategy + /// when the active model/context supports that path. WebGPU and LiteRT-LM web + /// reject this option until their runtimes expose equivalent controls. + /// + /// Prefer [speculativeDecodingConfig] for new code that needs a specific + /// strategy or runtime-neutral options. final bool speculativeDecoding; + /// Strategy and knobs for backend-native speculative decoding. + /// + /// `null` disables speculative decoding unless [speculativeDecoding] is true. + /// When [speculativeDecoding] is true and this is null, backends should treat + /// the request as [SpeculativeDecodingStrategy.backendDefault]. + final SpeculativeDecodingConfig? speculativeDecodingConfig; + /// Reuses matching prompt prefixes from previous requests in the same native /// context to reduce prompt ingestion latency. /// @@ -133,11 +214,26 @@ class GenerationParams { this.preservedTokens = const [], this.grammarRoot = 'root', this.speculativeDecoding = false, + this.speculativeDecodingConfig, this.reusePromptPrefix = defaultReusePromptPrefix, this.streamBatchTokenThreshold = defaultStreamBatchTokenThreshold, this.streamBatchByteThreshold = defaultStreamBatchByteThreshold, }); + /// Whether speculative decoding is requested by either public API shape. + bool get isSpeculativeDecodingEnabled => + speculativeDecoding || speculativeDecodingConfig != null; + + /// Resolved speculative decoding configuration, if enabled. + /// + /// Legacy [speculativeDecoding] requests resolve to backend-default + /// speculative decoding. + SpeculativeDecodingConfig? get resolvedSpeculativeDecodingConfig => + speculativeDecodingConfig ?? + (speculativeDecoding + ? const SpeculativeDecodingConfig.backendDefault() + : null); + /// Creates a copy of this [GenerationParams] with updated fields. GenerationParams copyWith({ int? maxTokens, @@ -154,6 +250,8 @@ class GenerationParams { List? preservedTokens, String? grammarRoot, bool? speculativeDecoding, + SpeculativeDecodingConfig? speculativeDecodingConfig, + bool clearSpeculativeDecodingConfig = false, bool? reusePromptPrefix, int? streamBatchTokenThreshold, int? streamBatchByteThreshold, @@ -173,6 +271,9 @@ class GenerationParams { preservedTokens: preservedTokens ?? this.preservedTokens, grammarRoot: grammarRoot ?? this.grammarRoot, speculativeDecoding: speculativeDecoding ?? this.speculativeDecoding, + speculativeDecodingConfig: clearSpeculativeDecodingConfig + ? null + : (speculativeDecodingConfig ?? this.speculativeDecodingConfig), reusePromptPrefix: reusePromptPrefix ?? this.reusePromptPrefix, streamBatchTokenThreshold: streamBatchTokenThreshold ?? this.streamBatchTokenThreshold, diff --git a/lib/src/core/models/inference/model_params.dart b/lib/src/core/models/inference/model_params.dart index 8826e466..d11e88ce 100644 --- a/lib/src/core/models/inference/model_params.dart +++ b/lib/src/core/models/inference/model_params.dart @@ -193,6 +193,15 @@ class ModelParams { /// Set to 1 to preserve single-sequence behavior. final int maxParallelSequences; + /// llama.cpp recurrent-state rollback snapshots per sequence (`n_rs_seq`). + /// + /// Set this to at least the MTP draft token max when using llama.cpp MTP + /// speculative decoding with architectures that need bounded rollback + /// snapshots, such as Qwen3.5 MTP. The default `0` preserves legacy context + /// memory use. llama.cpp may clamp this to zero for unsupported + /// architectures. + final int speculativeRollbackTokenMax; + /// `llama_model_params.use_mmap`. Default `true`. final bool useMmap; @@ -258,6 +267,7 @@ class ModelParams { this.batchSize = 0, this.microBatchSize = 0, this.maxParallelSequences = 1, + this.speculativeRollbackTokenMax = 0, this.useMmap = true, this.useMlock = false, this.flashAttention = FlashAttention.auto, @@ -292,6 +302,13 @@ class ModelParams { 'must be non-empty when provided', ); } + if (speculativeRollbackTokenMax < 0) { + throw ArgumentError.value( + speculativeRollbackTokenMax, + 'speculativeRollbackTokenMax', + 'must be non-negative', + ); + } if ((cacheTypeK != KvCacheType.f16 || cacheTypeV != KvCacheType.f16) && flashAttention == FlashAttention.disabled) { throw ArgumentError( @@ -332,6 +349,7 @@ class ModelParams { int? batchSize, int? microBatchSize, int? maxParallelSequences, + int? speculativeRollbackTokenMax, bool? useMmap, bool? useMlock, FlashAttention? flashAttention, @@ -378,6 +396,8 @@ class ModelParams { batchSize: batchSize ?? this.batchSize, microBatchSize: microBatchSize ?? this.microBatchSize, maxParallelSequences: maxParallelSequences ?? this.maxParallelSequences, + speculativeRollbackTokenMax: + speculativeRollbackTokenMax ?? this.speculativeRollbackTokenMax, useMmap: useMmap ?? this.useMmap, useMlock: useMlock ?? this.useMlock, flashAttention: flashAttention ?? this.flashAttention, diff --git a/lib/src/hook/native_bundle_config.dart b/lib/src/hook/native_bundle_config.dart index 2e202525..22ad03ae 100644 --- a/lib/src/hook/native_bundle_config.dart +++ b/lib/src/hook/native_bundle_config.dart @@ -21,6 +21,7 @@ const List defaultNativeRuntimes = [nativeRuntimeLlamaCpp]; const Set _coreLibraries = { 'llamadart', 'llama', + 'llama-common', 'ggml', 'ggml-base', 'mtmd', diff --git a/test/integration/backends/llama_cpp/native_symbol_integration_test.dart b/test/integration/backends/llama_cpp/native_symbol_integration_test.dart index 2ebff8cf..12057ad0 100644 --- a/test/integration/backends/llama_cpp/native_symbol_integration_test.dart +++ b/test/integration/backends/llama_cpp/native_symbol_integration_test.dart @@ -1,13 +1,240 @@ @TestOn('vm') library; +import 'dart:convert'; +import 'dart:ffi' as ffi; import 'dart:io'; import 'package:test/test.dart'; import 'package:llamadart/src/backends/llama_cpp/bindings.dart'; +const _llamadartWrapperAssetId = 'package:llamadart/llamadart_wrapper'; + +const _mtpSymbols = [ + 'llama_dart_mtp_init', + 'llama_dart_mtp_free', + 'llama_dart_mtp_get_draft_context', + 'llama_dart_mtp_begin', + 'llama_dart_mtp_process_batch', + 'llama_dart_mtp_draft', + 'llama_dart_mtp_accept', + 'llama_dart_sampler_sample_and_accept_n', +]; + +@ffi.Native)>( + assetId: _llamadartWrapperAssetId, + symbol: 'llama_dart_mtp_free', +) +external void _windowsMtpFree(ffi.Pointer mtp); + +File _windowsMtpWrapperLibraryFile() { + final dartToolLibPath = [ + Directory.current.path, + '.dart_tool', + 'lib', + ].join(Platform.pathSeparator); + final dartToolLibDir = Directory(dartToolLibPath); + final candidates = [ + ?_nativeAssetFilePath(_llamadartWrapperAssetId), + ..._matchingWindowsLibraryPaths( + dartToolLibDir, + RegExp(r'^llamadart(?:[-_][^.\\/]+)*\.dll$'), + ), + [dartToolLibPath, 'llamadart_wrapper.dll'].join(Platform.pathSeparator), + [dartToolLibPath, 'llamadart.dll'].join(Platform.pathSeparator), + 'llamadart_wrapper.dll', + 'llamadart.dll', + ]; + + final tried = {}; + for (final candidate in candidates) { + if (!tried.add(candidate)) { + continue; + } + + final file = File(candidate); + if (file.existsSync()) { + return file; + } + } + + throw StateError( + 'Unable to find Windows llama.cpp MTP wrapper library. ' + 'Tried: ${tried.join(', ')}.', + ); +} + +bool _fileContainsAscii(File file, String text) { + final bytes = file.readAsBytesSync(); + final pattern = ascii.encode(text); + if (pattern.isEmpty || bytes.length < pattern.length) { + return false; + } + + for (var i = 0; i <= bytes.length - pattern.length; i++) { + var matched = true; + for (var j = 0; j < pattern.length; j++) { + if (bytes[i + j] != pattern[j]) { + matched = false; + break; + } + } + if (matched) { + return true; + } + } + return false; +} + +String? _nativeAssetFilePath(String assetId) { + final configFile = File('.dart_tool/native_assets.yaml'); + if (!configFile.existsSync()) { + return null; + } + + final source = configFile + .readAsLinesSync() + .where((line) => !line.trimLeft().startsWith('#')) + .join('\n'); + final Object? decoded; + try { + decoded = jsonDecode(source); + } on FormatException { + return null; + } + if (decoded is! Map) { + return null; + } + + final nativeAssets = decoded['native-assets']; + if (nativeAssets is! Map) { + return null; + } + + for (final platformAssets in nativeAssets.values) { + if (platformAssets is! Map) { + continue; + } + final entry = platformAssets[assetId]; + if (entry is List && entry.length >= 2 && entry[0] == 'absolute') { + final filePath = entry[1]; + if (filePath is String && filePath.isNotEmpty) { + return filePath; + } + } + } + + return null; +} + +List _matchingWindowsLibraryPaths(Directory directory, RegExp regex) { + try { + return directory + .listSync() + .whereType() + .map((file) => file.path) + .where((filePath) { + final separatorIndex = filePath.lastIndexOf(Platform.pathSeparator); + final name = separatorIndex == -1 + ? filePath + : filePath.substring(separatorIndex + 1); + return regex.hasMatch(name); + }) + .toList(growable: false); + } catch (_) { + return const []; + } +} + void main() { group('Native Symbol Availability', () { + test('Verify MTP symbols are declared in generated bindings', () { + final bindingsSource = File( + 'lib/src/backends/llama_cpp/bindings.dart', + ).readAsStringSync(); + + for (final symbol in _mtpSymbols) { + expect( + bindingsSource, + matches(RegExp(r'external\s+[\s\S]*?\b' + RegExp.escape(symbol))), + reason: symbol, + ); + } + }); + + test('Verify MTP wrapper symbols are resolvable', () { + if (Platform.isWindows) { + expect(() => llama_context_default_params(), returnsNormally); + expect( + () => _windowsMtpFree(ffi.nullptr.cast()), + returnsNormally, + ); + final wrapper = _windowsMtpWrapperLibraryFile(); + for (final symbol in _mtpSymbols) { + expect(_fileContainsAscii(wrapper, symbol), isTrue, reason: symbol); + } + return; + } + + final nullMtp = ffi.nullptr.cast(); + final nullModel = ffi.nullptr.cast(); + final nullContext = ffi.nullptr.cast(); + final nullSampler = ffi.nullptr.cast(); + final nullTokenArray = ffi.nullptr.cast(); + final ctxParams = llama_context_default_params(); + + expect( + llama_dart_mtp_init( + nullModel, + nullContext, + ctxParams, + 1, + 0, + 0.0, + true, + ).address, + 0, + ); + expect(() => llama_dart_mtp_free(nullMtp), returnsNormally); + expect(llama_dart_mtp_get_draft_context(nullMtp).address, 0); + expect(llama_dart_mtp_begin(nullMtp, 0, nullTokenArray, 0), isFalse); + expect( + llama_dart_mtp_draft( + nullMtp, + 0, + 0, + 0, + nullTokenArray, + 0, + 1, + nullTokenArray, + 0, + ), + -1, + ); + expect(() => llama_dart_mtp_accept(nullMtp, 0, 0), returnsNormally); + expect( + llama_dart_sampler_sample_and_accept_n( + nullSampler, + nullContext, + nullTokenArray, + 0, + nullTokenArray, + 0, + nullTokenArray, + 0, + ), + -1, + ); + + final batch = llama_batch_init(1, 0, 1); + try { + expect(llama_dart_mtp_process_batch(nullMtp, batch), isFalse); + } finally { + llama_batch_free(batch); + } + }); + test('Verify multimodal symbols are resolvable', () { // Some bundles export mtmd via the primary llama asset while others ship // it as a dedicated mtmd shared library loaded via runtime fallback. diff --git a/test/unit/backends/litert_lm/litert_lm_backend_web_test.dart b/test/unit/backends/litert_lm/litert_lm_backend_web_test.dart index 01021a37..5402d3c4 100644 --- a/test/unit/backends/litert_lm/litert_lm_backend_web_test.dart +++ b/test/unit/backends/litert_lm/litert_lm_backend_web_test.dart @@ -295,6 +295,41 @@ void main() { } }); + test('rejects speculative decoding config on LiteRT-LM web', () async { + _installFakeEngine(chunks: []); + + final backend = LiteRtLmBackend(); + try { + final modelHandle = await backend.modelLoadFromUrl( + 'https://example.com/model.litertlm', + const ModelParams(), + ); + final contextHandle = await backend.contextCreate( + modelHandle, + const ModelParams(), + ); + + await expectLater( + backend.generate( + contextHandle, + 'hello', + const GenerationParams( + speculativeDecodingConfig: SpeculativeDecodingConfig.mtp(), + ), + ), + emitsError( + isA().having( + (error) => error.message.toString(), + 'message', + contains('speculativeDecodingConfig'), + ), + ), + ); + } finally { + await backend.dispose(); + } + }); + test('rejects unsupported context-time model params', () async { _installFakeEngine(chunks: []); diff --git a/test/unit/backends/litert_lm/litert_lm_service_test.dart b/test/unit/backends/litert_lm/litert_lm_service_test.dart index eac59502..58f2729b 100644 --- a/test/unit/backends/litert_lm/litert_lm_service_test.dart +++ b/test/unit/backends/litert_lm/litert_lm_service_test.dart @@ -868,6 +868,39 @@ void main() { } }); + test('passes LiteRT-LM MTP config as speculative decoding', () async { + final fakeClient = _FakeLiteRtLmRuntimeClient(); + final service = LiteRtLmService(clientFactory: () => fakeClient); + + try { + final modelHandle = await service.loadModel( + modelFile.path, + const ModelParams(preferredBackend: GpuBackend.cpu), + ); + final contextHandle = service.createContext( + modelHandle, + const ModelParams(preferredBackend: GpuBackend.cpu), + ); + + final subscription = service + .generate( + contextHandle, + 'hello', + const GenerationParams( + maxTokens: 7, + speculativeDecodingConfig: SpeculativeDecodingConfig.mtp(), + ), + ) + .listen((_) {}); + + await fakeClient.generateStarted.future; + expect(fakeClient.lastSpeculativeDecoding, isTrue); + unawaited(subscription.cancel()); + } finally { + service.dispose(); + } + }); + test('passes supported LiteRT-LM generation options to the client', () async { final fakeClient = _FakeLiteRtLmRuntimeClient(); final service = LiteRtLmService(clientFactory: () => fakeClient); @@ -1620,6 +1653,25 @@ void main() { ), ), ); + + await expectLater( + service.generate( + contextHandle, + 'hello', + const GenerationParams( + speculativeDecodingConfig: SpeculativeDecodingConfig.mtp( + draftTokenMax: 3, + ), + ), + ), + emitsError( + isA().having( + (error) => error.message.toString(), + 'message', + contains('speculativeDecodingConfig.draftTokenMax'), + ), + ), + ); } finally { service.dispose(); } diff --git a/test/unit/backends/llama_cpp/llama_cpp_service_test.dart b/test/unit/backends/llama_cpp/llama_cpp_service_test.dart index 2c76b68e..0098ce4e 100644 --- a/test/unit/backends/llama_cpp/llama_cpp_service_test.dart +++ b/test/unit/backends/llama_cpp/llama_cpp_service_test.dart @@ -32,6 +32,47 @@ void main() { }); }); + group('Android Vulkan MTP guard', () { + test('rejects Android Vulkan with GPU layers by default', () { + expect( + LlamaCppService.shouldRejectAndroidVulkanMtp( + 'Vulkan', + resolvedGpuLayers: 999, + isAndroid: true, + ), + isTrue, + ); + }); + + test('allows CPU and explicit debug override', () { + expect( + LlamaCppService.shouldRejectAndroidVulkanMtp( + 'CPU', + resolvedGpuLayers: 0, + isAndroid: true, + ), + isFalse, + ); + expect( + LlamaCppService.shouldRejectAndroidVulkanMtp( + 'Vulkan', + resolvedGpuLayers: 999, + isAndroid: true, + allowMtp: true, + ), + isFalse, + ); + expect( + LlamaCppService.shouldRejectAndroidVulkanMtp( + 'Vulkan', + resolvedGpuLayers: 999, + isAndroid: false, + ), + isFalse, + ); + }); + }); + group('loadModel preflight validation', () { late Directory tempDir; @@ -101,25 +142,41 @@ void main() { ); }); - test('generate rejects speculative decoding', () async { - expect( - service - .generate( - -1, - 'hello', - const GenerationParams(speculativeDecoding: true), - 0, - ) - .drain(), - throwsA( - isA().having( - (error) => error.message.toString(), - 'message', - contains('speculative decoding'), - ), - ), - ); - }); + test( + 'generate reports unknown context before speculative decoding', + () async { + expect( + service + .generate( + -1, + 'hello', + const GenerationParams(speculativeDecoding: true), + 0, + ) + .drain(), + throwsA(isA()), + ); + }, + ); + + test( + 'generate reports unknown context before speculative config', + () async { + expect( + service + .generate( + -1, + 'hello', + const GenerationParams( + speculativeDecodingConfig: SpeculativeDecodingConfig.mtp(), + ), + 0, + ) + .drain(), + throwsA(isA()), + ); + }, + ); test('embed and embedBatch throw for unknown context handle', () { expect(() => service.embed(-1, 'hello'), throwsA(isA())); diff --git a/test/unit/backends/webgpu/webgpu_backend_test.dart b/test/unit/backends/webgpu/webgpu_backend_test.dart index 27c74e87..52f2db7a 100644 --- a/test/unit/backends/webgpu/webgpu_backend_test.dart +++ b/test/unit/backends/webgpu/webgpu_backend_test.dart @@ -802,6 +802,25 @@ void main() { ); }); + test('rejects speculative decoding config', () { + expect( + () => backend.generate( + 1, + 'Hello', + const GenerationParams( + speculativeDecodingConfig: SpeculativeDecodingConfig.mtp(), + ), + ), + throwsA( + isA().having( + (error) => error.message.toString(), + 'message', + contains('speculative decoding'), + ), + ), + ); + }); + test( 'canceling generation subscription aborts active bridge completion', () async { diff --git a/test/unit/core/models/inference/generation_params_test.dart b/test/unit/core/models/inference/generation_params_test.dart index 53a6fbdd..a3670a8e 100644 --- a/test/unit/core/models/inference/generation_params_test.dart +++ b/test/unit/core/models/inference/generation_params_test.dart @@ -10,6 +10,9 @@ void main() { grammarRoot: 'main', grammarLazy: true, speculativeDecoding: true, + speculativeDecodingConfig: const SpeculativeDecodingConfig.mtp( + draftTokenMax: 3, + ), reusePromptPrefix: false, streamBatchTokenThreshold: 4, streamBatchByteThreshold: 256, @@ -26,6 +29,12 @@ void main() { expect(updated.grammarRoot, 'main'); expect(updated.grammarLazy, isTrue); expect(updated.speculativeDecoding, isTrue); + expect(updated.isSpeculativeDecodingEnabled, isTrue); + expect( + updated.resolvedSpeculativeDecodingConfig?.strategy, + SpeculativeDecodingStrategy.mtp, + ); + expect(updated.resolvedSpeculativeDecodingConfig?.draftTokenMax, 3); expect(updated.reusePromptPrefix, isFalse); expect(updated.streamBatchTokenThreshold, 4); expect(updated.streamBatchByteThreshold, 256); @@ -38,6 +47,9 @@ void main() { expect(params.minP, 0.0); expect(params.speculativeDecoding, isFalse); + expect(params.speculativeDecodingConfig, isNull); + expect(params.isSpeculativeDecodingEnabled, isFalse); + expect(params.resolvedSpeculativeDecodingConfig, isNull); }); test('GenerationParams defaults stream batching thresholds', () { @@ -47,4 +59,26 @@ void main() { expect(params.streamBatchTokenThreshold, 8); expect(params.streamBatchByteThreshold, 512); }); + + test('GenerationParams resolves legacy speculative decoding as default', () { + const params = GenerationParams(speculativeDecoding: true); + + expect(params.isSpeculativeDecodingEnabled, isTrue); + expect( + params.resolvedSpeculativeDecodingConfig?.strategy, + SpeculativeDecodingStrategy.backendDefault, + ); + }); + + test('GenerationParams copyWith can clear speculative decoding config', () { + const params = GenerationParams( + speculativeDecodingConfig: SpeculativeDecodingConfig.mtp( + draftTokenMax: 3, + ), + ); + final updated = params.copyWith(clearSpeculativeDecodingConfig: true); + + expect(updated.speculativeDecodingConfig, isNull); + expect(updated.isSpeculativeDecodingEnabled, isFalse); + }); } diff --git a/test/unit/core/models/inference/model_params_test.dart b/test/unit/core/models/inference/model_params_test.dart index 15182ede..2c13f0eb 100644 --- a/test/unit/core/models/inference/model_params_test.dart +++ b/test/unit/core/models/inference/model_params_test.dart @@ -24,6 +24,7 @@ void main() { expect(params.batchSize, 0); expect(params.microBatchSize, 0); expect(params.maxParallelSequences, 1); + expect(params.speculativeRollbackTokenMax, 0); expect(params.useMmap, isTrue); expect(params.useMlock, isFalse); expect(params.flashAttention, FlashAttention.auto); @@ -50,6 +51,7 @@ void main() { batchSize: 256, microBatchSize: 64, maxParallelSequences: 8, + speculativeRollbackTokenMax: 3, ); expect(updated.contextSize, 1024); @@ -65,6 +67,7 @@ void main() { expect(updated.batchSize, 256); expect(updated.microBatchSize, 64); expect(updated.maxParallelSequences, 8); + expect(updated.speculativeRollbackTokenMax, 3); }); test('ModelParams exposes load-time tuning knobs', () { @@ -130,6 +133,7 @@ void main() { batchSize: 512, microBatchSize: 128, maxParallelSequences: 4, + speculativeRollbackTokenMax: 3, useMmap: false, useMlock: true, flashAttention: FlashAttention.enabled, @@ -161,6 +165,7 @@ void main() { expect(updated.batchSize, 512); expect(updated.microBatchSize, 128); expect(updated.maxParallelSequences, 4); + expect(updated.speculativeRollbackTokenMax, 3); expect(updated.useMmap, isFalse); expect(updated.useMlock, isTrue); expect(updated.flashAttention, FlashAttention.enabled); @@ -171,6 +176,13 @@ void main() { expect(updated.ropeFrequencyScale, 0.5); }); + group('validate(): speculative rollback settings', () { + test('negative speculativeRollbackTokenMax throws ArgumentError', () { + const p = ModelParams(speculativeRollbackTokenMax: -1); + expect(p.validate, throwsArgumentError); + }); + }); + group('validate(): non-F16 KV requires flash attention', () { test('q8_0 K + flashAttention disabled throws ArgumentError', () { const p = ModelParams( diff --git a/test/unit/hook/build_hook_android_integration_test.dart b/test/unit/hook/build_hook_android_integration_test.dart index 9507c6be..5051c501 100644 --- a/test/unit/hook/build_hook_android_integration_test.dart +++ b/test/unit/hook/build_hook_android_integration_test.dart @@ -98,6 +98,7 @@ void main() { expect(emittedNames, contains('libllamadart.so')); expect(emittedNames, contains('libllama.so')); + expect(emittedNames, contains('libllama-common.so')); expect(emittedNames, contains('libggml.so')); expect(emittedNames, contains('libggml-base.so')); expect(emittedNames, contains('libggml-vulkan.so')); @@ -219,6 +220,7 @@ const List _androidCpuVariantLibraries = [ const List _androidArm64Libraries = [ 'libllamadart.so', 'libllama.so', + 'libllama-common.so', 'libggml.so', 'libggml-base.so', 'libggml-vulkan.so', diff --git a/test/unit/hook/build_hook_linux_integration_test.dart b/test/unit/hook/build_hook_linux_integration_test.dart index b4f9dc3f..e8fe81b1 100644 --- a/test/unit/hook/build_hook_linux_integration_test.dart +++ b/test/unit/hook/build_hook_linux_integration_test.dart @@ -47,6 +47,7 @@ void main() { await _writeBundleLibraries(bundleDir, const [ 'libllamadart.so', 'libllama.so', + 'libllama-common.so', 'libggml.so', 'libggml-base.so', 'libggml-cpu.so', @@ -93,6 +94,8 @@ void main() { expect(emittedNames, contains('libllamadart.so')); expect(emittedNames, contains('libllama.so')); expect(emittedNames, contains('libllama.so.0')); + expect(emittedNames, contains('libllama-common.so')); + expect(emittedNames, contains('libllama-common.so.0')); expect(emittedNames, contains('libggml.so')); expect(emittedNames, contains('libggml.so.0')); expect(emittedNames, contains('libggml-base.so')); diff --git a/test/unit/hook/native_bundle_config_test.dart b/test/unit/hook/native_bundle_config_test.dart index d5865b6c..7bdcfa6b 100644 --- a/test/unit/hook/native_bundle_config_test.dart +++ b/test/unit/hook/native_bundle_config_test.dart @@ -95,6 +95,14 @@ void main() { expect(descriptor.backend, isNull); }); + test('classifies llama common as a core runtime library', () { + final descriptor = describeNativeLibrary('/tmp/libllama-common.so'); + + expect(descriptor.canonicalName, 'llama-common'); + expect(descriptor.isCore, isTrue); + expect(descriptor.backend, isNull); + }); + test('normalizes Linux SONAME suffix for ggml base library', () { final descriptor = describeNativeLibrary('/tmp/libggml-base.so.1'); @@ -340,6 +348,7 @@ void main() { final libraries = [ describeNativeLibrary('/tmp/libllamadart.so'), describeNativeLibrary('/tmp/libllama.so'), + describeNativeLibrary('/tmp/libllama-common.so'), describeNativeLibrary('/tmp/libggml.so'), describeNativeLibrary('/tmp/libggml-base.so'), describeNativeLibrary('/tmp/libggml-cpu.so'), @@ -359,6 +368,7 @@ void main() { expect(selectedNames, contains('ggml-cpu')); expect(selectedNames, contains('ggml-vulkan')); expect(selectedNames, isNot(contains('ggml-opencl'))); + expect(selectedNames, contains('llama-common')); }); test('uses requested backend when available', () { @@ -556,6 +566,7 @@ void main() { final libraries = [ describeNativeLibrary('/tmp/libllamadart.so'), describeNativeLibrary('/tmp/libllama.so'), + describeNativeLibrary('/tmp/libllama-common.so'), describeNativeLibrary('/tmp/libggml.so'), describeNativeLibrary('/tmp/libggml-base.so'), describeNativeLibrary('/tmp/libggml-vulkan.so'),