From d422d59c614e6fd56ad873867f11dd6271392577 Mon Sep 17 00:00:00 2001 From: EfeDurmaz16 Date: Fri, 15 May 2026 13:40:51 +0300 Subject: [PATCH] fix: accept slash-qualified model IDs --- src/modelsdotdev/_internal/data.py | 25 +++++++++++++++++++++---- tests/test_api.py | 1 + 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/modelsdotdev/_internal/data.py b/src/modelsdotdev/_internal/data.py index 1d58837..dee6333 100644 --- a/src/modelsdotdev/_internal/data.py +++ b/src/modelsdotdev/_internal/data.py @@ -284,14 +284,31 @@ def parse_model_id(model_id: str) -> ModelRef: def get_model_by_id(model_id: str) -> Model | None: - """Return a model by canonical ``provider:model`` ID.""" - if ":" not in model_id: - raise ValueError("model_id must be in 'provider:model' format") + """Return a model by ``provider:model`` or ``provider/model`` ID.""" + invalid_message = ( + "model_id must include provider and model IDs " + "as 'provider:model' or 'provider/model'" + ) + colon_index = model_id.find(":") + slash_index = model_id.find("/") + if colon_index == -1 and slash_index == -1: + raise ValueError(invalid_message) + + separator = ( + "/" + if slash_index != -1 + and (colon_index == -1 or slash_index < colon_index) + else ":" + ) + provider_id, provider_model_id = model_id.split(separator, 1) + + if not provider_id or not provider_model_id: + raise ValueError(invalid_message) with closing(_connect()) as connection: row = connection.execute( f"SELECT {MODEL_COLUMNS} FROM models WHERE full_id = ?", - (model_id,), + (f"{provider_id}:{provider_model_id}",), ).fetchone() return None if row is None else _model_from_row(connection, row) diff --git a/tests/test_api.py b/tests/test_api.py index 9d4dbc6..ff2f12d 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -59,6 +59,7 @@ def test_model_iteration_and_lookup_use_real_database() -> None: assert model.provider_id in providers_by_id assert model.qualified_id == f"{model.provider_id}:{model.id}" assert get_model_by_id(model.qualified_id) == model + assert get_model_by_id(f"{model.provider_id}/{model.id}") == model assert ( providers_by_id[model.provider_id].get_model_by_id(model.id) == model