Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion model_eval/runners/litert_lm/litert_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ class Config(base.RunnerConfig):
min_log_severity: int = 1000
# Whether to skip the slow run_decode step for greedy verification.
always_return_not_greedy: bool = True
# Whether to enable text and multimodal scoring.
enable_scoring: bool = True

@classmethod
def from_unified_args(
Expand Down Expand Up @@ -267,7 +269,7 @@ def model_name(self) -> str:
@property
def capabilities(self) -> base.RunnerCapabilities:
return base.RunnerCapabilities(
text_scoring=True,
text_scoring=self._config.enable_scoring,
text_generation=True,
multimodal_scoring=False,
multimodal_generation=True,
Expand Down
7 changes: 4 additions & 3 deletions model_eval/tests/unit/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def test_run_produces_jsonl_files(
"model_name": "litert-lm-model",
"min_log_severity": 1000,
"always_return_not_greedy": True,
"enable_scoring": True,
},
)

Expand All @@ -293,12 +294,12 @@ def test_run_produces_jsonl_files(
metric_files = [f for f in files if f.endswith("_metrics.jsonl")]
sample_files = [f for f in files if f.endswith("_samples.jsonl")]

self.assertEqual(len(metric_files), 1)
self.assertEqual(len(sample_files), 1)
self.assertLen(metric_files, 1)
self.assertLen(sample_files, 1)

with open(os.path.join(self.output_dir, metric_files[0]), "r") as f:
lines = f.readlines()
self.assertEqual(len(lines), 2)
self.assertLen(lines, 2)
data = json.loads(lines[0])
self.assertEqual(data["task"], "mmlu")
self.assertEqual(data["metric"], "acc")
Expand Down
16 changes: 16 additions & 0 deletions model_eval/tests/unit/runners/litert_lm/litert_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,22 @@ def test_initialization_with_invalid_activation_data_type(
runner = litert_lm.LiteRtLmRunner(config)
with self.assertRaisesRegex(ValueError, "Unsupported activation data type"):
runner.start()
def test_capabilities_enable_scoring(self):
# Default is True
config_default = litert_lm.LiteRtLmRunner.Config(
runner_type="litert-lm", model_path="/foo"
)
runner_default = litert_lm.LiteRtLmRunner(config_default)
self.assertTrue(runner_default.capabilities.text_scoring)
self.assertFalse(runner_default.capabilities.multimodal_scoring)

# Set to False
config_disabled = litert_lm.LiteRtLmRunner.Config(
runner_type="litert-lm", model_path="/foo", enable_scoring=False
)
runner_disabled = litert_lm.LiteRtLmRunner(config_disabled)
self.assertFalse(runner_disabled.capabilities.text_scoring)
self.assertFalse(runner_disabled.capabilities.multimodal_scoring)


if __name__ == "__main__":
Expand Down
Loading