Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions app/db/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1930,6 +1930,39 @@ class Circuit(ScientificArtifact, NameDescriptionVectorMixin):

# calibration_data (multiple entities): ...

# View-only relationships to the Derivation rows that reference this circuit.
# They are loaded on demand (see app/service/circuit.py `expand`) and read through the
# load-aware properties below, so an un-expanded direction serializes as null instead of
# tripping `raiseload`.
derivations_as_generated: Mapped[list["Derivation"]] = relationship(
"Derivation",
primaryjoin=lambda: Circuit.id == Derivation.generated_id,
foreign_keys=lambda: [Derivation.generated_id],
viewonly=True,
overlaps="generated",
)
derivations_as_used: Mapped[list["Derivation"]] = relationship(
"Derivation",
primaryjoin=lambda: Circuit.id == Derivation.used_id,
foreign_keys=lambda: [Derivation.used_id],
viewonly=True,
overlaps="used",
)

@property
def generated_derivations(self) -> list["Derivation"] | None:
"""Derivations where this circuit is the generated entity, or None if not expanded."""
if "derivations_as_generated" in sa.inspect(self).unloaded:
return None
return self.derivations_as_generated

@property
def used_derivations(self) -> list["Derivation"] | None:
"""Derivations where this circuit is the used entity, or None if not expanded."""
if "derivations_as_used" in sa.inspect(self).unloaded:
return None
return self.derivations_as_used

@declared_attr.directive
@classmethod
def __table_args__(cls): # noqa: D105, PLW3201
Expand Down
25 changes: 23 additions & 2 deletions app/filters/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,24 @@

from fastapi_filter import with_prefix

from app.db.model import Circuit
from app.db.types import CircuitBuildCategory, CircuitScale, TargetSimulator
from app.db.model import Circuit, Derivation
from app.db.types import CircuitBuildCategory, CircuitScale, DerivationType, TargetSimulator
from app.dependencies.filter import FilterDepends
from app.filters.base import CustomFilter
from app.filters.common import IdFilterMixin, ILikeSearchFilterMixin, NameFilterMixin
from app.filters.scientific_artifact import ScientificArtifactFilter


class NestedDerivationFilter(CustomFilter):
"""Filter circuits by derivation type, on either the generated or used side."""

derivation_type: DerivationType | None = None
derivation_type__in: list[DerivationType] | None = None

class Constants(CustomFilter.Constants):
model = Derivation


class CircuitFilterMixin:
scale: CircuitScale | None = None
scale__in: list[CircuitScale] | None = None
Expand Down Expand Up @@ -56,6 +66,17 @@ class CircuitFilter(
number_connections__lte: int | None = None
number_connections__gte: int | None = None

# derivations where the circuit is the generated (derived) entity: "how it was derived"
generated_derivation: Annotated[
NestedDerivationFilter | None,
FilterDepends(with_prefix("generated_derivation", NestedDerivationFilter)),
] = None
# derivations where the circuit is the used (source) entity: "what was derived from it"
used_derivation: Annotated[
NestedDerivationFilter | None,
FilterDepends(with_prefix("used_derivation", NestedDerivationFilter)),
] = None

order_by: list[str] = ["-creation_date"] # noqa: RUF012

class Constants(ScientificArtifactFilter.Constants):
Expand Down
7 changes: 6 additions & 1 deletion app/queries/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,12 @@ def _with_subquery[I: Identifiable](
select_cols = [db_model_class.id] + [
element.label(label_name) for (label_name, element, _) in labeled_sort_columns
]
subq = data_query.with_only_columns(*select_cols).subquery()
# DISTINCT collapses the duplicate id rows that a one-to-many filter join produces
# (e.g. the derivation-type filters, where a circuit can match several Derivation rows).
# Without it, OFFSET/LIMIT would page over the duplicated rows while total_items uses
# count(distinct id), so a circuit could repeat across pages or be skipped. Every ORDER BY
# element is included in the select list above, so SELECT DISTINCT is always valid here.
subq = data_query.with_only_columns(*select_cols).distinct().subquery()

outer_order_bys = []
for label_name, _, modifier in labeled_sort_columns:
Expand Down
13 changes: 13 additions & 0 deletions app/queries/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CellMorphologyProtocol,
Circuit,
Contribution,
Derivation,
EMCellMesh,
EMDenseReconstructionDataset,
EModel,
Expand Down Expand Up @@ -94,6 +95,8 @@ def _get_alias[T: type[DeclarativeBase]](db_cls: T, name: str | None = None) ->
used_alias = _get_alias(Entity, "used")
generated_alias = _get_alias(Entity, "generated")
circuit_alias = _get_alias(Circuit)
generated_derivation_alias = _get_alias(Derivation, "generated_derivation")
used_derivation_alias = _get_alias(Derivation, "used_derivation")
ion_channel_alias = _get_alias(IonChannel)
em_dense_reconstruction_dataset_alias = _get_alias(EMDenseReconstructionDataset)
ion_channel_model_alias = _get_alias(IonChannelModel, "ion_channel_model")
Expand Down Expand Up @@ -274,6 +277,16 @@ def _get_alias[T: type[DeclarativeBase]](db_cls: T, name: str | None = None) ->
)
== circuit_alias.id,
),
# circuits filtered by derivation type on the generated (derived) side
"generated_derivation": lambda q: q.outerjoin(
generated_derivation_alias,
db_model_class.id == generated_derivation_alias.generated_id,
),
# circuits filtered by derivation type on the used (source) side
"used_derivation": lambda q: q.outerjoin(
used_derivation_alias,
db_model_class.id == used_derivation_alias.used_id,
),
"used": lambda q: q.outerjoin(
Usage, db_model_class.id == Usage.usage_activity_id
).outerjoin(used_alias, Usage.usage_entity_id == used_alias.id),
Expand Down
31 changes: 29 additions & 2 deletions app/schemas/circuit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import uuid

from app.db.types import CircuitBuildCategory, CircuitScale, TargetSimulator
from app.schemas.base import NameDescriptionMixin
from app.db.types import CircuitBuildCategory, CircuitScale, DerivationType, TargetSimulator
from app.schemas.base import NameDescriptionMixin, Schema
from app.schemas.entity import NestedEntityRead
from app.schemas.scientific_artifact import ScientificArtifactCreate, ScientificArtifactRead
from app.schemas.utils import make_update_schema

Expand All @@ -28,6 +29,32 @@ class CircuitRead(CircuitBaseMixin, ScientificArtifactRead):
pass


class CircuitGeneratedDerivationRead(Schema):
"""A derivation where the circuit is the generated (derived) entity."""

used: NestedEntityRead
derivation_type: DerivationType
label: str | None = None


class CircuitUsedDerivationRead(Schema):
"""A derivation where the circuit is the used (source) entity."""

generated: NestedEntityRead
derivation_type: DerivationType
label: str | None = None


class CircuitExpandedRead(CircuitRead):
"""Circuit read schema with on-demand derivation lists (see `expand` query param).

A direction that was not expanded serializes as ``null`` (load-aware property on the model).
"""

generated_derivations: list[CircuitGeneratedDerivationRead] | None = None
used_derivations: list[CircuitUsedDerivationRead] | None = None


class CircuitCreate(CircuitBaseMixin, ScientificArtifactCreate):
pass

Expand Down
65 changes: 51 additions & 14 deletions app/service/circuit.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import uuid
from typing import TYPE_CHECKING
from enum import StrEnum, auto
from functools import partial
from typing import TYPE_CHECKING, Annotated

import sqlalchemy as sa
from fastapi import Query
from sqlalchemy.orm import aliased, joinedload, raiseload, selectinload

from app.db.model import (
Agent,
Circuit,
Contribution,
Derivation,
Person,
Subject,
)
Expand All @@ -31,6 +35,7 @@
from app.schemas.circuit import (
CircuitAdminUpdate,
CircuitCreate,
CircuitExpandedRead,
CircuitRead,
CircuitUserUpdate,
)
Expand All @@ -41,8 +46,15 @@
from app.filters.base import Aliases


def _load(query: sa.Select):
return query.options(
class ExpandableAttribute(StrEnum):
"""Derivation lists that can be loaded on demand via the `expand` query param."""

generated_derivations = auto()
used_derivations = auto()


def _load(query: sa.Select, *, expand: set[ExpandableAttribute] | None = None):
query = query.options(
joinedload(Circuit.license),
joinedload(Circuit.subject).options(
joinedload(Subject.species),
Expand All @@ -58,34 +70,45 @@ def _load(query: sa.Select):
selectinload(Circuit.assets),
raiseload("*"),
)
if expand and ExpandableAttribute.generated_derivations in expand:
query = query.options(
selectinload(Circuit.derivations_as_generated).joinedload(Derivation.used)
)
if expand and ExpandableAttribute.used_derivations in expand:
query = query.options(
selectinload(Circuit.derivations_as_used).joinedload(Derivation.generated)
)
return query

@GianlucaFicarelli GianlucaFicarelli Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For discussion: it would be nice if the possibility to filter by derivation type and get the list of derivations is made available to any entity.
However, it would require changes:

  • to the _load function in every entity (change to many files, with many repetitions, so I would avoid that)
  • or some changes in queries/common.py for read_many/read_one

Thoughts? Could it be useful already for other derivation types?
If the change adds too much complexity, the generalization can be postponed.



def read_one(
user_context: UserContextDep,
db: SessionDep,
id_: uuid.UUID,
) -> CircuitRead:
expand: Annotated[set[ExpandableAttribute] | None, Query()] = None,
) -> CircuitRead | CircuitExpandedRead:
return router_read_one(
db=db,
id_=id_,
db_model_class=Circuit,
user_context=user_context,
response_schema_class=CircuitRead,
apply_operations=_load,
response_schema_class=CircuitExpandedRead if expand else CircuitRead,
apply_operations=partial(_load, expand=expand),
)


def admin_read_one(
db: SessionDep,
id_: uuid.UUID,
) -> CircuitRead:
expand: Annotated[set[ExpandableAttribute] | None, Query()] = None,
) -> CircuitRead | CircuitExpandedRead:
return router_read_one(
db=db,
id_=id_,
db_model_class=Circuit,
user_context=None,
response_schema_class=CircuitRead,
apply_operations=_load,
response_schema_class=CircuitExpandedRead if expand else CircuitRead,
apply_operations=partial(_load, expand=expand),
)


Expand Down Expand Up @@ -149,12 +172,15 @@ def _read_many(
with_search: SearchDep,
facets: FacetsDep,
in_brain_region: InBrainRegionDep,
expand: set[ExpandableAttribute] | None,
check_authorized_project: bool,
) -> ListResponse[CircuitRead]:
) -> ListResponse[CircuitRead | CircuitExpandedRead]:
subject_alias = aliased(Subject, flat=True)
agent_alias = aliased(Agent, flat=True)
created_by_alias = aliased(Person, flat=True)
updated_by_alias = aliased(Person, flat=True)
generated_derivation_alias = aliased(Derivation, flat=True)
used_derivation_alias = aliased(Derivation, flat=True)

aliases: Aliases = {
Subject: subject_alias,
Expand All @@ -165,6 +191,10 @@ def _read_many(
"created_by": created_by_alias,
"updated_by": updated_by_alias,
},
Derivation: {
"generated_derivation": generated_derivation_alias,
"used_derivation": used_derivation_alias,
},
}
facet_keys = [
"brain_region",
Expand All @@ -177,13 +207,16 @@ def _read_many(
filter_keys = [
"subject",
*facet_keys,
"generated_derivation",
"used_derivation",
]
name_to_facet_query_params, filter_joins = query_params_factory(
db_model_class=Circuit,
facet_keys=facet_keys,
filter_keys=filter_keys,
aliases=aliases,
)
response_schema_class = CircuitExpandedRead if expand else CircuitRead
return router_read_many(
db=db,
filter_model=filter_model,
Expand All @@ -193,10 +226,10 @@ def _read_many(
facets=facets,
name_to_facet_query_params=name_to_facet_query_params,
apply_filter_query_operations=None,
apply_data_query_operations=_load,
apply_data_query_operations=partial(_load, expand=expand),
aliases=aliases,
pagination_request=pagination_request,
response_schema_class=CircuitRead,
response_schema_class=response_schema_class,
authorized_project_id=user_context.project_id,
filter_joins=filter_joins,
check_authorized_project=check_authorized_project,
Expand All @@ -211,7 +244,8 @@ def read_many(
with_search: SearchDep,
facets: FacetsDep,
in_brain_region: InBrainRegionDep,
) -> ListResponse[CircuitRead]:
expand: Annotated[set[ExpandableAttribute] | None, Query()] = None,
) -> ListResponse[CircuitRead | CircuitExpandedRead]:
return _read_many(
user_context=user_context,
db=db,
Expand All @@ -220,6 +254,7 @@ def read_many(
with_search=with_search,
facets=facets,
in_brain_region=in_brain_region,
expand=expand,
check_authorized_project=True,
)

Expand All @@ -232,7 +267,8 @@ def admin_read_many(
with_search: SearchDep,
facets: FacetsDep,
in_brain_region: InBrainRegionDep,
) -> ListResponse[CircuitRead]:
expand: Annotated[set[ExpandableAttribute] | None, Query()] = None,
) -> ListResponse[CircuitRead | CircuitExpandedRead]:
return _read_many(
user_context=user_context,
db=db,
Expand All @@ -241,6 +277,7 @@ def admin_read_many(
with_search=with_search,
facets=facets,
in_brain_region=in_brain_region,
expand=expand,
check_authorized_project=False,
)

Expand Down
Loading