From 883df246ad2edf7b32b3f2a391cfca0cb3083f1d Mon Sep 17 00:00:00 2001 From: Pavlo Getta Date: Thu, 25 Jun 2026 08:50:07 +0200 Subject: [PATCH 1/3] Support querying and displaying circuits by derivation type Expose circuit derivations on GET /circuit, derived from the existing Derivation table (no migration, no denormalized column). - Filter (always available), both directions mirroring the read fields: generated_derivation__derivation_type (circuit is the generated/derived entity) and used_derivation__derivation_type (circuit is the used/source entity), each with an __in variant. No source-type restriction, so e.g. an emodel->circuit emodel_circuit derivation matches on the generated side. - Read columns (opt-in via ?expand=...): generated_derivations and used_derivations, expandable independently. Unexpanded directions are omitted; expanded but empty serialize as []. Load-aware properties keep this safe under raiseload("*") via two viewonly relationships on Circuit. --- app/db/model.py | 33 ++++++++ app/filters/circuit.py | 25 +++++- app/queries/factory.py | 13 +++ app/schemas/circuit.py | 31 +++++++- app/service/circuit.py | 51 ++++++++++-- tests/test_circuit.py | 175 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 316 insertions(+), 12 deletions(-) diff --git a/app/db/model.py b/app/db/model.py index 91931a44a..cd7b6e221 100644 --- a/app/db/model.py +++ b/app/db/model.py @@ -1919,6 +1919,39 @@ class Circuit(ScientificArtifact, NameDescriptionVectorMixin): # calibration_data (multiple entities): ... + # View-only relationships to the Derivation rows that reference this circuit. + # They are loaded on demand (see app/service/circuit.py `expand`) and read through the + # load-aware properties below, so an un-expanded direction serializes as null instead of + # tripping `raiseload`. + derivations_as_generated: Mapped[list["Derivation"]] = relationship( + "Derivation", + primaryjoin=lambda: Circuit.id == Derivation.generated_id, + foreign_keys=lambda: [Derivation.generated_id], + viewonly=True, + overlaps="generated", + ) + derivations_as_used: Mapped[list["Derivation"]] = relationship( + "Derivation", + primaryjoin=lambda: Circuit.id == Derivation.used_id, + foreign_keys=lambda: [Derivation.used_id], + viewonly=True, + overlaps="used", + ) + + @property + def generated_derivations(self) -> list["Derivation"] | None: + """Derivations where this circuit is the generated entity, or None if not expanded.""" + if "derivations_as_generated" in sa.inspect(self).unloaded: + return None + return self.derivations_as_generated + + @property + def used_derivations(self) -> list["Derivation"] | None: + """Derivations where this circuit is the used entity, or None if not expanded.""" + if "derivations_as_used" in sa.inspect(self).unloaded: + return None + return self.derivations_as_used + @declared_attr.directive @classmethod def __table_args__(cls): # noqa: D105, PLW3201 diff --git a/app/filters/circuit.py b/app/filters/circuit.py index 6eb3c93ba..2b28fc176 100644 --- a/app/filters/circuit.py +++ b/app/filters/circuit.py @@ -3,14 +3,24 @@ from fastapi_filter import with_prefix -from app.db.model import Circuit -from app.db.types import CircuitBuildCategory, CircuitScale, TargetSimulator +from app.db.model import Circuit, Derivation +from app.db.types import CircuitBuildCategory, CircuitScale, DerivationType, TargetSimulator from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import IdFilterMixin, ILikeSearchFilterMixin, NameFilterMixin from app.filters.scientific_artifact import ScientificArtifactFilter +class NestedDerivationFilter(CustomFilter): + """Filter circuits by derivation type, on either the generated or used side.""" + + derivation_type: DerivationType | None = None + derivation_type__in: list[DerivationType] | None = None + + class Constants(CustomFilter.Constants): + model = Derivation + + class CircuitFilterMixin: scale: CircuitScale | None = None scale__in: list[CircuitScale] | None = None @@ -56,6 +66,17 @@ class CircuitFilter( number_connections__lte: int | None = None number_connections__gte: int | None = None + # derivations where the circuit is the generated (derived) entity: "how it was derived" + generated_derivation: Annotated[ + NestedDerivationFilter | None, + FilterDepends(with_prefix("generated_derivation", NestedDerivationFilter)), + ] = None + # derivations where the circuit is the used (source) entity: "what was derived from it" + used_derivation: Annotated[ + NestedDerivationFilter | None, + FilterDepends(with_prefix("used_derivation", NestedDerivationFilter)), + ] = None + order_by: list[str] = ["-creation_date"] # noqa: RUF012 class Constants(ScientificArtifactFilter.Constants): diff --git a/app/queries/factory.py b/app/queries/factory.py index 0917857cb..bf9894723 100644 --- a/app/queries/factory.py +++ b/app/queries/factory.py @@ -10,6 +10,7 @@ CellMorphologyProtocol, Circuit, Contribution, + Derivation, EMCellMesh, EMDenseReconstructionDataset, EModel, @@ -94,6 +95,8 @@ def _get_alias[T: type[DeclarativeBase]](db_cls: T, name: str | None = None) -> used_alias = _get_alias(Entity, "used") generated_alias = _get_alias(Entity, "generated") circuit_alias = _get_alias(Circuit) + generated_derivation_alias = _get_alias(Derivation, "generated_derivation") + used_derivation_alias = _get_alias(Derivation, "used_derivation") ion_channel_alias = _get_alias(IonChannel) em_dense_reconstruction_dataset_alias = _get_alias(EMDenseReconstructionDataset) ion_channel_model_alias = _get_alias(IonChannelModel, "ion_channel_model") @@ -274,6 +277,16 @@ def _get_alias[T: type[DeclarativeBase]](db_cls: T, name: str | None = None) -> ) == circuit_alias.id, ), + # circuits filtered by derivation type on the generated (derived) side + "generated_derivation": lambda q: q.outerjoin( + generated_derivation_alias, + db_model_class.id == generated_derivation_alias.generated_id, + ), + # circuits filtered by derivation type on the used (source) side + "used_derivation": lambda q: q.outerjoin( + used_derivation_alias, + db_model_class.id == used_derivation_alias.used_id, + ), "used": lambda q: q.outerjoin( Usage, db_model_class.id == Usage.usage_activity_id ).outerjoin(used_alias, Usage.usage_entity_id == used_alias.id), diff --git a/app/schemas/circuit.py b/app/schemas/circuit.py index e510f992f..61052a14b 100644 --- a/app/schemas/circuit.py +++ b/app/schemas/circuit.py @@ -1,7 +1,8 @@ import uuid -from app.db.types import CircuitBuildCategory, CircuitScale, TargetSimulator -from app.schemas.base import NameDescriptionMixin +from app.db.types import CircuitBuildCategory, CircuitScale, DerivationType, TargetSimulator +from app.schemas.base import NameDescriptionMixin, Schema +from app.schemas.entity import NestedEntityRead from app.schemas.scientific_artifact import ScientificArtifactCreate, ScientificArtifactRead from app.schemas.utils import make_update_schema @@ -28,6 +29,32 @@ class CircuitRead(CircuitBaseMixin, ScientificArtifactRead): pass +class CircuitGeneratedDerivationRead(Schema): + """A derivation where the circuit is the generated (derived) entity.""" + + used: NestedEntityRead + derivation_type: DerivationType + label: str | None = None + + +class CircuitUsedDerivationRead(Schema): + """A derivation where the circuit is the used (source) entity.""" + + generated: NestedEntityRead + derivation_type: DerivationType + label: str | None = None + + +class CircuitExpandedRead(CircuitRead): + """Circuit read schema with on-demand derivation lists (see `expand` query param). + + A direction that was not expanded serializes as ``null`` (load-aware property on the model). + """ + + generated_derivations: list[CircuitGeneratedDerivationRead] | None = None + used_derivations: list[CircuitUsedDerivationRead] | None = None + + class CircuitCreate(CircuitBaseMixin, ScientificArtifactCreate): pass diff --git a/app/service/circuit.py b/app/service/circuit.py index 3be0401dd..1fb7f9f27 100644 --- a/app/service/circuit.py +++ b/app/service/circuit.py @@ -1,13 +1,17 @@ import uuid -from typing import TYPE_CHECKING +from enum import StrEnum, auto +from functools import partial +from typing import TYPE_CHECKING, Annotated import sqlalchemy as sa +from fastapi import Query from sqlalchemy.orm import aliased, joinedload, raiseload, selectinload from app.db.model import ( Agent, Circuit, Contribution, + Derivation, Person, Subject, ) @@ -31,6 +35,7 @@ from app.schemas.circuit import ( CircuitAdminUpdate, CircuitCreate, + CircuitExpandedRead, CircuitRead, CircuitUserUpdate, ) @@ -41,8 +46,15 @@ from app.filters.base import Aliases -def _load(query: sa.Select): - return query.options( +class ExpandableAttribute(StrEnum): + """Derivation lists that can be loaded on demand via the `expand` query param.""" + + generated_derivations = auto() + used_derivations = auto() + + +def _load(query: sa.Select, *, expand: set[ExpandableAttribute] | None = None): + query = query.options( joinedload(Circuit.license), joinedload(Circuit.subject).options( joinedload(Subject.species), @@ -58,6 +70,15 @@ def _load(query: sa.Select): selectinload(Circuit.assets), raiseload("*"), ) + if expand and ExpandableAttribute.generated_derivations in expand: + query = query.options( + selectinload(Circuit.derivations_as_generated).joinedload(Derivation.used) + ) + if expand and ExpandableAttribute.used_derivations in expand: + query = query.options( + selectinload(Circuit.derivations_as_used).joinedload(Derivation.generated) + ) + return query def read_one( @@ -149,12 +170,15 @@ def _read_many( with_search: SearchDep, facets: FacetsDep, in_brain_region: InBrainRegionDep, + expand: set[ExpandableAttribute] | None, check_authorized_project: bool, -) -> ListResponse[CircuitRead]: +) -> ListResponse[CircuitRead | CircuitExpandedRead]: subject_alias = aliased(Subject, flat=True) agent_alias = aliased(Agent, flat=True) created_by_alias = aliased(Person, flat=True) updated_by_alias = aliased(Person, flat=True) + generated_derivation_alias = aliased(Derivation, flat=True) + used_derivation_alias = aliased(Derivation, flat=True) aliases: Aliases = { Subject: subject_alias, @@ -165,6 +189,10 @@ def _read_many( "created_by": created_by_alias, "updated_by": updated_by_alias, }, + Derivation: { + "generated_derivation": generated_derivation_alias, + "used_derivation": used_derivation_alias, + }, } facet_keys = [ "brain_region", @@ -177,6 +205,8 @@ def _read_many( filter_keys = [ "subject", *facet_keys, + "generated_derivation", + "used_derivation", ] name_to_facet_query_params, filter_joins = query_params_factory( db_model_class=Circuit, @@ -184,6 +214,7 @@ def _read_many( filter_keys=filter_keys, aliases=aliases, ) + response_schema_class = CircuitExpandedRead if expand else CircuitRead return router_read_many( db=db, filter_model=filter_model, @@ -193,10 +224,10 @@ def _read_many( facets=facets, name_to_facet_query_params=name_to_facet_query_params, apply_filter_query_operations=None, - apply_data_query_operations=_load, + apply_data_query_operations=partial(_load, expand=expand), aliases=aliases, pagination_request=pagination_request, - response_schema_class=CircuitRead, + response_schema_class=response_schema_class, authorized_project_id=user_context.project_id, filter_joins=filter_joins, check_authorized_project=check_authorized_project, @@ -211,7 +242,8 @@ def read_many( with_search: SearchDep, facets: FacetsDep, in_brain_region: InBrainRegionDep, -) -> ListResponse[CircuitRead]: + expand: Annotated[set[ExpandableAttribute] | None, Query()] = None, +) -> ListResponse[CircuitRead | CircuitExpandedRead]: return _read_many( user_context=user_context, db=db, @@ -220,6 +252,7 @@ def read_many( with_search=with_search, facets=facets, in_brain_region=in_brain_region, + expand=expand, check_authorized_project=True, ) @@ -232,7 +265,8 @@ def admin_read_many( with_search: SearchDep, facets: FacetsDep, in_brain_region: InBrainRegionDep, -) -> ListResponse[CircuitRead]: + expand: Annotated[set[ExpandableAttribute] | None, Query()] = None, +) -> ListResponse[CircuitRead | CircuitExpandedRead]: return _read_many( user_context=user_context, db=db, @@ -241,6 +275,7 @@ def admin_read_many( with_search=with_search, facets=facets, in_brain_region=in_brain_region, + expand=expand, check_authorized_project=False, ) diff --git a/tests/test_circuit.py b/tests/test_circuit.py index 30c556676..2d45332ba 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -2,6 +2,7 @@ from app.db.model import ( Circuit, + Derivation, ExternalUrl, Publication, ScientificArtifactExternalUrlLink, @@ -10,6 +11,7 @@ from app.db.types import ( CircuitBuildCategory, CircuitScale, + DerivationType, EntityType, ExternalSource, PublicationType, @@ -389,3 +391,176 @@ def test_filtering(client, root_circuit, models): params={"lifecycle_status": "active", "root_circuit_id": str(root_circuit.id)}, ).json()["data"] assert len(data) == len(models) + + +def _add_derivation( + db, + *, + used_id, + generated_id, + person_id, + derivation_type=DerivationType.circuit_extraction, + label=None, +): + return add_db( + db, + Derivation( + used_id=used_id, + generated_id=generated_id, + derivation_type=derivation_type, + label=label, + created_by_id=person_id, + updated_by_id=person_id, + ), + ) + + +def test_filter_by_derivation_type(db, client, root_circuit, circuit, public_circuit, person_id): + """Filter circuits by derivation type on both the generated and used sides.""" + # circuit is derived from root_circuit via extraction; public_circuit via rewiring. + # => root_circuit is the `used` (source) side of both derivations. + _add_derivation( + db, + used_id=root_circuit.id, + generated_id=circuit.id, + person_id=person_id, + derivation_type=DerivationType.circuit_extraction, + ) + _add_derivation( + db, + used_id=root_circuit.id, + generated_id=public_circuit.id, + person_id=person_id, + derivation_type=DerivationType.circuit_rewiring, + ) + + def ids(params): + return { + d["id"] for d in assert_request(client.get, url=ROUTE, params=params).json()["data"] + } + + # --- generated side: "how this circuit was derived" + assert ids({"generated_derivation__derivation_type": "circuit_extraction"}) == {str(circuit.id)} + assert ids({"generated_derivation__derivation_type": "circuit_rewiring"}) == { + str(public_circuit.id) + } + matched = ids( + {"generated_derivation__derivation_type__in": ["circuit_extraction", "circuit_rewiring"]} + ) + assert matched == {str(circuit.id), str(public_circuit.id)} + # the underived root_circuit is not on the generated side of any derivation + assert str(root_circuit.id) not in matched + assert ids({"generated_derivation__derivation_type": "circuit_customization"}) == set() + + # --- used side: "what was derived from this circuit" + assert ids({"used_derivation__derivation_type": "circuit_extraction"}) == {str(root_circuit.id)} + assert ids({"used_derivation__derivation_type": "circuit_rewiring"}) == {str(root_circuit.id)} + assert ids( + {"used_derivation__derivation_type__in": ["circuit_extraction", "circuit_rewiring"]} + ) == {str(root_circuit.id)} + # the derived children are not on the used side of these derivations + assert ids({"used_derivation__derivation_type": "circuit_customization"}) == set() + + # the filter alone does not add derivation columns to the response + data = assert_request( + client.get, + url=ROUTE, + params={"generated_derivation__derivation_type": "circuit_extraction"}, + ).json()["data"] + assert "generated_derivations" not in data[0] + assert "used_derivations" not in data[0] + + +def test_filter_by_derivation_type_non_circuit_source(db, client, circuit, emodel_id, person_id): + """The filter has no source-type restriction: an emodel->circuit derivation matches.""" + _add_derivation( + db, + used_id=emodel_id, + generated_id=circuit.id, + person_id=person_id, + derivation_type=DerivationType.emodel_circuit, + ) + data = assert_request( + client.get, url=ROUTE, params={"generated_derivation__derivation_type": "emodel_circuit"} + ).json()["data"] + assert {d["id"] for d in data} == {str(circuit.id)} + + +def test_expand_derivations(db, client, root_circuit, circuit, person_id): + """`expand` opts into derivation lists, per direction, independently.""" + _add_derivation( + db, + used_id=root_circuit.id, + generated_id=circuit.id, + person_id=person_id, + derivation_type=DerivationType.circuit_extraction, + label="extracted", + ) + + def get_by_id(entity_id, params=None): + data = assert_request(client.get, url=ROUTE, params=params or {}).json()["data"] + return next(d for d in data if d["id"] == str(entity_id)) + + # no expand -> fields are absent entirely (no extra query, no null columns) + child = get_by_id(circuit.id) + assert "generated_derivations" not in child + assert "used_derivations" not in child + + # expand=generated_derivations -> populated on the child; other direction is null + child = get_by_id(circuit.id, {"expand": "generated_derivations"}) + assert child["used_derivations"] is None + assert len(child["generated_derivations"]) == 1 + entry = child["generated_derivations"][0] + assert entry["used"]["id"] == str(root_circuit.id) + assert entry["used"]["type"] == EntityType.circuit + assert entry["derivation_type"] == DerivationType.circuit_extraction + assert entry["label"] == "extracted" + + # expand=used_derivations -> populated on the parent; other direction is null + parent = get_by_id(root_circuit.id, {"expand": "used_derivations"}) + assert parent["generated_derivations"] is None + assert len(parent["used_derivations"]) == 1 + entry = parent["used_derivations"][0] + assert entry["generated"]["id"] == str(circuit.id) + assert entry["generated"]["type"] == EntityType.circuit + assert entry["derivation_type"] == DerivationType.circuit_extraction + assert entry["label"] == "extracted" + + # expand both -> a direction with no derivations is an empty list (not null) + params = {"expand": ["generated_derivations", "used_derivations"]} + child = get_by_id(circuit.id, params) + parent = get_by_id(root_circuit.id, params) + assert len(child["generated_derivations"]) == 1 + assert child["used_derivations"] == [] + assert parent["generated_derivations"] == [] + assert len(parent["used_derivations"]) == 1 + + +def test_filter_and_expand_combined(db, client, root_circuit, circuit, public_circuit, person_id): + """Filtering and expanding compose: filtered rows carry the expanded columns.""" + _add_derivation( + db, + used_id=root_circuit.id, + generated_id=circuit.id, + person_id=person_id, + derivation_type=DerivationType.circuit_extraction, + ) + _add_derivation( + db, + used_id=root_circuit.id, + generated_id=public_circuit.id, + person_id=person_id, + derivation_type=DerivationType.circuit_rewiring, + ) + + data = assert_request( + client.get, + url=ROUTE, + params={ + "generated_derivation__derivation_type": "circuit_extraction", + "expand": "generated_derivations", + }, + ).json()["data"] + assert {d["id"] for d in data} == {str(circuit.id)} + assert data[0]["generated_derivations"][0]["used"]["id"] == str(root_circuit.id) + assert data[0]["used_derivations"] is None From 220104f49a1e506eabef799b9650f81012637bdb Mon Sep 17 00:00:00 2001 From: Pavlo Getta Date: Thu, 25 Jun 2026 09:40:29 +0200 Subject: [PATCH 2/3] Fix list pagination for one-to-many filter joins The paginated id subquery in `_with_subquery` selected ids without DISTINCT, so a one-to-many filter join (e.g. the circuit derivation-type filters, where a circuit can match several Derivation rows) duplicated the row. OFFSET/LIMIT then paged over the duplicated rows while total_items used count(distinct id), so a matching entity could repeat across pages or be skipped. Select DISTINCT ids in the subquery so the limit window operates on distinct entities. This also hardens the pre-existing contribution/mtype/used/generated one-to-many filters. DISTINCT is valid here because every ORDER BY element is already part of the subquery select list. Add a regression test asserting a circuit matched by multiple derivation rows is counted once and does not reappear on a second page. --- app/queries/common.py | 7 ++++++- tests/test_circuit.py | 41 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/app/queries/common.py b/app/queries/common.py index bdaaf4f7b..0bfbf29c0 100644 --- a/app/queries/common.py +++ b/app/queries/common.py @@ -253,7 +253,12 @@ def _with_subquery[I: Identifiable]( select_cols = [db_model_class.id] + [ element.label(label_name) for (label_name, element, _) in labeled_sort_columns ] - subq = data_query.with_only_columns(*select_cols).subquery() + # DISTINCT collapses the duplicate id rows that a one-to-many filter join produces + # (e.g. the derivation-type filters, where a circuit can match several Derivation rows). + # Without it, OFFSET/LIMIT would page over the duplicated rows while total_items uses + # count(distinct id), so a circuit could repeat across pages or be skipped. Every ORDER BY + # element is included in the select list above, so SELECT DISTINCT is always valid here. + subq = data_query.with_only_columns(*select_cols).distinct().subquery() outer_order_bys = [] for label_name, _, modifier in labeled_sort_columns: diff --git a/tests/test_circuit.py b/tests/test_circuit.py index 2d45332ba..a2168ddde 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -564,3 +564,44 @@ def test_filter_and_expand_combined(db, client, root_circuit, circuit, public_ci assert {d["id"] for d in data} == {str(circuit.id)} assert data[0]["generated_derivations"][0]["used"]["id"] == str(root_circuit.id) assert data[0]["used_derivations"] is None + + +def test_derivation_filter_pagination_no_duplicates( + db, client, root_circuit, circuit, public_circuit, person_id +): + """A circuit matching multiple derivation rows must not repeat across pages. + + root_circuit is the `used` source of two extraction derivations, so the one-to-many + filter join yields two rows for it. Pagination must still treat it as a single circuit: + total_items counts distinct circuits and the duplicate join row must not spill it onto a + second page (regression for the derivation-filter pagination bug). + """ + _add_derivation( + db, + used_id=root_circuit.id, + generated_id=circuit.id, + person_id=person_id, + derivation_type=DerivationType.circuit_extraction, + ) + _add_derivation( + db, + used_id=root_circuit.id, + generated_id=public_circuit.id, + person_id=person_id, + derivation_type=DerivationType.circuit_extraction, + ) + + params = {"used_derivation__derivation_type": "circuit_extraction"} + + # only root_circuit is the `used` side of an extraction derivation, despite two join rows + first = assert_request( + client.get, url=ROUTE, params={**params, "page": 1, "page_size": 1} + ).json() + assert first["pagination"]["total_items"] == 1 + assert [d["id"] for d in first["data"]] == [str(root_circuit.id)] + + # the duplicate join row must not place root_circuit on a second page too + second = assert_request( + client.get, url=ROUTE, params={**params, "page": 2, "page_size": 1} + ).json() + assert second["data"] == [] From 2abbe838125f4d61d4c6aebcf85a6355f3a0e468 Mon Sep 17 00:00:00 2001 From: Pavlo Getta Date: Fri, 26 Jun 2026 10:04:53 +0200 Subject: [PATCH 3/3] Support expand on circuit get-one endpoints Wire the `expand` query param into read_one and admin_read_one, mirroring read_many/admin_read_many: switch the response schema to CircuitExpandedRead when expand is set and pass partial(_load, expand=expand). Add a test exercising expand on both the user and admin get-one routes. --- app/service/circuit.py | 14 ++++++++------ tests/test_circuit.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/app/service/circuit.py b/app/service/circuit.py index 1fb7f9f27..83ad72b57 100644 --- a/app/service/circuit.py +++ b/app/service/circuit.py @@ -85,28 +85,30 @@ def read_one( user_context: UserContextDep, db: SessionDep, id_: uuid.UUID, -) -> CircuitRead: + expand: Annotated[set[ExpandableAttribute] | None, Query()] = None, +) -> CircuitRead | CircuitExpandedRead: return router_read_one( db=db, id_=id_, db_model_class=Circuit, user_context=user_context, - response_schema_class=CircuitRead, - apply_operations=_load, + response_schema_class=CircuitExpandedRead if expand else CircuitRead, + apply_operations=partial(_load, expand=expand), ) def admin_read_one( db: SessionDep, id_: uuid.UUID, -) -> CircuitRead: + expand: Annotated[set[ExpandableAttribute] | None, Query()] = None, +) -> CircuitRead | CircuitExpandedRead: return router_read_one( db=db, id_=id_, db_model_class=Circuit, user_context=None, - response_schema_class=CircuitRead, - apply_operations=_load, + response_schema_class=CircuitExpandedRead if expand else CircuitRead, + apply_operations=partial(_load, expand=expand), ) diff --git a/tests/test_circuit.py b/tests/test_circuit.py index a2168ddde..0598f7453 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -536,6 +536,45 @@ def get_by_id(entity_id, params=None): assert len(parent["used_derivations"]) == 1 +def test_expand_derivations_read_one(db, client, client_admin, root_circuit, circuit, person_id): + """`expand` opts into derivation lists on the get-one endpoint too (user + admin).""" + _add_derivation( + db, + used_id=root_circuit.id, + generated_id=circuit.id, + person_id=person_id, + derivation_type=DerivationType.circuit_extraction, + label="extracted", + ) + + # no expand -> derivation fields are absent entirely + data = assert_request(client.get, url=f"{ROUTE}/{circuit.id}").json() + assert "generated_derivations" not in data + assert "used_derivations" not in data + + # expand=generated_derivations on the child -> populated; other direction is null + data = assert_request( + client.get, url=f"{ROUTE}/{circuit.id}", params={"expand": "generated_derivations"} + ).json() + assert data["used_derivations"] is None + assert len(data["generated_derivations"]) == 1 + entry = data["generated_derivations"][0] + assert entry["used"]["id"] == str(root_circuit.id) + assert entry["used"]["type"] == EntityType.circuit + assert entry["derivation_type"] == DerivationType.circuit_extraction + assert entry["label"] == "extracted" + + # the admin get-one endpoint supports expand too + data = assert_request( + client_admin.get, + url=f"{ADMIN_ROUTE}/{root_circuit.id}", + params={"expand": "used_derivations"}, + ).json() + assert data["generated_derivations"] is None + assert len(data["used_derivations"]) == 1 + assert data["used_derivations"][0]["generated"]["id"] == str(circuit.id) + + def test_filter_and_expand_combined(db, client, root_circuit, circuit, public_circuit, person_id): """Filtering and expanding compose: filtered rows carry the expanded columns.""" _add_derivation(