diff --git a/tools/build_upstream_runtime.py b/tools/build_upstream_runtime.py index 9a5adae..9946e9a 100644 --- a/tools/build_upstream_runtime.py +++ b/tools/build_upstream_runtime.py @@ -81,6 +81,8 @@ REQUIRED_C_API_SYMBOLS = [ b"litert_lm_engine_settings_create", b"litert_lm_engine_create", + b"litert_lm_session_config_set_lora_file", + b"litert_lm_session_config_set_audio_lora_file", b"litert_lm_conversation_create", b"litert_lm_conversation_send_message_stream", ] @@ -161,6 +163,452 @@ def download_upstream(tag: str, work_dir: Path) -> Path: return candidates[0] +def _replace_once(text: str, old: str, new: str, path: Path) -> str: + if old not in text: + if new in text: + return text + raise RuntimeError(f"Could not find patch target in {path}: {old[:80]!r}") + return text.replace(old, new, 1) + + +def _insert_before_once(text: str, marker: str, insertion: str, path: Path) -> str: + if insertion + marker in text: + return text + if marker not in text: + raise RuntimeError(f"Could not find insertion point in {path}: {marker[:80]!r}") + return text.replace(marker, insertion + marker, 1) + + +def _insert_after_once(text: str, marker: str, insertion: str, path: Path) -> str: + if marker + insertion in text: + return text + if marker not in text: + raise RuntimeError(f"Could not find insertion point in {path}: {marker[:80]!r}") + return text.replace(marker, marker + insertion, 1) + + +def patch_upstream_lora_support(source_root: Path) -> None: + """Enable LiteRT-LM text LoRA through the C ABI.""" + c_build_path = source_root / "c" / "BUILD" + text = c_build_path.read_text(encoding="utf-8") + text = _insert_after_once( + text, + ' "//runtime/proto:token_cc_proto",\n', + ' "//runtime/util:scoped_file",\n', + c_build_path, + ) + c_build_path.write_text(text, encoding="utf-8") + + engine_h_path = source_root / "c" / "engine.h" + text = engine_h_path.read_text(encoding="utf-8") + text = _insert_after_once( + text, + """void litert_lm_session_config_set_sampler_params( + LiteRtLmSessionConfig* config, const LiteRtLmSamplerParams* sampler_params); + +""", + """// Sets a LoRA adapter file for this session config. +// @param config The config to modify. +// @param path Path to a LiteRT-LM LoRA adapter file. +// @return true when the file is opened and attached to the session config. +LITERT_LM_C_API_EXPORT +bool litert_lm_session_config_set_lora_file(LiteRtLmSessionConfig* config, + const char* path); + +// Sets an audio LoRA adapter file for this session config. +// @param config The config to modify. +// @param path Path to a LiteRT-LM audio LoRA adapter file. +// @return true when the file is opened and attached to the session config. +LITERT_LM_C_API_EXPORT +bool litert_lm_session_config_set_audio_lora_file( + LiteRtLmSessionConfig* config, const char* path); + +""", + engine_h_path, + ) + engine_h_path.write_text(text, encoding="utf-8") + + engine_cc_path = source_root / "c" / "engine.cc" + text = engine_cc_path.read_text(encoding="utf-8") + text = _insert_after_once( + text, + '#include "runtime/util/logging.h"\n', + '#include "runtime/util/scoped_file.h"\n', + engine_cc_path, + ) + text = _insert_before_once( + text, + "void litert_lm_session_config_delete(LiteRtLmSessionConfig* config) {\n", + """bool litert_lm_session_config_set_lora_file( + LiteRtLmSessionConfig* config, const char* path) { + if (!config || !config->config || !path || path[0] == '\\0') { + ABSL_LOG(ERROR) << "Invalid session config or LoRA path"; + return false; + } + auto scoped_file = litert::lm::ScopedFile::Open(path); + if (!scoped_file.ok()) { + ABSL_LOG(ERROR) << "Failed to open LoRA file: " + << scoped_file.status(); + return false; + } + config->config->SetScopedLoraFile( + std::make_shared(std::move(*scoped_file))); + return true; +} + +bool litert_lm_session_config_set_audio_lora_file( + LiteRtLmSessionConfig* config, const char* path) { + if (!config || !config->config || !path || path[0] == '\\0') { + ABSL_LOG(ERROR) << "Invalid session config or audio LoRA path"; + return false; + } + auto scoped_file = litert::lm::ScopedFile::Open(path); + if (!scoped_file.ok()) { + ABSL_LOG(ERROR) << "Failed to open audio LoRA file: " + << scoped_file.status(); + return false; + } + config->config->SetAudioScopedLoraFile( + std::make_shared(std::move(*scoped_file))); + return true; +} + +""", + engine_cc_path, + ) + engine_cc_path.write_text(text, encoding="utf-8") + + executor_build_path = source_root / "runtime" / "executor" / "BUILD" + text = executor_build_path.read_text(encoding="utf-8") + text = _replace_once( + text, + """ "//runtime/components:model_resources", + "//runtime/components:model_resources_litert_lm", + "//runtime/components:model_resources_task", + "//runtime/components:sampler", +""", + """ "//runtime/components:model_resources", + "//runtime/components:model_resources_litert_lm", + "//runtime/components:model_resources_task", + "//runtime/components:lora_manager", + "//runtime/components:sampler", +""", + executor_build_path, + ) + executor_build_path.write_text(text, encoding="utf-8") + + executor_base_path = ( + source_root / "runtime" / "executor" / "llm_executor_base.h" + ) + text = executor_base_path.read_text(encoding="utf-8") + text = _insert_after_once( + text, + "namespace litert::lm {\n\n", + "class ModelAssets;\n\n", + executor_base_path, + ) + text = _insert_before_once( + text, + """ // Resets all of the internal states (e.g. KVCache). Loaded and used LoRA + // models are not affected (remain loaded and in use). +""", + """ // Loads a LoRA model into executor-owned resources. + virtual absl::Status LoadLoRA(uint32_t lora_id, + const ModelAssets& model_assets) { + return absl::UnimplementedError(absl::StrCat( + "LoadLoRA not implemented for backend: ", ExecutorBackendName())); + } + + // Selects the LoRA model to use for the active context. Passing nullopt + // clears the active LoRA for base-model sessions. + virtual absl::Status UseLoRA(std::optional lora_id) { + if (!lora_id.has_value()) { + return absl::OkStatus(); + } + return absl::UnimplementedError(absl::StrCat( + "UseLoRA not implemented for backend: ", ExecutorBackendName())); + } + +""", + executor_base_path, + ) + executor_base_path.write_text(text, encoding="utf-8") + + compiled_h_path = ( + source_root + / "runtime" + / "executor" + / "llm_litert_compiled_model_executor.h" + ) + text = compiled_h_path.read_text(encoding="utf-8") + text = _insert_after_once( + text, + '#include "runtime/components/embedding_lookup/embedding_lookup_manager.h"\n', + '#include "runtime/components/lora_manager.h"\n', + compiled_h_path, + ) + text = _insert_after_once( + text, + """ absl::Status RestoreContext( + std::unique_ptr context_data) override; + +""", + """ absl::Status LoadLoRA(uint32_t lora_id, + const ModelAssets& model_assets) override; + + absl::Status UseLoRA(std::optional lora_id) override; + +""", + compiled_h_path, + ) + text = _insert_after_once( + text, + """ absl::Status BindTensorsAndRunDecode(TensorBuffer* output_logits); +""", + """ absl::Status AppendActiveLoraInputBuffers( + absl::flat_hash_map& input_buffers); +""", + compiled_h_path, + ) + text = _insert_before_once( + text, + """ // The MTP drafter model. + std::unique_ptr mtp_drafter_; +""", + """ std::unique_ptr lora_manager_; + std::optional active_lora_id_; + +""", + compiled_h_path, + ) + compiled_h_path.write_text(text, encoding="utf-8") + + compiled_cc_path = ( + source_root + / "runtime" + / "executor" + / "llm_litert_compiled_model_executor.cc" + ) + text = compiled_cc_path.read_text(encoding="utf-8") + text = _replace_once( + text, + """absl::Status LlmLiteRtCompiledModelExecutorBase::BindTensorsAndRunDecode( + TensorBuffer* output_logits) { + absl::flat_hash_map decode_input_buffers; + for (const auto& [input_name, input_buffer] : decode_input_buffers_) { + LITERT_ASSIGN_OR_RETURN(auto input_buffer_dup, input_buffer.Duplicate()); + decode_input_buffers[input_name] = std::move(input_buffer_dup); + } + for (const auto& [input_name, input_buffer] : *input_kv_cache_buffers_) { + LITERT_ASSIGN_OR_RETURN(auto input_buffer_dup, input_buffer.Duplicate()); + decode_input_buffers[input_name] = std::move(input_buffer_dup); + } + absl::flat_hash_map decode_output_buffers; +""", + """absl::Status LlmLiteRtCompiledModelExecutorBase::BindTensorsAndRunDecode( + TensorBuffer* output_logits) { + absl::flat_hash_map decode_input_buffers; + for (const auto& [input_name, input_buffer] : decode_input_buffers_) { + LITERT_ASSIGN_OR_RETURN(auto input_buffer_dup, input_buffer.Duplicate()); + decode_input_buffers[input_name] = std::move(input_buffer_dup); + } + for (const auto& [input_name, input_buffer] : *input_kv_cache_buffers_) { + LITERT_ASSIGN_OR_RETURN(auto input_buffer_dup, input_buffer.Duplicate()); + decode_input_buffers[input_name] = std::move(input_buffer_dup); + } + RETURN_IF_ERROR(AppendActiveLoraInputBuffers(decode_input_buffers)); + absl::flat_hash_map decode_output_buffers; +""", + compiled_cc_path, + ) + text = _insert_before_once( + text, + "int LlmLiteRtCompiledModelExecutorBase::BindTensorsAndRunDecodeStatic(\n", + """absl::Status LlmLiteRtCompiledModelExecutorBase::AppendActiveLoraInputBuffers( + absl::flat_hash_map& input_buffers) { + if (!active_lora_id_.has_value()) { + return absl::OkStatus(); + } + RET_CHECK_NE(lora_manager_, nullptr) << "LoRA manager is not initialized."; + ASSIGN_OR_RETURN(auto lora_buffers, lora_manager_->GetLoRABuffers()); + for (auto& [input_name, input_buffer] : lora_buffers) { + input_buffers[input_name] = std::move(input_buffer); + } + return absl::OkStatus(); +} + +""", + compiled_cc_path, + ) + text = _replace_once( + text, + """absl::StatusOr> +LlmLiteRtCompiledModelExecutorBase::CloneContext() const { + std::optional lora_id; + ASSIGN_OR_RETURN(auto kv_cache_buffers, CloneKVCacheBuffers()); +""", + """absl::StatusOr> +LlmLiteRtCompiledModelExecutorBase::CloneContext() const { + std::optional lora_id = + llm_context_->processed_context().lora_id(); + ASSIGN_OR_RETURN(auto kv_cache_buffers, CloneKVCacheBuffers()); +""", + compiled_cc_path, + ) + text = _insert_before_once( + text, + """absl::Status LlmLiteRtCompiledModelExecutorBase::RestoreContext( + std::unique_ptr context_data) { + llm_context_ = std::move(context_data); +""", + """absl::Status LlmLiteRtCompiledModelExecutorBase::LoadLoRA( + uint32_t lora_id, const ModelAssets& model_assets) { + if (lora_manager_ == nullptr) { + ASSIGN_OR_RETURN(lora_manager_, + LoraManager::Create(*compiled_model_, + kDecodeSignatureRunner)); + } + return lora_manager_->LoadLoRA(lora_id, model_assets); +} + +absl::Status LlmLiteRtCompiledModelExecutorBase::UseLoRA( + std::optional lora_id) { + if (!lora_id.has_value()) { + active_lora_id_ = std::nullopt; + return absl::OkStatus(); + } + if (lora_manager_ == nullptr) { + return absl::FailedPreconditionError("LoRA manager is not initialized."); + } + RETURN_IF_ERROR(lora_manager_->UseLoRA(*lora_id)); + active_lora_id_ = lora_id; + return absl::OkStatus(); +} + +""", + compiled_cc_path, + ) + text = _replace_once( + text, + """absl::Status LlmLiteRtCompiledModelExecutorBase::RestoreContext( + std::unique_ptr context_data) { + llm_context_ = std::move(context_data); + + // We can keep our kv cache buffers if this is the first step. This lets us +""", + """absl::Status LlmLiteRtCompiledModelExecutorBase::RestoreContext( + std::unique_ptr context_data) { + llm_context_ = std::move(context_data); + RETURN_IF_ERROR(UseLoRA(llm_context_->processed_context().lora_id())); + + // We can keep our kv cache buffers if this is the first step. This lets us +""", + compiled_cc_path, + ) + compiled_cc_path.write_text(text, encoding="utf-8") + + resource_manager_h_path = ( + source_root + / "runtime" + / "framework" + / "resource_management" + / "resource_manager.h" + ) + text = resource_manager_h_path.read_text(encoding="utf-8") + text = _insert_after_once( + text, + '#include "absl/container/flat_hash_map.h" // from @com_google_absl\n', + '#include "absl/container/flat_hash_set.h" // from @com_google_absl\n', + resource_manager_h_path, + ) + text = _insert_after_once( + text, + """ absl::flat_hash_map lora_hash_to_id_; + +""", + """ absl::flat_hash_set loaded_lora_ids_; + +""", + resource_manager_h_path, + ) + resource_manager_h_path.write_text(text, encoding="utf-8") + + resource_manager_cc_path = ( + source_root + / "runtime" + / "framework" + / "resource_management" + / "resource_manager.cc" + ) + text = resource_manager_cc_path.read_text(encoding="utf-8") + text = _insert_after_once( + text, + """ absl::Status RestoreContext( + std::unique_ptr llm_context) override { + return llm_executor_->RestoreContext(std::move(llm_context)); + } + +""", + """ absl::Status LoadLoRA(uint32_t lora_id, + const ModelAssets& model_assets) override { + return llm_executor_->LoadLoRA(lora_id, model_assets); + } + + absl::Status UseLoRA(std::optional lora_id) override { + return llm_executor_->UseLoRA(lora_id); + } + +""", + resource_manager_cc_path, + ) + text = _replace_once( + text, + """ // TODO: b/462499294 - + // 1. Check if lora is loaded or not. + // 2. Get the lora id. + // 3. If lora is not loaded, load the lora. + + // Check if the lora is already loaded. + // TODO: b/462499294 - Use the real lora path. + bool lora_is_loaded = + lora_hash_to_id_.find("fake_lora_path") != lora_hash_to_id_.end(); + + // Find the lora id. If lora_id is not nullopt, it means the lora is used. + std::optional lora_id = AssignLoraId( + /*lora_path=*/"", + /*has_scoped_lora_file=*/session_config.GetScopedLoraFile() != nullptr); + + // If lora is used and not loaded, load the lora. + if (lora_id.has_value() && !lora_is_loaded) { + RET_CHECK(session_config.GetScopedLoraFile() != nullptr); + ASSIGN_OR_RETURN(ModelAssets model_assets, + ModelAssets::Create(session_config.GetScopedLoraFile(), + /*model_path=*/"")); + return absl::InvalidArgumentError("Lora is not supported."); + } +""", + """ // Find the lora id. If lora_id is not nullopt, it means the lora is used. + std::optional lora_id = AssignLoraId( + /*lora_path=*/"", + /*has_scoped_lora_file=*/session_config.GetScopedLoraFile() != nullptr); + + if (lora_id.has_value()) { + RET_CHECK(session_config.GetScopedLoraFile() != nullptr); + ASSIGN_OR_RETURN(ModelAssets model_assets, + ModelAssets::Create(session_config.GetScopedLoraFile(), + /*model_path=*/"")); + MovableMutexLock lock(&executor_mutex_); + if (!loaded_lora_ids_.contains(*lora_id)) { + RETURN_IF_ERROR(llm_executor_->LoadLoRA(*lora_id, model_assets)); + loaded_lora_ids_.insert(*lora_id); + } + } +""", + resource_manager_cc_path, + ) + resource_manager_cc_path.write_text(text, encoding="utf-8") + + def patch_upstream_build(source_root: Path) -> None: stream_proxy_target = source_root / "c" / "stream_proxy.c" shutil.copy2(STREAM_PROXY_SOURCE, stream_proxy_target) @@ -217,6 +665,8 @@ def patch_upstream_build(source_root: Path) -> None: encoding="utf-8", ) + patch_upstream_lora_support(source_root) + def bazel_command() -> list[str]: if shutil.which("bazelisk"):