Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/modelsdotdev/_internal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,31 @@ def parse_model_id(model_id: str) -> ModelRef:


def get_model_by_id(model_id: str) -> Model | None:
"""Return a model by canonical ``provider:model`` ID."""
if ":" not in model_id:
raise ValueError("model_id must be in 'provider:model' format")
"""Return a model by ``provider:model`` or ``provider/model`` ID."""
invalid_message = (
"model_id must include provider and model IDs "
"as 'provider:model' or 'provider/model'"
)
colon_index = model_id.find(":")
slash_index = model_id.find("/")
if colon_index == -1 and slash_index == -1:
raise ValueError(invalid_message)

separator = (
"/"
if slash_index != -1
and (colon_index == -1 or slash_index < colon_index)
else ":"
)
provider_id, provider_model_id = model_id.split(separator, 1)

if not provider_id or not provider_model_id:
raise ValueError(invalid_message)

with closing(_connect()) as connection:
row = connection.execute(
f"SELECT {MODEL_COLUMNS} FROM models WHERE full_id = ?",
(model_id,),
(f"{provider_id}:{provider_model_id}",),
).fetchone()
return None if row is None else _model_from_row(connection, row)

Expand Down
1 change: 1 addition & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_model_iteration_and_lookup_use_real_database() -> None:
assert model.provider_id in providers_by_id
assert model.qualified_id == f"{model.provider_id}:{model.id}"
assert get_model_by_id(model.qualified_id) == model
assert get_model_by_id(f"{model.provider_id}/{model.id}") == model
assert (
providers_by_id[model.provider_id].get_model_by_id(model.id)
== model
Expand Down