Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions model_eval/runners/litert_lm/litert_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ def _parse_backend(backend_str: str) -> litert_lm.Backend:
)


def _parse_activation_data_type(
activation_data_type: str,
) -> litert_lm.ActivationDataType | None:
"""Parses a string activation data type to the litert_lm.ActivationDataType."""
valid_types = ["fp32", "fp16"]
if activation_data_type not in valid_types:
raise ValueError(
f"Unsupported activation data type: '{activation_data_type}'. "
f"Must be one of {valid_types} or a "
"litert_lm.ActivationDataType instance."
)
return litert_lm.ActivationDataType.from_str(activation_data_type)


def _clamp_log_severity(severity: int) -> litert_lm.LogSeverity:
"""Clamps a given log severity level to the supported litert_lm.LogSeverity."""
if severity <= 0:
Expand Down Expand Up @@ -125,6 +139,8 @@ class Config(base.RunnerConfig):
max_num_tokens: int = 4096
# Whether to enable speculative decoding.
enable_speculative_decoding: bool | None = None
# Optional activation data type to use for the model (e.g. 'fp32', 'fp16').
activation_data_type: str | None = None
# Host for the runner's server.
host: str = "127.0.0.1"
# Port for the runner's server.
Expand Down Expand Up @@ -192,6 +208,10 @@ def start(self) -> None:
engine_kwargs["enable_speculative_decoding"] = (
self._config.enable_speculative_decoding
)
if self._config.activation_data_type is not None:
engine_kwargs["activation_data_type"] = _parse_activation_data_type(
self._config.activation_data_type
)

# Resolve the model path (download from HuggingFace if necessary).
path = _resolve_model_path(self._config.model_path)
Expand Down
1 change: 1 addition & 0 deletions model_eval/tests/unit/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def test_run_produces_jsonl_files(
"host": "127.0.0.1",
"max_num_tokens": 4096,
"enable_speculative_decoding": None,
"activation_data_type": None,
"model_name": "litert-lm-model",
"min_log_severity": 1000,
"always_return_not_greedy": True,
Expand Down
136 changes: 136 additions & 0 deletions model_eval/tests/unit/runners/litert_lm/litert_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,142 @@ def test_absolute_path_resolution(
max_num_tokens=mock.ANY,
)

@mock.patch(
"model_eval.runners.litert_lm.litert_lm.litert_lm.set_min_log_severity"
)
@mock.patch("model_eval.runners.base.requests.post")
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.threading.Thread"
)
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.uvicorn.Server"
)
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.uvicorn.Config"
)
@mock.patch(
"model_eval.runners.litert_lm._litert_lm_server.build_app"
)
@mock.patch(
"model_eval.runners.litert_lm._litert_lm_server.wait_for_server"
)
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.litert_lm.Engine"
)
def test_initialization_with_activation_data_type_string(
self,
mock_engine,
mock_wait_for_server,
mock_build_app,
mock_uvicorn_config,
mock_uvicorn_server,
mock_thread,
mock_post,
mock_set_min_log_severity,
):
mock_post.return_value.json.return_value = {
"choices": [{"score": 0.9, "logprobs": []}]
}
config = litert_lm.LiteRtLmRunner.Config(
runner_type="litert-lm",
model_path="/path/to/model",
model_name="my-test-model",
backend="cpu",
activation_data_type="fp16",
)
runner = litert_lm.LiteRtLmRunner(config)
with mock.patch(
"model_eval.runners.litert_lm.litert_lm.os.path.exists",
return_value=True,
):
runner.start()
mock_engine.assert_called_once_with(
"/path/to/model",
backend=litert_lm._resolve_backend(litert_lm.litert_lm.Backend.CPU),
max_num_tokens=4096,
activation_data_type=litert_lm.litert_lm.ActivationDataType.FLOAT16,
)
runner.stop()

@mock.patch(
"model_eval.runners.litert_lm.litert_lm.litert_lm.set_min_log_severity"
)
@mock.patch("model_eval.runners.base.requests.post")
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.threading.Thread"
)
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.uvicorn.Server"
)
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.uvicorn.Config"
)
@mock.patch(
"model_eval.runners.litert_lm._litert_lm_server.build_app"
)
@mock.patch(
"model_eval.runners.litert_lm._litert_lm_server.wait_for_server"
)
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.litert_lm.Engine"
)
def test_initialization_with_activation_data_type_enum(
self,
mock_engine,
mock_wait_for_server,
mock_build_app,
mock_uvicorn_config,
mock_uvicorn_server,
mock_thread,
mock_post,
mock_set_min_log_severity,
):
mock_post.return_value.json.return_value = {
"choices": [{"score": 0.9, "logprobs": []}]
}
config = litert_lm.LiteRtLmRunner.Config(
runner_type="litert-lm",
model_path="/path/to/model",
model_name="my-test-model",
backend="cpu",
activation_data_type="fp32",
)
runner = litert_lm.LiteRtLmRunner(config)
with mock.patch(
"model_eval.runners.litert_lm.litert_lm.os.path.exists",
return_value=True,
):
runner.start()
mock_engine.assert_called_once_with(
"/path/to/model",
backend=litert_lm._resolve_backend(litert_lm.litert_lm.Backend.CPU),
max_num_tokens=4096,
activation_data_type=litert_lm.litert_lm.ActivationDataType.FLOAT32,
)
runner.stop()

@mock.patch(
"model_eval.runners.litert_lm.litert_lm.litert_lm.set_min_log_severity"
)
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.os.path.exists"
)
@mock.patch(
"model_eval.runners.litert_lm.litert_lm.litert_lm.Engine"
)
def test_initialization_with_invalid_activation_data_type(
self, mock_engine, mock_exists, mock_set_min_log_severity
):
mock_exists.return_value = True
config = litert_lm.LiteRtLmRunner.Config(
runner_type="litert-lm",
model_path="/path/to/model",
activation_data_type="invalid_type",
)
runner = litert_lm.LiteRtLmRunner(config)
with self.assertRaisesRegex(ValueError, "Unsupported activation data type"):
runner.start()


if __name__ == "__main__":
unittest.main()
Loading