Skip to content
33 changes: 33 additions & 0 deletions app/db/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,39 @@ def __table_args__(cls): # noqa: D105, PLW3201
passive_deletes=False,
)

# View-only relationships to the Derivation rows that reference this entity.
# They are loaded on demand (see the `expand` query param) and read through the
# load-aware properties below, so an un-expanded direction serializes as null instead of
# tripping `raiseload`. Defined on Entity so every entity subclass inherits them.
derivations_as_generated: Mapped[list["Derivation"]] = relationship(
"Derivation",
primaryjoin=lambda: Entity.id == Derivation.generated_id,
foreign_keys=lambda: [Derivation.generated_id],
viewonly=True,
overlaps="generated",
)
derivations_as_used: Mapped[list["Derivation"]] = relationship(
"Derivation",
primaryjoin=lambda: Entity.id == Derivation.used_id,
foreign_keys=lambda: [Derivation.used_id],
viewonly=True,
overlaps="used",
)

@property
def generated_from_derivations(self) -> list["Derivation"] | None:
"""Derivations where this entity is the generated entity, or None if not expanded."""
if "derivations_as_generated" in sa.inspect(self).unloaded:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need that guard?
If the property is accessed without loading the derivations, it could be better to fail rather than returning None.
Removing the guard can be acceptable only if we are sure that the property is accessed only when required.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked better the code and these properties are accessed even when expand is empty, so the guard is needed.

return None
return self.derivations_as_generated

@property
def used_by_derivations(self) -> list["Derivation"] | None:
"""Derivations where this entity is the used entity, or None if not expanded."""
if "derivations_as_used" in sa.inspect(self).unloaded:
return None
return self.derivations_as_used

__mapper_args__ = { # noqa: RUF012
"polymorphic_identity": __tablename__,
"polymorphic_on": "type",
Expand Down
4 changes: 4 additions & 0 deletions app/dependencies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from app.errors import ApiError, ApiErrorCode
from app.filters.base import CustomFilter
from app.filters.brain_region import WithinBrainRegionDirection, filter_by_region
from app.queries.expand import EntityExpand
from app.queries.filter import filter_from_db
from app.queries.types import ApplyOperations
from app.schemas.types import Facet, Facets, PaginationRequest
Expand Down Expand Up @@ -233,3 +234,6 @@ class DerivationQuery(BaseModel):
SearchDep = Annotated[Search, Depends()]
InBrainRegionDep = Annotated[InBrainRegionQuery, Depends()]
DerivationQueryDep = Annotated[DerivationQuery, Depends()]
# `?expand=generated_from_derivations&expand=used_by_derivations` — available on every entity
# read endpoint.
ExpandDep = Annotated[set[EntityExpand] | None, Query()]
36 changes: 34 additions & 2 deletions app/filters/entity.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,36 @@
import uuid
from typing import Annotated

from fastapi_filter import with_prefix

from app.db.model import Entity
from app.db.types import EntityLifecycleStatus, EntityType
from app.db.model import Derivation, Entity
from app.db.types import DerivationType, EntityLifecycleStatus, EntityType
from app.dependencies.filter import FilterDepends
from app.filters.base import CustomFilter
from app.filters.common import AuthorizedFilterMixin, CreationFilterMixin, IdFilterMixin
from app.filters.person import CreatorFilterMixin


class NestedDerivationFilter(CustomFilter):
"""Filter entities by a related Derivation, on either the generated or used side.

Exposed on every entity filter as ``generated_derivation`` / ``used_derivation`` (see
EntityFilterMixin). Besides ``derivation_type``, the related-entity ids are filterable:
on ``generated_derivation`` the meaningful side is ``used_id`` ("derived from entity X"),
on ``used_derivation`` it is ``generated_id`` ("source of entity Y").
"""

derivation_type: DerivationType | None = None
derivation_type__in: list[DerivationType] | None = None
used_id: uuid.UUID | None = None
used_id__in: list[uuid.UUID] | None = None

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, this can be useful for selecting cell morphologies derived from one or more EMDenseReconstructionDataset with derivation_type em_dense_reconstruction_dataset_cell_morphology and support pagination.

See this comment

generated_id: uuid.UUID | None = None
generated_id__in: list[uuid.UUID] | None = None

class Constants(CustomFilter.Constants):
model = Derivation


class BasicEntityFilter(CustomFilter):
type: EntityType | None = None

Expand Down Expand Up @@ -45,3 +66,14 @@ class EntityFilterMixin(
ContributionFilterMixin,
):
lifecycle_status: EntityLifecycleStatus | None = None

# Derivations where this entity is the generated (derived) side: "how it was derived".
generated_derivation: Annotated[
NestedDerivationFilter | None,
FilterDepends(with_prefix("generated_derivation", NestedDerivationFilter)),
] = None
# Derivations where this entity is the used (source) side: "what was derived from it".
used_derivation: Annotated[
NestedDerivationFilter | None,
FilterDepends(with_prefix("used_derivation", NestedDerivationFilter)),
] = None
16 changes: 15 additions & 1 deletion app/queries/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
from collections.abc import Set as AbstractSet
from http import HTTPStatus

import sqlalchemy as sa
Expand Down Expand Up @@ -34,6 +35,7 @@
from app.filters.base import Aliases, CustomFilter
from app.queries import crud
from app.queries.constants import NESTED_RELATIONSHIPS_MAP
from app.queries.expand import apply_derivation_expand
from app.queries.filter import filter_from_db
from app.queries.types import ApplyOperations, SupportsModelValidate
from app.queries.utils import (
Expand All @@ -56,6 +58,7 @@ def router_read_one[T: Schema, I: Identifiable](
user_context: UserContext | None,
response_schema_class: SupportsModelValidate[T],
apply_operations: ApplyOperations[I] | None,
expand: AbstractSet[str] | None = None,
) -> T:
"""Read a model from the database.

Expand All @@ -66,6 +69,7 @@ def router_read_one[T: Schema, I: Identifiable](
user_context: the user context with project id and user information.
response_schema_class: Pydantic schema class for the returned data.
apply_operations: transformer function that modifies the select query.
expand: optional set of derivation directions to eager-load (entity models only).

Returns:
the model data as a Pydantic model.
Expand All @@ -81,6 +85,7 @@ def router_read_one[T: Schema, I: Identifiable](
)
if apply_operations:
query = apply_operations(query)
query = apply_derivation_expand(query, db_model_class, expand)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be applied only to entities?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It already is - apply_derivation_expand early-returns unless issubclass(db_model_class, Entity) (and when expand is empty), so it's a no-op for non-entity models.

with ensure_result(error_message=f"{db_model_class.__name__} not found"):
row = db.execute(query).unique().scalar_one()
return response_schema_class.model_validate(row)
Expand Down Expand Up @@ -253,7 +258,12 @@ def _with_subquery[I: Identifiable](
select_cols = [db_model_class.id] + [
element.label(label_name) for (label_name, element, _) in labeled_sort_columns
]
subq = data_query.with_only_columns(*select_cols).subquery()
# DISTINCT collapses the duplicate id rows that a one-to-many filter join produces
# (e.g. the derivation-type filters, where a circuit can match several Derivation rows).
# Without it, OFFSET/LIMIT would page over the duplicated rows while total_items uses
# count(distinct id), so a circuit could repeat across pages or be skipped. Every ORDER BY
# element is included in the select list above, so SELECT DISTINCT is always valid here.
subq = data_query.with_only_columns(*select_cols).distinct().subquery()

outer_order_bys = []
for label_name, _, modifier in labeled_sort_columns:
Expand Down Expand Up @@ -288,6 +298,7 @@ def router_read_many[T: Schema, I: Identifiable]( # noqa: PLR0913
filter_joins: dict[str, ApplyOperations] | None = None,
embedding: list[float] | None = None,
check_authorized_project: bool = True,
expand: AbstractSet[str] | None = None,
) -> ListResponse[T]:
"""Read multiple models from the database.

Expand All @@ -310,6 +321,7 @@ def router_read_many[T: Schema, I: Identifiable]( # noqa: PLR0913
- the keys in `name_to_facet_query_params`, for retrieving the facets.
embedding: optional list of floats representing an embedding vector for semantic search.
check_authorized_project: Whether to constrain or not to authorized entities
expand: optional set of derivation directions to eager-load (entity models only).

Returns:
the list of model data, pagination, and facets as a Pydantic model.
Expand Down Expand Up @@ -364,6 +376,8 @@ def router_read_many[T: Schema, I: Identifiable]( # noqa: PLR0913
data_query = _with_subquery(data_query=data_query, db_model_class=db_model_class)
data_query = apply_data_query_operations(data_query)

data_query = apply_derivation_expand(data_query, db_model_class, expand)

# unique is needed b/c it contains results that include joined eager loads against collections
data = db.execute(data_query).scalars().unique()

Expand Down
46 changes: 46 additions & 0 deletions app/queries/expand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Shared, entity-wide `expand` support for on-demand derivation lists.

Every entity read schema carries the load-aware ``generated_from_derivations`` /
``used_by_derivations`` fields (see app.schemas.entity.DerivationReadMixin); they serialize as
``null`` unless the matching direction was eagerly loaded. This module centralizes the enum the
read endpoints expose as the ``expand`` query param and the loader options that populate the
relationships.
"""

from collections.abc import Set as AbstractSet
from enum import StrEnum, auto

import sqlalchemy as sa
from sqlalchemy.orm import selectinload

from app.db.model import Derivation, Entity, Identifiable


class EntityExpand(StrEnum):
"""Derivation lists that any entity endpoint can load on demand via ``?expand=``."""

generated_from_derivations = auto()
used_by_derivations = auto()


def apply_derivation_expand(
query: sa.Select, db_model_class: type[Identifiable], expand: AbstractSet[str] | None
) -> sa.Select:
"""Eager-load the requested derivation directions onto an entity query.

A no-op for non-entity models and when nothing is requested. Adding the specific
``selectinload`` after a service's ``raiseload("*")`` is intentional: the more specific path
overrides the wildcard, so an un-expanded direction stays unloaded and its load-aware property
returns ``None`` instead of tripping ``raiseload``.
"""
if not expand or not issubclass(db_model_class, Entity):
return query
if EntityExpand.generated_from_derivations in expand:
query = query.options(
selectinload(db_model_class.derivations_as_generated).joinedload(Derivation.used)
)
if EntityExpand.used_by_derivations in expand:
query = query.options(
selectinload(db_model_class.derivations_as_used).joinedload(Derivation.generated)
)
return query
44 changes: 42 additions & 2 deletions app/queries/factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, cast

import sqlalchemy as sa
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, aliased

from app.db.model import (
Agent,
Expand All @@ -10,6 +10,7 @@
CellMorphologyProtocol,
Circuit,
Contribution,
Derivation,
EMCellMesh,
EMDenseReconstructionDataset,
EModel,
Expand Down Expand Up @@ -47,6 +48,11 @@
from app.queries.types import ApplyOperations


def _is_entity_model(db_model_class: Any) -> bool:
"""Whether the model is an Entity subclass (kept separate to avoid narrowing the caller)."""
return isinstance(db_model_class, type) and issubclass(db_model_class, Entity)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it indeed the case that the type needs to be as broad as "Any"?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I've narrowed the param to object.

@GianlucaFicarelli GianlucaFicarelli Jun 30, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see advantages in using object instead of Any here.
db_model_class is defined as Any in the parameters of query_params_factory, so it should be narrowed there if possible, or left as Any in this PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now I've reverted that back to Any.



def query_params_factory[I: Identifiable](
db_model_class: Any, facet_keys: list[str], filter_keys: list[str], aliases: Aliases
) -> tuple[dict[str, FacetQueryParams], dict[str, ApplyOperations[I]]]:
Expand Down Expand Up @@ -74,6 +80,18 @@ def _get_alias[T: type[DeclarativeBase]](db_cls: T, name: str | None = None) ->
value = db_cls if name is None else value.get(name, db_cls)
return cast("T", value)

# The generated_derivation / used_derivation filters are inherited by every entity filter
# (see app.filters.entity.EntityFilterMixin). For entity models, register a distinct pair of
# Derivation aliases in `aliases` (mutating the dict the caller also hands to router_read_many)
# so the join lambdas below and `filter_model.filter(...)` resolve the *same* alias objects,
# and so the two directions don't collide on a single un-aliased Derivation table.
is_entity_model = _is_entity_model(db_model_class)
if is_entity_model and Derivation not in aliases:
aliases[Derivation] = {
"generated_derivation": aliased(Derivation, flat=True),
"used_derivation": aliased(Derivation, flat=True),
}

@GianlucaFicarelli GianlucaFicarelli Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mutating the aliases dict seems a bit fragile, and unexpected by the caller.
An alternative could be to provide a function to build the aliases, that is called in each read_many instead of initializing the aliases dict directly.
However, it requires more changes and it would be easier to review if it's done in a separate PR (that can be also done after this one, I can have a look as well).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, Gianluca, I'll note this down - can look a bit later (in case you don't address it in the meantime).

morphology_alias = _get_alias(CellMorphology)
cell_morphology_protocol_alias = _get_alias(CellMorphologyProtocol)
emodel_alias = _get_alias(EModel)
Expand All @@ -94,6 +112,8 @@ def _get_alias[T: type[DeclarativeBase]](db_cls: T, name: str | None = None) ->
used_alias = _get_alias(Entity, "used")
generated_alias = _get_alias(Entity, "generated")
circuit_alias = _get_alias(Circuit)
generated_derivation_alias = _get_alias(Derivation, "generated_derivation")
used_derivation_alias = _get_alias(Derivation, "used_derivation")
ion_channel_alias = _get_alias(IonChannel)
em_dense_reconstruction_dataset_alias = _get_alias(EMDenseReconstructionDataset)
ion_channel_model_alias = _get_alias(IonChannelModel, "ion_channel_model")
Expand Down Expand Up @@ -274,6 +294,16 @@ def _get_alias[T: type[DeclarativeBase]](db_cls: T, name: str | None = None) ->
)
== circuit_alias.id,
),
# circuits filtered by derivation type on the generated (derived) side
"generated_derivation": lambda q: q.outerjoin(
generated_derivation_alias,
db_model_class.id == generated_derivation_alias.generated_id,
),
# circuits filtered by derivation type on the used (source) side
"used_derivation": lambda q: q.outerjoin(
used_derivation_alias,
db_model_class.id == used_derivation_alias.used_id,
),
"used": lambda q: q.outerjoin(
Usage, db_model_class.id == Usage.usage_activity_id
).outerjoin(used_alias, Usage.usage_entity_id == used_alias.id),
Expand Down Expand Up @@ -326,5 +356,15 @@ def _get_alias[T: type[DeclarativeBase]](db_cls: T, name: str | None = None) ->
),
}
name_to_facet_query_params = {k: name_to_facet_query_params[k] for k in facet_keys}
filter_joins = {k: filter_joins[k] for k in filter_keys}
# Every entity query gets the derivation join lambdas appended (as left joins, so order is
# safe), regardless of whether the service listed them. They are applied only when the
# corresponding nested filter is actually set (see app.queries.filter.filter_from_db).
selected_filter_keys = list(filter_keys)
if is_entity_model:
selected_filter_keys += [
key
for key in ("generated_derivation", "used_derivation")
if key not in selected_filter_keys
]
filter_joins = {k: filter_joins[k] for k in selected_filter_keys}
return name_to_facet_query_params, filter_joins
32 changes: 31 additions & 1 deletion app/schemas/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pydantic import RootModel

from app.db.types import EntityLifecycleStatus, EntityType
from app.db.types import DerivationType, EntityLifecycleStatus, EntityType
from app.schemas.asset import AssetsMixin
from app.schemas.base import AuthorizationOptionalPublicMixin, Schema
from app.schemas.identifiable import IdentifiableCreate, IdentifiableRead, NestedIdentifiableRead
Expand Down Expand Up @@ -35,13 +35,42 @@ class NestedEntityRead(NestedIdentifiableRead, EntityBaseReadMixin):
type: EntityType


class GeneratedDerivationRead(Schema):
"""A derivation where this entity is the generated (derived) entity."""

used: NestedEntityRead
derivation_type: DerivationType
label: str | None = None


class UsedDerivationRead(Schema):
"""A derivation where this entity is the used (source) entity."""

generated: NestedEntityRead
derivation_type: DerivationType
label: str | None = None


class DerivationReadMixin:
"""On-demand derivation lists, available on every entity read (see the `expand` query param).

A direction that was not expanded serializes as ``null`` (load-aware property on the Entity
model, so no extra query and `raiseload` is never tripped); an expanded-but-empty direction
serializes as ``[]``.
"""

generated_from_derivations: list[GeneratedDerivationRead] | None = None
used_by_derivations: list[UsedDerivationRead] | None = None


from app.schemas.contribution import ContributionReadWithoutEntityMixin # noqa: E402


class EntityReadWoutAssets(
IdentifiableRead,
EntityBaseReadMixin,
ContributionReadWithoutEntityMixin,
DerivationReadMixin,
):
"""Entity model that includes created_by and updated_by information."""

Expand All @@ -53,6 +82,7 @@ class EntityRead(
EntityBaseReadMixin,
AssetsMixin,
ContributionReadWithoutEntityMixin,
DerivationReadMixin,
):
"""Entity model that includes created_by and updated_by information."""

Expand Down
Loading