Skip to content
Draft
4 changes: 2 additions & 2 deletions backend/app/features/assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from app.core.logger import _setup_custom_logger
from app.features.assistant.orchestrator import generate_simulation_summary
from app.features.assistant.schemas import SimulationSummaryResponse
from app.features.simulation.models import Simulation
from app.features.simulation.models import Case, Simulation
from app.features.user.manager import optional_current_user
from app.features.user.models import User

Expand Down Expand Up @@ -42,7 +42,7 @@ async def summarize_simulation(
stmt = (
select(Simulation)
.options(
joinedload(Simulation.case),
joinedload(Simulation.case).selectinload(Case.links),
joinedload(Simulation.machine),
selectinload(Simulation.artifacts),
selectinload(Simulation.links),
Expand Down
7 changes: 6 additions & 1 deletion backend/app/features/assistant/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel, Field

from app.core.config import settings
from app.features.simulation.link_utils import merge_simulation_and_case_links
from app.features.simulation.models import Artifact, ExternalLink, Simulation

SNAPSHOT_TRUNCATED_CAVEAT = (
Expand Down Expand Up @@ -201,6 +202,10 @@ def build_simulation_snapshot(
*,
max_chars: int | None = None,
) -> SimulationSnapshot:
merged_links = merge_simulation_and_case_links(
simulation.links,
simulation.case.links,
)
snapshot = SimulationSnapshot(
simulation=SnapshotSimulationFields(
id=str(simulation.id),
Expand Down Expand Up @@ -240,7 +245,7 @@ def build_simulation_snapshot(
else None
),
artifacts=_sorted_artifacts(simulation.artifacts),
links=_sorted_links(simulation.links),
links=_sorted_links(merged_links),
snapshot_caveats=[],
)

Expand Down
138 changes: 131 additions & 7 deletions backend/app/features/simulation/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,32 @@
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session, joinedload, selectinload

from app.common.dependencies import get_database_session
from app.core.database import transaction
from app.features.assistant.orchestrator import is_summary_llm_available
from app.features.ingestion.enums import IngestionSourceType, IngestionStatus
from app.features.ingestion.models import Ingestion
from app.features.machine.models import Machine
from app.features.simulation.enums import ExternalLinkKind
from app.features.simulation.link_utils import merge_simulation_and_case_links
from app.features.simulation.models import Artifact, Case, ExternalLink, Simulation
from app.features.simulation.schemas import (
CaseOut,
DiagnosticsLinkRequest,
SimulationCreate,
SimulationOut,
SimulationSummaryCapabilitiesOut,
SimulationSummaryOut,
)
from app.features.user.manager import current_active_user
from app.features.user.models import User
from app.features.user.models import User, UserRole

simulation_router = APIRouter(prefix="/simulations", tags=["Simulations"])
case_router = APIRouter(prefix="/cases", tags=["Cases"])
diagnostics_router = APIRouter(prefix="/diagnostics", tags=["Diagnostics"])


@case_router.get(
Expand Down Expand Up @@ -116,20 +122,23 @@ def get_case(case_id: UUID, db: Session = Depends(get_database_session)) -> Case
"""
case = (
db.query(Case)
.options(selectinload(Case.simulations).selectinload(Simulation.machine))
.options(
selectinload(Case.simulations).selectinload(Simulation.machine),
selectinload(Case.links),
)
.filter(Case.id == case_id)
.first()
)

if not case:
raise HTTPException(status_code=404, detail="Case not found")

resp = _case_to_out(case)
resp = _case_to_out(case, include_links=True)

return resp


def _case_to_out(case: Case) -> CaseOut:
def _case_to_out(case: Case, *, include_links: bool = False) -> CaseOut:
"""Convert a Case ORM instance to CaseOut with nested SimulationSummaryOut.

Parameters
Expand Down Expand Up @@ -176,6 +185,7 @@ def _case_to_out(case: Case) -> CaseOut:
simulations=summaries,
machine_names=machine_names,
hpc_usernames=hpc_usernames,
links=case.links if include_links else [],
created_at=case.created_at,
updated_at=case.updated_at,
)
Expand Down Expand Up @@ -258,7 +268,7 @@ def create_simulation(
sim_loaded = (
db.query(Simulation)
.options(
joinedload(Simulation.case),
joinedload(Simulation.case).selectinload(Case.links),
joinedload(Simulation.machine),
selectinload(Simulation.artifacts),
selectinload(Simulation.links),
Expand All @@ -278,6 +288,43 @@ def create_simulation(
return result


@diagnostics_router.post(
"/link",
status_code=status.HTTP_204_NO_CONTENT,
responses={
204: {"description": "Diagnostics linked successfully."},
401: {"description": "Unauthorized."},
403: {"description": "Forbidden."},
404: {"description": "Matching case not found."},
409: {"description": "Ambiguous case match."},
422: {"description": "Validation error."},
},
)
def link_case_diagnostics(
payload: DiagnosticsLinkRequest,
db: Session = Depends(get_database_session),
user: User = Depends(current_active_user),
) -> None:
"""Resolve one case and upsert case-scoped diagnostic links."""
if user.role not in (UserRole.ADMIN, UserRole.SERVICE_ACCOUNT):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only administrators and service accounts may link diagnostics.",
)

case_id = _resolve_case_id_for_diagnostics_link(
db=db,
case_name=payload.case_name,
machine_name=payload.machine,
hpc_username=payload.hpc_username,
)
_upsert_case_diagnostic_links(
db=db,
case_id=case_id,
diagnostics=payload.diagnostics,
)


@simulation_router.get(
"",
response_model=list[SimulationOut],
Expand Down Expand Up @@ -321,7 +368,7 @@ def list_simulations(
in descending order.
"""
query = db.query(Simulation).options(
joinedload(Simulation.case),
joinedload(Simulation.case).selectinload(Case.links),
joinedload(Simulation.machine),
selectinload(Simulation.artifacts),
selectinload(Simulation.links),
Expand All @@ -336,6 +383,81 @@ def list_simulations(
return [_simulation_to_out(s) for s in sims]


def _resolve_case_id_for_diagnostics_link(
*,
db: Session,
case_name: str,
machine_name: str,
hpc_username: str,
) -> UUID:
"""Resolve a unique case ID from case, machine, and HPC username."""
matches = (
db.query(Case.id)
.join(Simulation, Simulation.case_id == Case.id)
.join(Machine, Simulation.machine_id == Machine.id)
.filter(Case.name == case_name)
.filter(Machine.name == machine_name)
.filter(Simulation.hpc_username == hpc_username)
.distinct()
.all()
)

if not matches:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No case matched the provided case_name, machine, and hpc_username.",
)

# TODO(#193): This ambiguity branch is not reachable while Case.name remains
# globally unique. If case identity moves to (case_name, machine, hpc_username),
# keep this 409 path and replace patched coverage with a DB-backed test.
# https://github.com/E3SM-Project/simboard/issues/193
if len(matches) > 1:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Multiple cases matched the provided case_name, machine, and hpc_username.",
)

return matches[0][0]


def _upsert_case_diagnostic_links(
*,
db: Session,
case_id: UUID,
diagnostics: list,
) -> None:
"""Create or update case-owned diagnostic links idempotently."""
now = datetime.now(timezone.utc)

with transaction(db):
for diagnostic in diagnostics:
stmt = (
pg_insert(ExternalLink)
.values(
case_id=case_id,
kind=ExternalLinkKind.DIAGNOSTIC,
url=str(diagnostic.url),
label=diagnostic.name,
created_at=now,
updated_at=now,
)
.on_conflict_do_update(
index_elements=[
ExternalLink.case_id,
ExternalLink.kind,
ExternalLink.url,
],
index_where=ExternalLink.case_id.is_not(None),
set_={
"label": diagnostic.name,
"updated_at": now,
},
)
)
db.execute(stmt)


@simulation_router.get(
"/{sim_id}",
response_model=SimulationOut,
Expand Down Expand Up @@ -370,7 +492,7 @@ def get_simulation(sim_id: UUID, db: Session = Depends(get_database_session)):
sim = (
db.query(Simulation)
.options(
joinedload(Simulation.case),
joinedload(Simulation.case).selectinload(Case.links),
joinedload(Simulation.machine),
selectinload(Simulation.artifacts),
selectinload(Simulation.links),
Expand Down Expand Up @@ -403,12 +525,14 @@ def _simulation_to_out(sim: Simulation) -> SimulationOut:
"""
case = sim.case
llm_available = is_summary_llm_available()
merged_links = merge_simulation_and_case_links(sim.links, case.links)

result = SimulationOut.model_validate(
{
**{k: v for k, v in sim.__dict__.items() if not k.startswith("_")},
"case_name": case.name,
"case_group": case.case_group,
"links": merged_links,
"summary_capabilities": SimulationSummaryCapabilitiesOut(
llm_available=llm_available,
auto_generate_deterministic_on_load=not llm_available,
Expand Down
30 changes: 30 additions & 0 deletions backend/app/features/simulation/link_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from __future__ import annotations

from collections.abc import Iterable

from app.features.simulation.models import ExternalLink


def merge_simulation_and_case_links(
simulation_links: Iterable[ExternalLink],
case_links: Iterable[ExternalLink],
) -> list[ExternalLink]:
"""Merge simulation-owned and case-owned links with simulation precedence."""
merged: list[ExternalLink] = []
seen: set[tuple[str, str]] = set()

for link in simulation_links:
key = (str(link.kind), link.url)
if key in seen:
continue
seen.add(key)
merged.append(link)

for link in case_links:
key = (str(link.kind), link.url)
if key in seen:
continue
seen.add(key)
merged.append(link)

return merged
52 changes: 48 additions & 4 deletions backend/app/features/simulation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,16 @@
from typing import TYPE_CHECKING, Optional
from uuid import UUID

from sqlalchemy import DateTime, ForeignKey, Integer, String, Text
from sqlalchemy import (
CheckConstraint,
DateTime,
ForeignKey,
Index,
Integer,
String,
Text,
text,
)
from sqlalchemy import Enum as SAEnum
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.dialects.postgresql import UUID as PG_UUID
Expand Down Expand Up @@ -45,6 +54,13 @@ class Case(Base, IDMixin, TimestampMixin):
cascade="all, delete-orphan",
passive_deletes=True,
)
links: Mapped[list[ExternalLink]] = relationship(
"ExternalLink",
back_populates="case",
foreign_keys="ExternalLink.case_id",
cascade="all, delete-orphan",
passive_deletes=True,
)


class Simulation(Base, IDMixin, TimestampMixin):
Expand Down Expand Up @@ -194,9 +210,30 @@ class Artifact(Base, IDMixin, TimestampMixin):

class ExternalLink(Base, IDMixin, TimestampMixin):
__tablename__ = "external_links"
__table_args__ = (
CheckConstraint(
"(simulation_id IS NOT NULL) <> (case_id IS NOT NULL)",
name="exactly_one_owner",
),
Index(
"uq_external_links_case_id_kind_url",
"case_id",
"kind",
"url",
unique=True,
postgresql_where=text("case_id IS NOT NULL"),
),
)

simulation_id: Mapped[UUID] = mapped_column(
PG_UUID(as_uuid=True), ForeignKey("simulations.id", ondelete="CASCADE")
simulation_id: Mapped[UUID | None] = mapped_column(
PG_UUID(as_uuid=True),
ForeignKey("simulations.id", ondelete="CASCADE"),
nullable=True,
)
case_id: Mapped[UUID | None] = mapped_column(
PG_UUID(as_uuid=True),
ForeignKey("cases.id", ondelete="CASCADE"),
nullable=True,
)

kind: Mapped[ExternalLinkKind] = mapped_column(
Expand All @@ -212,8 +249,15 @@ class ExternalLink(Base, IDMixin, TimestampMixin):
url: Mapped[str] = mapped_column(String(1000))
label: Mapped[Optional[str]] = mapped_column(String(200))

simulation: Mapped[Simulation] = relationship(
simulation: Mapped[Simulation | None] = relationship(
back_populates="links",
primaryjoin="ExternalLink.simulation_id==Simulation.id",
foreign_keys=[simulation_id],
passive_deletes=True,
)
case: Mapped[Case | None] = relationship(
back_populates="links",
primaryjoin="ExternalLink.case_id==Case.id",
foreign_keys=[case_id],
passive_deletes=True,
)
Loading
Loading