From 11e460b47a25b149d9ea42afa45ede9b2182eeb6 Mon Sep 17 00:00:00 2001 From: AI Edge Eval Team Date: Wed, 24 Jun 2026 20:39:07 -0700 Subject: [PATCH] Add activation data type for litert-lm runner PiperOrigin-RevId: 937717527 --- model_eval/runners/litert_lm/litert_lm.py | 20 +++ model_eval/tests/unit/pipeline_test.py | 1 + .../unit/runners/litert_lm/litert_lm_test.py | 136 ++++++++++++++++++ 3 files changed, 157 insertions(+) diff --git a/model_eval/runners/litert_lm/litert_lm.py b/model_eval/runners/litert_lm/litert_lm.py index 4135d25..a36d5b9 100644 --- a/model_eval/runners/litert_lm/litert_lm.py +++ b/model_eval/runners/litert_lm/litert_lm.py @@ -93,6 +93,20 @@ def _parse_backend(backend_str: str) -> litert_lm.Backend: ) +def _parse_activation_data_type( + activation_data_type: str, +) -> litert_lm.ActivationDataType | None: + """Parses a string activation data type to the litert_lm.ActivationDataType.""" + valid_types = ["fp32", "fp16"] + if activation_data_type not in valid_types: + raise ValueError( + f"Unsupported activation data type: '{activation_data_type}'. " + f"Must be one of {valid_types} or a " + "litert_lm.ActivationDataType instance." + ) + return litert_lm.ActivationDataType.from_str(activation_data_type) + + def _clamp_log_severity(severity: int) -> litert_lm.LogSeverity: """Clamps a given log severity level to the supported litert_lm.LogSeverity.""" if severity <= 0: @@ -125,6 +139,8 @@ class Config(base.RunnerConfig): max_num_tokens: int = 4096 # Whether to enable speculative decoding. enable_speculative_decoding: bool | None = None + # Optional activation data type to use for the model (e.g. 'fp32', 'fp16'). + activation_data_type: str | None = None # Host for the runner's server. host: str = "127.0.0.1" # Port for the runner's server. @@ -192,6 +208,10 @@ def start(self) -> None: engine_kwargs["enable_speculative_decoding"] = ( self._config.enable_speculative_decoding ) + if self._config.activation_data_type is not None: + engine_kwargs["activation_data_type"] = _parse_activation_data_type( + self._config.activation_data_type + ) # Resolve the model path (download from HuggingFace if necessary). path = _resolve_model_path(self._config.model_path) diff --git a/model_eval/tests/unit/pipeline_test.py b/model_eval/tests/unit/pipeline_test.py index 7f08b19..abeffe8 100644 --- a/model_eval/tests/unit/pipeline_test.py +++ b/model_eval/tests/unit/pipeline_test.py @@ -281,6 +281,7 @@ def test_run_produces_jsonl_files( "host": "127.0.0.1", "max_num_tokens": 4096, "enable_speculative_decoding": None, + "activation_data_type": None, "model_name": "litert-lm-model", "min_log_severity": 1000, "always_return_not_greedy": True, diff --git a/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py b/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py index 147fc19..9a9b1ae 100644 --- a/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py +++ b/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py @@ -334,6 +334,142 @@ def test_absolute_path_resolution( max_num_tokens=mock.ANY, ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.litert_lm.set_min_log_severity" + ) + @mock.patch("model_eval.runners.base.requests.post") + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.threading.Thread" + ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.uvicorn.Server" + ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.uvicorn.Config" + ) + @mock.patch( + "model_eval.runners.litert_lm._litert_lm_server.build_app" + ) + @mock.patch( + "model_eval.runners.litert_lm._litert_lm_server.wait_for_server" + ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.litert_lm.Engine" + ) + def test_initialization_with_activation_data_type_string( + self, + mock_engine, + mock_wait_for_server, + mock_build_app, + mock_uvicorn_config, + mock_uvicorn_server, + mock_thread, + mock_post, + mock_set_min_log_severity, + ): + mock_post.return_value.json.return_value = { + "choices": [{"score": 0.9, "logprobs": []}] + } + config = litert_lm.LiteRtLmRunner.Config( + runner_type="litert-lm", + model_path="/path/to/model", + model_name="my-test-model", + backend="cpu", + activation_data_type="fp16", + ) + runner = litert_lm.LiteRtLmRunner(config) + with mock.patch( + "model_eval.runners.litert_lm.litert_lm.os.path.exists", + return_value=True, + ): + runner.start() + mock_engine.assert_called_once_with( + "/path/to/model", + backend=litert_lm._resolve_backend(litert_lm.litert_lm.Backend.CPU), + max_num_tokens=4096, + activation_data_type=litert_lm.litert_lm.ActivationDataType.FLOAT16, + ) + runner.stop() + + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.litert_lm.set_min_log_severity" + ) + @mock.patch("model_eval.runners.base.requests.post") + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.threading.Thread" + ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.uvicorn.Server" + ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.uvicorn.Config" + ) + @mock.patch( + "model_eval.runners.litert_lm._litert_lm_server.build_app" + ) + @mock.patch( + "model_eval.runners.litert_lm._litert_lm_server.wait_for_server" + ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.litert_lm.Engine" + ) + def test_initialization_with_activation_data_type_enum( + self, + mock_engine, + mock_wait_for_server, + mock_build_app, + mock_uvicorn_config, + mock_uvicorn_server, + mock_thread, + mock_post, + mock_set_min_log_severity, + ): + mock_post.return_value.json.return_value = { + "choices": [{"score": 0.9, "logprobs": []}] + } + config = litert_lm.LiteRtLmRunner.Config( + runner_type="litert-lm", + model_path="/path/to/model", + model_name="my-test-model", + backend="cpu", + activation_data_type="fp32", + ) + runner = litert_lm.LiteRtLmRunner(config) + with mock.patch( + "model_eval.runners.litert_lm.litert_lm.os.path.exists", + return_value=True, + ): + runner.start() + mock_engine.assert_called_once_with( + "/path/to/model", + backend=litert_lm._resolve_backend(litert_lm.litert_lm.Backend.CPU), + max_num_tokens=4096, + activation_data_type=litert_lm.litert_lm.ActivationDataType.FLOAT32, + ) + runner.stop() + + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.litert_lm.set_min_log_severity" + ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.os.path.exists" + ) + @mock.patch( + "model_eval.runners.litert_lm.litert_lm.litert_lm.Engine" + ) + def test_initialization_with_invalid_activation_data_type( + self, mock_engine, mock_exists, mock_set_min_log_severity + ): + mock_exists.return_value = True + config = litert_lm.LiteRtLmRunner.Config( + runner_type="litert-lm", + model_path="/path/to/model", + activation_data_type="invalid_type", + ) + runner = litert_lm.LiteRtLmRunner(config) + with self.assertRaisesRegex(ValueError, "Unsupported activation data type"): + runner.start() + if __name__ == "__main__": unittest.main()