From e7932c97f7e78e8d98012c72fb54c4d4d1e8500f Mon Sep 17 00:00:00 2001 From: AI Edge Eval Team Date: Sat, 27 Jun 2026 09:54:33 -0700 Subject: [PATCH] Add enable_scoring config to LiteRtLmRunner. PiperOrigin-RevId: 939094193 --- model_eval/runners/litert_lm/litert_lm.py | 4 +++- model_eval/tests/unit/pipeline_test.py | 7 ++++--- .../unit/runners/litert_lm/litert_lm_test.py | 16 ++++++++++++++++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/model_eval/runners/litert_lm/litert_lm.py b/model_eval/runners/litert_lm/litert_lm.py index a36d5b9..fd4abed 100644 --- a/model_eval/runners/litert_lm/litert_lm.py +++ b/model_eval/runners/litert_lm/litert_lm.py @@ -156,6 +156,8 @@ class Config(base.RunnerConfig): min_log_severity: int = 1000 # Whether to skip the slow run_decode step for greedy verification. always_return_not_greedy: bool = True + # Whether to enable text and multimodal scoring. + enable_scoring: bool = True @classmethod def from_unified_args( @@ -267,7 +269,7 @@ def model_name(self) -> str: @property def capabilities(self) -> base.RunnerCapabilities: return base.RunnerCapabilities( - text_scoring=True, + text_scoring=self._config.enable_scoring, text_generation=True, multimodal_scoring=False, multimodal_generation=True, diff --git a/model_eval/tests/unit/pipeline_test.py b/model_eval/tests/unit/pipeline_test.py index abeffe8..c6606e8 100644 --- a/model_eval/tests/unit/pipeline_test.py +++ b/model_eval/tests/unit/pipeline_test.py @@ -285,6 +285,7 @@ def test_run_produces_jsonl_files( "model_name": "litert-lm-model", "min_log_severity": 1000, "always_return_not_greedy": True, + "enable_scoring": True, }, ) @@ -293,12 +294,12 @@ def test_run_produces_jsonl_files( metric_files = [f for f in files if f.endswith("_metrics.jsonl")] sample_files = [f for f in files if f.endswith("_samples.jsonl")] - self.assertEqual(len(metric_files), 1) - self.assertEqual(len(sample_files), 1) + self.assertLen(metric_files, 1) + self.assertLen(sample_files, 1) with open(os.path.join(self.output_dir, metric_files[0]), "r") as f: lines = f.readlines() - self.assertEqual(len(lines), 2) + self.assertLen(lines, 2) data = json.loads(lines[0]) self.assertEqual(data["task"], "mmlu") self.assertEqual(data["metric"], "acc") diff --git a/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py b/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py index 9a9b1ae..2525456 100644 --- a/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py +++ b/model_eval/tests/unit/runners/litert_lm/litert_lm_test.py @@ -469,6 +469,22 @@ def test_initialization_with_invalid_activation_data_type( runner = litert_lm.LiteRtLmRunner(config) with self.assertRaisesRegex(ValueError, "Unsupported activation data type"): runner.start() + def test_capabilities_enable_scoring(self): + # Default is True + config_default = litert_lm.LiteRtLmRunner.Config( + runner_type="litert-lm", model_path="/foo" + ) + runner_default = litert_lm.LiteRtLmRunner(config_default) + self.assertTrue(runner_default.capabilities.text_scoring) + self.assertFalse(runner_default.capabilities.multimodal_scoring) + + # Set to False + config_disabled = litert_lm.LiteRtLmRunner.Config( + runner_type="litert-lm", model_path="/foo", enable_scoring=False + ) + runner_disabled = litert_lm.LiteRtLmRunner(config_disabled) + self.assertFalse(runner_disabled.capabilities.text_scoring) + self.assertFalse(runner_disabled.capabilities.multimodal_scoring) if __name__ == "__main__":