From 2da79282e4b1c1ab9e3d2d44601363fa2aebe09b Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Fri, 6 Mar 2026 14:51:32 +0000 Subject: [PATCH 1/6] feat: implement discovery domain with record search, feature catalog, and feature search Add the discovery bounded context providing read-only search and filtering over published records and hook-emitted feature data. Record search: metadata field filters with typed casts (Float for numbers, Date for dates), free-text search across text/URL fields, cursor-based keyset pagination, and configurable sort. Feature search: per-hook-table queries with schema-aware column access using local MetaData (no global cache pollution), operator validation against JSON Schema types, and row_id-based cursor encoding. Feature catalog: lists all feature tables with column schemas and record counts. Targeted get_feature_table_schema() for single-table lookup in the search path (avoids N+1 catalog scan). Includes defense-in-depth quoted_name for dynamic SQL identifiers, FeatureReader port for record-level feature enrichment, and discovery API surface documentation. Closes #77 --- .../versions/0d9fbacf8e58_initial_tables.py | 11 +- server/osa/application/api/rest/app.py | 2 + .../application/api/v1/routes/discovery.py | 134 ++++++ .../osa/application/api/v1/routes/records.py | 1 + server/osa/application/di.py | 2 + server/osa/domain/discovery/__init__.py | 1 + server/osa/domain/discovery/model/__init__.py | 0 server/osa/domain/discovery/model/value.py | 102 +++++ server/osa/domain/discovery/port/__init__.py | 0 .../discovery/port/field_definition_reader.py | 17 + .../osa/domain/discovery/port/read_store.py | 56 +++ server/osa/domain/discovery/query/__init__.py | 0 .../discovery/query/get_feature_catalog.py | 32 ++ .../domain/discovery/query/search_features.py | 51 +++ .../domain/discovery/query/search_records.py | 54 +++ .../osa/domain/discovery/service/__init__.py | 0 .../osa/domain/discovery/service/discovery.py | 206 +++++++++ server/osa/domain/discovery/util/__init__.py | 0 .../osa/domain/discovery/util/di/__init__.py | 3 + .../osa/domain/discovery/util/di/provider.py | 27 ++ .../osa/domain/record/port/feature_reader.py | 19 + server/osa/domain/record/query/get_record.py | 3 + server/osa/domain/record/service/record.py | 14 +- .../persistence/adapter/discovery.py | 395 ++++++++++++++++++ .../persistence/adapter/feature_reader.py | 61 +++ server/osa/infrastructure/persistence/di.py | 21 + .../osa/infrastructure/persistence/tables.py | 9 +- .../tests/unit/domain/discovery/__init__.py | 0 .../discovery/test_discovery_service.py | 328 +++++++++++++++ .../discovery/test_get_feature_catalog.py | 100 +++++ .../domain/discovery/test_search_features.py | 212 ++++++++++ .../domain/discovery/test_search_records.py | 69 +++ .../tests/unit/domain/discovery/test_value.py | 100 +++++ .../domain/record/test_get_record_handler.py | 1 + .../domain/record/test_record_features.py | 216 ++++++++++ .../unit/domain/record/test_record_service.py | 3 + .../test_field_definition_reader.py | 130 ++++++ 37 files changed, 2377 insertions(+), 3 deletions(-) create mode 100644 server/osa/application/api/v1/routes/discovery.py create mode 100644 server/osa/domain/discovery/__init__.py create mode 100644 server/osa/domain/discovery/model/__init__.py create mode 100644 server/osa/domain/discovery/model/value.py create mode 100644 server/osa/domain/discovery/port/__init__.py create mode 100644 server/osa/domain/discovery/port/field_definition_reader.py create mode 100644 server/osa/domain/discovery/port/read_store.py create mode 100644 server/osa/domain/discovery/query/__init__.py create mode 100644 server/osa/domain/discovery/query/get_feature_catalog.py create mode 100644 server/osa/domain/discovery/query/search_features.py create mode 100644 server/osa/domain/discovery/query/search_records.py create mode 100644 server/osa/domain/discovery/service/__init__.py create mode 100644 server/osa/domain/discovery/service/discovery.py create mode 100644 server/osa/domain/discovery/util/__init__.py create mode 100644 server/osa/domain/discovery/util/di/__init__.py create mode 100644 server/osa/domain/discovery/util/di/provider.py create mode 100644 server/osa/domain/record/port/feature_reader.py create mode 100644 server/osa/infrastructure/persistence/adapter/discovery.py create mode 100644 server/osa/infrastructure/persistence/adapter/feature_reader.py create mode 100644 server/tests/unit/domain/discovery/__init__.py create mode 100644 server/tests/unit/domain/discovery/test_discovery_service.py create mode 100644 server/tests/unit/domain/discovery/test_get_feature_catalog.py create mode 100644 server/tests/unit/domain/discovery/test_search_features.py create mode 100644 server/tests/unit/domain/discovery/test_search_records.py create mode 100644 server/tests/unit/domain/discovery/test_value.py create mode 100644 server/tests/unit/domain/record/test_record_features.py create mode 100644 server/tests/unit/infrastructure/persistence/test_field_definition_reader.py diff --git a/server/migrations/versions/0d9fbacf8e58_initial_tables.py b/server/migrations/versions/0d9fbacf8e58_initial_tables.py index 3914472..697bdd5 100644 --- a/server/migrations/versions/0d9fbacf8e58_initial_tables.py +++ b/server/migrations/versions/0d9fbacf8e58_initial_tables.py @@ -10,6 +10,7 @@ import sqlalchemy as sa from alembic import op +from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision: str = "0d9fbacf8e58" @@ -53,13 +54,20 @@ def upgrade() -> None: "records", sa.Column("srn", sa.String(), nullable=False), sa.Column("deposition_srn", sa.String(), nullable=False), - sa.Column("metadata", sa.JSON(), nullable=False), + sa.Column("metadata", postgresql.JSONB(), nullable=False), sa.Column("indexes", sa.JSON(), nullable=False), sa.Column("published_at", sa.DateTime(timezone=True), nullable=False), sa.PrimaryKeyConstraint("srn"), ) op.create_index("idx_records_deposition_srn", "records", ["deposition_srn"]) op.create_index("idx_records_published_at", "records", ["published_at"]) + op.create_index( + "idx_records_metadata_gin", + "records", + ["metadata"], + postgresql_using="gin", + postgresql_ops={"metadata": "jsonb_path_ops"}, + ) # EVENTS (Outbox) op.create_table( @@ -89,6 +97,7 @@ def downgrade() -> None: op.drop_table("events") # RECORDS + op.drop_index("idx_records_metadata_gin", table_name="records") op.drop_index("idx_records_published_at", table_name="records") op.drop_index("idx_records_deposition_srn", table_name="records") op.drop_table("records") diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index 151c229..9924dd7 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -12,6 +12,7 @@ auth, conventions, depositions, + discovery, events, health, ontologies, @@ -89,6 +90,7 @@ def create_app() -> FastAPI: app_instance.include_router(conventions.router, prefix="/api/v1") app_instance.include_router(depositions.router, prefix="/api/v1") app_instance.include_router(validation.router, prefix="/api/v1") + app_instance.include_router(discovery.router, prefix="/api/v1") # Global OSA error handler - maps domain and infrastructure errors to HTTP responses @app_instance.exception_handler(OSAError) diff --git a/server/osa/application/api/v1/routes/discovery.py b/server/osa/application/api/v1/routes/discovery.py new file mode 100644 index 0000000..723fb01 --- /dev/null +++ b/server/osa/application/api/v1/routes/discovery.py @@ -0,0 +1,134 @@ +"""Discovery API routes — search and filter records and features.""" + +from typing import Any + +from dishka.integrations.fastapi import DishkaRoute, FromDishka +from fastapi import APIRouter +from pydantic import BaseModel, Field + +from osa.domain.discovery.model.value import ( + Filter, + SortOrder, +) +from osa.domain.discovery.query.get_feature_catalog import ( + GetFeatureCatalog, + GetFeatureCatalogHandler, + GetFeatureCatalogResult, +) +from osa.domain.discovery.query.search_features import ( + SearchFeatures, + SearchFeaturesHandler, + SearchFeaturesResult, +) +from osa.domain.discovery.query.search_records import ( + SearchRecords, + SearchRecordsHandler, + SearchRecordsResult, +) + +router = APIRouter( + prefix="/discovery", + tags=["discovery"], + route_class=DishkaRoute, +) + + +# ── Request / Response models ── + + +class RecordSearchRequest(BaseModel): + filters: list[Filter] = [] + q: str | None = None + sort: str = "published_at" + order: SortOrder = SortOrder.DESC + cursor: str | None = None + limit: int = Field(default=20, ge=1, le=100) + + +class RecordSearchResponse(BaseModel): + results: list[dict[str, Any]] + total: int + cursor: str | None + has_more: bool + + +class FeatureCatalogResponse(BaseModel): + tables: list[dict[str, Any]] + + +class FeatureSearchRequest(BaseModel): + filters: list[Filter] = [] + record_srn: str | None = None + sort: str = "id" + order: SortOrder = SortOrder.DESC + cursor: str | None = None + limit: int = Field(default=50, ge=1, le=100) + + +class FeatureSearchResponse(BaseModel): + rows: list[dict[str, Any]] + total: int + cursor: str | None + has_more: bool + + +# ── Routes ── + + +@router.post("/records") +async def search_records( + body: RecordSearchRequest, + handler: FromDishka[SearchRecordsHandler], +) -> RecordSearchResponse: + """Search and filter published records.""" + result: SearchRecordsResult = await handler.run( + SearchRecords( + filters=body.filters, + q=body.q, + sort=body.sort, + order=body.order, + cursor=body.cursor, + limit=body.limit, + ) + ) + return RecordSearchResponse( + results=result.results, + total=result.total, + cursor=result.cursor, + has_more=result.has_more, + ) + + +@router.get("/features") +async def get_feature_catalog( + handler: FromDishka[GetFeatureCatalogHandler], +) -> FeatureCatalogResponse: + """List available feature tables with column schemas and record counts.""" + result: GetFeatureCatalogResult = await handler.run(GetFeatureCatalog()) + return FeatureCatalogResponse(tables=result.tables) + + +@router.post("/features/{hook_name}") +async def search_features( + hook_name: str, + body: FeatureSearchRequest, + handler: FromDishka[SearchFeaturesHandler], +) -> FeatureSearchResponse: + """Query and filter rows in a specific feature table.""" + result: SearchFeaturesResult = await handler.run( + SearchFeatures( + hook_name=hook_name, + filters=body.filters, + record_srn=body.record_srn, + sort=body.sort, + order=body.order, + cursor=body.cursor, + limit=body.limit, + ) + ) + return FeatureSearchResponse( + rows=result.rows, + total=result.total, + cursor=result.cursor, + has_more=result.has_more, + ) diff --git a/server/osa/application/api/v1/routes/records.py b/server/osa/application/api/v1/routes/records.py index 06b57de..2e6b1b6 100644 --- a/server/osa/application/api/v1/routes/records.py +++ b/server/osa/application/api/v1/routes/records.py @@ -41,5 +41,6 @@ async def get_record( "deposition_srn": str(result.deposition_srn), "metadata": result.metadata, "published_at": result.published_at.isoformat(), + "features": result.features, } ) diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 321ce4c..a6d8759 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -3,6 +3,7 @@ from osa.cli.util.paths import OSAPaths from osa.config import Config from osa.domain.auth.util.di import AuthProvider +from osa.domain.discovery.util.di import DiscoveryProvider from osa.domain.deposition.util.di import DepositionProvider from osa.domain.feature.util.di import FeatureProvider from osa.domain.semantics.util.di.provider import SemanticsProvider @@ -37,6 +38,7 @@ def create_container() -> AsyncContainer: ValidationProvider(), AuthProvider(), AuthInfraProvider(), + DiscoveryProvider(), context={Config: config, OSAPaths: paths}, scopes=Scope, # type: ignore[arg-type] # Custom scope class ) diff --git a/server/osa/domain/discovery/__init__.py b/server/osa/domain/discovery/__init__.py new file mode 100644 index 0000000..c6d5d02 --- /dev/null +++ b/server/osa/domain/discovery/__init__.py @@ -0,0 +1 @@ +"""Discovery domain — read-only search and filter API for records and features.""" diff --git a/server/osa/domain/discovery/model/__init__.py b/server/osa/domain/discovery/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/discovery/model/value.py b/server/osa/domain/discovery/model/value.py new file mode 100644 index 0000000..2a94c6d --- /dev/null +++ b/server/osa/domain/discovery/model/value.py @@ -0,0 +1,102 @@ +"""Discovery domain value objects — filters, cursors, result types.""" + +from __future__ import annotations + +import base64 +import json +from datetime import datetime +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel + +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.model.srn import RecordSRN + + +class FilterOperator(StrEnum): + EQ = "eq" + CONTAINS = "contains" + GTE = "gte" + LTE = "lte" + + +class SortOrder(StrEnum): + ASC = "asc" + DESC = "desc" + + +class Filter(BaseModel): + field: str + operator: FilterOperator + value: str | float | bool + + +VALID_OPERATORS: dict[FieldType, set[FilterOperator]] = { + FieldType.TEXT: {FilterOperator.EQ, FilterOperator.CONTAINS}, + FieldType.URL: {FilterOperator.EQ, FilterOperator.CONTAINS}, + FieldType.NUMBER: {FilterOperator.EQ, FilterOperator.GTE, FilterOperator.LTE}, + FieldType.DATE: {FilterOperator.EQ, FilterOperator.GTE, FilterOperator.LTE}, + FieldType.BOOLEAN: {FilterOperator.EQ}, + FieldType.TERM: {FilterOperator.EQ}, +} + + +def encode_cursor(sort_value: Any, id_value: Any) -> str: + """Encode a cursor as base64 JSON.""" + payload = {"s": sort_value, "id": id_value} + return base64.urlsafe_b64encode(json.dumps(payload).encode()).decode() + + +def decode_cursor(cursor: str) -> dict[str, Any]: + """Decode a base64 JSON cursor. Raises ValueError on malformed input.""" + try: + raw = base64.urlsafe_b64decode(cursor.encode()) + data = json.loads(raw) + except Exception as exc: + raise ValueError(f"Malformed cursor: {exc}") from exc + if not isinstance(data, dict) or "s" not in data or "id" not in data: + raise ValueError("Cursor must contain 's' and 'id' keys") + return data + + +class RecordSummary(BaseModel): + srn: RecordSRN + published_at: datetime + metadata: dict[str, Any] + + +class RecordSearchResult(BaseModel): + results: list[RecordSummary] + total: int + cursor: str | None + has_more: bool + + +class ColumnInfo(BaseModel): + name: str + type: str + required: bool + + +class FeatureCatalogEntry(BaseModel): + hook_name: str + columns: list[ColumnInfo] + record_count: int + + +class FeatureCatalog(BaseModel): + tables: list[FeatureCatalogEntry] + + +class FeatureRow(BaseModel): + row_id: int + record_srn: RecordSRN + data: dict[str, Any] + + +class FeatureSearchResult(BaseModel): + rows: list[FeatureRow] + total: int + cursor: str | None + has_more: bool diff --git a/server/osa/domain/discovery/port/__init__.py b/server/osa/domain/discovery/port/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/discovery/port/field_definition_reader.py b/server/osa/domain/discovery/port/field_definition_reader.py new file mode 100644 index 0000000..b763c8a --- /dev/null +++ b/server/osa/domain/discovery/port/field_definition_reader.py @@ -0,0 +1,17 @@ +"""FieldDefinitionReader port — cross-domain read port for schema field lookups.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from osa.domain.semantics.model.value import FieldType + + +class FieldDefinitionReader(Protocol): + async def get_all_field_types(self) -> dict[str, FieldType]: + """Return global field_name -> FieldType map across all schemas. + + Raises ValidationError if same field name has conflicting types across schemas. + """ + ... diff --git a/server/osa/domain/discovery/port/read_store.py b/server/osa/domain/discovery/port/read_store.py new file mode 100644 index 0000000..a0e351e --- /dev/null +++ b/server/osa/domain/discovery/port/read_store.py @@ -0,0 +1,56 @@ +"""DiscoveryReadStore port — read-only access to records and feature data.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: + from osa.domain.discovery.model.value import ( + FeatureCatalogEntry, + FeatureRow, + Filter, + RecordSummary, + SortOrder, + ) + from osa.domain.semantics.model.value import FieldType + from osa.domain.shared.model.srn import RecordSRN + + +class DiscoveryReadStore(Protocol): + async def search_records( + self, + filters: list[Filter], + text_fields: list[str], + q: str | None, + sort: str, + order: SortOrder, + cursor: dict | None, + limit: int, + field_types: dict[str, FieldType] | None = None, + ) -> tuple[list[RecordSummary], int]: + """Search and filter published records. Returns (results, total_count).""" + ... + + async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: + """List all feature tables with column schemas and record counts.""" + ... + + async def get_feature_table_schema(self, hook_name: str) -> FeatureCatalogEntry | None: + """Look up a single feature table's schema by hook name. + + Returns None if the hook_name is not found. + """ + ... + + async def search_features( + self, + hook_name: str, + filters: list[Filter], + record_srn: RecordSRN | None, + sort: str, + order: SortOrder, + cursor: dict | None, + limit: int, + ) -> tuple[list[FeatureRow], int]: + """Search and filter feature rows. Returns (rows, total_count).""" + ... diff --git a/server/osa/domain/discovery/query/__init__.py b/server/osa/domain/discovery/query/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/discovery/query/get_feature_catalog.py b/server/osa/domain/discovery/query/get_feature_catalog.py new file mode 100644 index 0000000..30f891d --- /dev/null +++ b/server/osa/domain/discovery/query/get_feature_catalog.py @@ -0,0 +1,32 @@ +"""GetFeatureCatalog query — list available feature tables with schemas and counts.""" + +from osa.domain.discovery.model.value import FeatureCatalog +from osa.domain.discovery.service.discovery import DiscoveryService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.query import Query, QueryHandler, Result + + +class GetFeatureCatalog(Query): + pass + + +class GetFeatureCatalogResult(Result): + tables: list[dict] + + +class GetFeatureCatalogHandler(QueryHandler[GetFeatureCatalog, GetFeatureCatalogResult]): + __auth__ = public() + discovery_service: DiscoveryService + + async def run(self, cmd: GetFeatureCatalog) -> GetFeatureCatalogResult: + catalog: FeatureCatalog = await self.discovery_service.get_feature_catalog() + return GetFeatureCatalogResult( + tables=[ + { + "hook_name": entry.hook_name, + "columns": [c.model_dump() for c in entry.columns], + "record_count": entry.record_count, + } + for entry in catalog.tables + ] + ) diff --git a/server/osa/domain/discovery/query/search_features.py b/server/osa/domain/discovery/query/search_features.py new file mode 100644 index 0000000..c61bd60 --- /dev/null +++ b/server/osa/domain/discovery/query/search_features.py @@ -0,0 +1,51 @@ +"""SearchFeatures query — query and filter rows in a specific feature table.""" + +from osa.domain.discovery.model.value import ( + FeatureSearchResult, + Filter, + SortOrder, +) +from osa.domain.discovery.service.discovery import DiscoveryService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.model.srn import RecordSRN +from osa.domain.shared.query import Query, QueryHandler, Result + + +class SearchFeatures(Query): + hook_name: str + filters: list[Filter] = [] + record_srn: str | None = None + sort: str = "id" + order: SortOrder = SortOrder.DESC + cursor: str | None = None + limit: int = 50 + + +class SearchFeaturesResult(Result): + rows: list[dict] + total: int + cursor: str | None + has_more: bool + + +class SearchFeaturesHandler(QueryHandler[SearchFeatures, SearchFeaturesResult]): + __auth__ = public() + discovery_service: DiscoveryService + + async def run(self, cmd: SearchFeatures) -> SearchFeaturesResult: + record_srn = RecordSRN.parse(cmd.record_srn) if cmd.record_srn else None + result: FeatureSearchResult = await self.discovery_service.search_features( + hook_name=cmd.hook_name, + filters=cmd.filters, + record_srn=record_srn, + sort=cmd.sort, + order=cmd.order, + cursor=cmd.cursor, + limit=cmd.limit, + ) + return SearchFeaturesResult( + rows=[{"record_srn": str(r.record_srn), **r.data} for r in result.rows], + total=result.total, + cursor=result.cursor, + has_more=result.has_more, + ) diff --git a/server/osa/domain/discovery/query/search_records.py b/server/osa/domain/discovery/query/search_records.py new file mode 100644 index 0000000..2812685 --- /dev/null +++ b/server/osa/domain/discovery/query/search_records.py @@ -0,0 +1,54 @@ +"""SearchRecords query — search and filter published records.""" + +from osa.domain.discovery.model.value import ( + Filter, + RecordSearchResult, + SortOrder, +) +from osa.domain.discovery.service.discovery import DiscoveryService +from osa.domain.shared.authorization.gate import public +from osa.domain.shared.query import Query, QueryHandler, Result + + +class SearchRecords(Query): + filters: list[Filter] = [] + q: str | None = None + sort: str = "published_at" + order: SortOrder = SortOrder.DESC + cursor: str | None = None + limit: int = 20 + + +class SearchRecordsResult(Result): + results: list[dict] + total: int + cursor: str | None + has_more: bool + + +class SearchRecordsHandler(QueryHandler[SearchRecords, SearchRecordsResult]): + __auth__ = public() + discovery_service: DiscoveryService + + async def run(self, cmd: SearchRecords) -> SearchRecordsResult: + result: RecordSearchResult = await self.discovery_service.search_records( + filters=cmd.filters, + q=cmd.q, + sort=cmd.sort, + order=cmd.order, + cursor=cmd.cursor, + limit=cmd.limit, + ) + return SearchRecordsResult( + results=[ + { + "srn": str(r.srn), + "published_at": r.published_at.isoformat(), + "metadata": r.metadata, + } + for r in result.results + ], + total=result.total, + cursor=result.cursor, + has_more=result.has_more, + ) diff --git a/server/osa/domain/discovery/service/__init__.py b/server/osa/domain/discovery/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/discovery/service/discovery.py b/server/osa/domain/discovery/service/discovery.py new file mode 100644 index 0000000..ad2b508 --- /dev/null +++ b/server/osa/domain/discovery/service/discovery.py @@ -0,0 +1,206 @@ +"""DiscoveryService — read-only business logic for record and feature search.""" + +from __future__ import annotations + +import logging + +from osa.domain.discovery.model.value import ( + VALID_OPERATORS, + FeatureCatalog, + FeatureSearchResult, + Filter, + FilterOperator, + RecordSearchResult, + SortOrder, + decode_cursor, + encode_cursor, +) +from osa.domain.discovery.port.field_definition_reader import FieldDefinitionReader +from osa.domain.discovery.port.read_store import DiscoveryReadStore +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.error import NotFoundError, ValidationError +from osa.domain.shared.model.srn import RecordSRN +from osa.domain.shared.service import Service + +logger = logging.getLogger(__name__) + + +class DiscoveryService(Service): + """Orchestrates validation and delegation for discovery queries.""" + + read_store: DiscoveryReadStore + field_reader: FieldDefinitionReader + + async def search_records( + self, + filters: list[Filter], + q: str | None, + sort: str, + order: SortOrder, + cursor: str | None, + limit: int, + ) -> RecordSearchResult: + """Validate inputs and delegate record search to the read store.""" + if limit < 1 or limit > 100: + raise ValidationError("limit must be between 1 and 100", field="limit") + + field_map = await self.field_reader.get_all_field_types() + + # Validate filter fields and operators + for f in filters: + if f.field not in field_map: + raise ValidationError( + f"Unknown field '{f.field}': not defined in any registered schema", + field=f.field, + ) + field_type = field_map[f.field] + valid_ops = VALID_OPERATORS[field_type] + if f.operator not in valid_ops: + raise ValidationError( + f"Operator '{f.operator}' is not valid for field '{f.field}' " + f"(type '{field_type}'). Valid: {sorted(valid_ops)}", + field=f.field, + ) + + # Validate sort field + if sort != "published_at" and sort not in field_map: + raise ValidationError( + f"Unknown sort field '{sort}': not defined in any registered schema", + field="sort", + ) + + # Decode cursor + decoded_cursor = None + if cursor is not None: + try: + decoded_cursor = decode_cursor(cursor) + except ValueError as exc: + raise ValidationError(str(exc), field="cursor") from exc + + # Identify text-searchable fields for free-text q + text_fields = [ + name for name, ft in field_map.items() if ft in (FieldType.TEXT, FieldType.URL) + ] + + results, total = await self.read_store.search_records( + filters=filters, + text_fields=text_fields, + q=q, + sort=sort, + order=order, + cursor=decoded_cursor, + limit=limit, + field_types=field_map, + ) + + has_more = len(results) == limit and len(results) < total + next_cursor = None + if has_more and results: + last = results[-1] + if sort == "published_at": + sort_val = last.published_at.isoformat() + else: + sort_val = last.metadata.get(sort) + next_cursor = encode_cursor(sort_val, str(last.srn)) + + return RecordSearchResult( + results=results, + total=total, + cursor=next_cursor, + has_more=has_more, + ) + + async def get_feature_catalog(self) -> FeatureCatalog: + """Delegate feature catalog listing to the read store.""" + entries = await self.read_store.get_feature_catalog() + return FeatureCatalog(tables=entries) + + async def search_features( + self, + hook_name: str, + filters: list[Filter], + record_srn: RecordSRN | None, + sort: str, + order: SortOrder, + cursor: str | None, + limit: int, + ) -> FeatureSearchResult: + """Validate inputs and delegate feature search to the read store.""" + if limit < 1 or limit > 100: + raise ValidationError("limit must be between 1 and 100", field="limit") + + # Look up the feature table schema + entry = await self.read_store.get_feature_table_schema(hook_name) + if entry is None: + raise NotFoundError(f"Feature table not found: {hook_name}") + + # Build column type map from catalog schema + col_map: dict[str, str] = {col.name: col.type for col in entry.columns} + # Also allow sort/filter on record_srn + col_map["record_srn"] = "string" + + # Map JSON types to FieldType equivalents for operator validation + json_type_to_ops: dict[str, set[FilterOperator]] = { + "string": {FilterOperator.EQ, FilterOperator.CONTAINS}, + "number": {FilterOperator.EQ, FilterOperator.GTE, FilterOperator.LTE}, + "integer": {FilterOperator.EQ, FilterOperator.GTE, FilterOperator.LTE}, + "boolean": {FilterOperator.EQ}, + "array": {FilterOperator.EQ}, + "object": {FilterOperator.EQ}, + } + + # Validate filters + for f in filters: + if f.field not in col_map: + raise ValidationError( + f"Unknown column '{f.field}' in feature table '{hook_name}'", + field=f.field, + ) + json_type = col_map[f.field] + valid_ops = json_type_to_ops.get(json_type, {FilterOperator.EQ}) + if f.operator not in valid_ops: + raise ValidationError( + f"Operator '{f.operator}' is not valid for column '{f.field}' " + f"(type '{json_type}'). Valid: {sorted(valid_ops)}", + field=f.field, + ) + + # Validate sort column + if sort != "id" and sort not in col_map: + raise ValidationError( + f"Unknown sort column '{sort}' in feature table '{hook_name}'", + field="sort", + ) + + # Decode cursor + try: + decoded_cursor = decode_cursor(cursor) if cursor else None + except ValueError as exc: + raise ValidationError(str(exc), field="cursor") from exc + + rows, total = await self.read_store.search_features( + hook_name=hook_name, + filters=filters, + record_srn=record_srn, + sort=sort, + order=order, + cursor=decoded_cursor, + limit=limit, + ) + + has_more = len(rows) == limit and len(rows) < total + next_cursor = None + if has_more and rows: + last = rows[-1] + if sort == "id": + sort_val = last.row_id + else: + sort_val = last.data.get(sort) + next_cursor = encode_cursor(sort_val, last.row_id) + + return FeatureSearchResult( + rows=rows, + total=total, + cursor=next_cursor, + has_more=has_more, + ) diff --git a/server/osa/domain/discovery/util/__init__.py b/server/osa/domain/discovery/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/osa/domain/discovery/util/di/__init__.py b/server/osa/domain/discovery/util/di/__init__.py new file mode 100644 index 0000000..8672799 --- /dev/null +++ b/server/osa/domain/discovery/util/di/__init__.py @@ -0,0 +1,3 @@ +from osa.domain.discovery.util.di.provider import DiscoveryProvider + +__all__ = ["DiscoveryProvider"] diff --git a/server/osa/domain/discovery/util/di/provider.py b/server/osa/domain/discovery/util/di/provider.py new file mode 100644 index 0000000..325367f --- /dev/null +++ b/server/osa/domain/discovery/util/di/provider.py @@ -0,0 +1,27 @@ +"""Dishka DI provider for the discovery domain.""" + +from dishka import provide + +from osa.domain.discovery.port.field_definition_reader import FieldDefinitionReader +from osa.domain.discovery.port.read_store import DiscoveryReadStore +from osa.domain.discovery.query.get_feature_catalog import GetFeatureCatalogHandler +from osa.domain.discovery.query.search_features import SearchFeaturesHandler +from osa.domain.discovery.query.search_records import SearchRecordsHandler +from osa.domain.discovery.service.discovery import DiscoveryService +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + + +class DiscoveryProvider(Provider): + @provide(scope=Scope.UOW) + def get_discovery_service( + self, + read_store: DiscoveryReadStore, + field_reader: FieldDefinitionReader, + ) -> DiscoveryService: + return DiscoveryService(read_store=read_store, field_reader=field_reader) + + # Query Handlers + search_records_handler = provide(SearchRecordsHandler, scope=Scope.UOW) + get_feature_catalog_handler = provide(GetFeatureCatalogHandler, scope=Scope.UOW) + search_features_handler = provide(SearchFeaturesHandler, scope=Scope.UOW) diff --git a/server/osa/domain/record/port/feature_reader.py b/server/osa/domain/record/port/feature_reader.py new file mode 100644 index 0000000..4b0ee5d --- /dev/null +++ b/server/osa/domain/record/port/feature_reader.py @@ -0,0 +1,19 @@ +"""FeatureReader port — cross-domain read port for feature data enrichment.""" + +from __future__ import annotations + +from typing import Any, Protocol + +from osa.domain.shared.model.srn import RecordSRN + + +class FeatureReader(Protocol): + async def get_features_for_record( + self, record_srn: RecordSRN + ) -> dict[str, list[dict[str, Any]]]: + """Return {hook_name: [row_dicts]} for all feature tables. + + Returns {} when no feature tables exist or record has no feature data. + Excludes auto columns (id, created_at) from row dicts. + """ + ... diff --git a/server/osa/domain/record/query/get_record.py b/server/osa/domain/record/query/get_record.py index 0400e05..cb586c7 100644 --- a/server/osa/domain/record/query/get_record.py +++ b/server/osa/domain/record/query/get_record.py @@ -18,6 +18,7 @@ class RecordDetail(Result): deposition_srn: DepositionSRN metadata: dict[str, Any] published_at: datetime + features: dict[str, list[dict[str, Any]]] = {} class GetRecordHandler(QueryHandler[GetRecord, RecordDetail]): @@ -26,9 +27,11 @@ class GetRecordHandler(QueryHandler[GetRecord, RecordDetail]): async def run(self, cmd: GetRecord) -> RecordDetail: record = await self.record_service.get(cmd.srn) + features = await self.record_service.get_features_for_record(cmd.srn) return RecordDetail( srn=record.srn, deposition_srn=record.deposition_srn, metadata=record.metadata, published_at=record.published_at, + features=features, ) diff --git a/server/osa/domain/record/service/record.py b/server/osa/domain/record/service/record.py index ff4250b..2b51fde 100644 --- a/server/osa/domain/record/service/record.py +++ b/server/osa/domain/record/service/record.py @@ -1,8 +1,10 @@ """RecordService - orchestrates record creation from approved depositions.""" +from __future__ import annotations + import logging from datetime import UTC, datetime -from typing import Any +from typing import TYPE_CHECKING, Any from uuid import uuid4 from osa.domain.record.event.record_published import RecordPublished @@ -22,6 +24,9 @@ from osa.domain.shared.outbox import Outbox from osa.domain.shared.service import Service +if TYPE_CHECKING: + from osa.domain.record.port.feature_reader import FeatureReader + logger = logging.getLogger(__name__) @@ -31,6 +36,13 @@ class RecordService(Service): record_repo: RecordRepository outbox: Outbox node_domain: Domain + feature_reader: FeatureReader + + async def get_features_for_record( + self, record_srn: RecordSRN + ) -> dict[str, list[dict[str, Any]]]: + """Fetch feature data for a record.""" + return await self.feature_reader.get_features_for_record(record_srn) async def get(self, srn: RecordSRN) -> Record: """Retrieve a published record by SRN.""" diff --git a/server/osa/infrastructure/persistence/adapter/discovery.py b/server/osa/infrastructure/persistence/adapter/discovery.py new file mode 100644 index 0000000..52a6ffe --- /dev/null +++ b/server/osa/infrastructure/persistence/adapter/discovery.py @@ -0,0 +1,395 @@ +"""Infrastructure adapters for the discovery domain — read-only SQL queries.""" + +from __future__ import annotations + +import logging +from typing import Any + +from sqlalchemy import ( + Column, + Date, + Float, + Integer, + MetaData, + String, + Table, + and_, + cast, + func, + or_, + select, + text, +) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.elements import quoted_name + +from osa.domain.discovery.model.value import ( + ColumnInfo, + FeatureCatalogEntry, + FeatureRow, + Filter, + FilterOperator, + RecordSummary, + SortOrder, +) +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.srn import RecordSRN +from osa.infrastructure.persistence.tables import ( + feature_tables_table, + records_table, + schemas_table, +) + +logger = logging.getLogger(__name__) + + +class PostgresFieldDefinitionReader: + """Builds a global field_name -> FieldType map from all registered schemas.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get_all_field_types(self) -> dict[str, FieldType]: + stmt = select(schemas_table.c.srn, schemas_table.c.fields) + result = await self.session.execute(stmt) + rows = result.mappings().all() + + field_map: dict[str, FieldType] = {} + for row in rows: + for field_def in row["fields"]: + name = field_def["name"] + field_type = FieldType(field_def["type"]) + if name in field_map and field_map[name] != field_type: + raise ValidationError( + f"Conflicting types for field '{name}': " + f"'{field_map[name]}' vs '{field_type}'", + field=name, + ) + field_map[name] = field_type + + return field_map + + +class PostgresDiscoveryReadStore: + """Direct SQL queries against records and feature tables for discovery.""" + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def search_records( + self, + filters: list[Filter], + text_fields: list[str], + q: str | None, + sort: str, + order: SortOrder, + cursor: dict[str, Any] | None, + limit: int, + field_types: dict[str, FieldType] | None = None, + ) -> tuple[list[RecordSummary], int]: + """Build and execute a dynamic SQL query for record search.""" + t = records_table + conditions: list[Any] = [] + ft = field_types or {} + + # Build filter conditions + for f in filters: + conditions.append(self._record_filter_clause(f, ft.get(f.field))) + + # Free-text search across text fields + if q and text_fields: + pattern = f"%{q}%" + text_clauses = [t.c.metadata[field].astext.ilike(pattern) for field in text_fields] + conditions.append(or_(*text_clauses)) + + # Determine sort expression + if sort == "published_at": + sort_expr = t.c.published_at + else: + sort_expr = t.c.metadata[sort].astext + + # Sort direction + if order == SortOrder.ASC: + order_clauses = [sort_expr.asc().nullslast(), t.c.srn.asc()] + else: + order_clauses = [sort_expr.desc().nullslast(), t.c.srn.desc()] + + # Keyset cursor + if cursor is not None: + cursor_sort = cursor["s"] + cursor_id = cursor["id"] + if order == SortOrder.ASC: + conditions.append( + or_( + sort_expr > cursor_sort, + and_(sort_expr == cursor_sort, t.c.srn > cursor_id), + ) + ) + else: + conditions.append( + or_( + sort_expr < cursor_sort, + and_(sort_expr == cursor_sort, t.c.srn < cursor_id), + ) + ) + + # Build query with COUNT(*) OVER() for total + where_clause = and_(*conditions) if conditions else text("TRUE") + total_col = func.count().over().label("_total") + + stmt = ( + select(t.c.srn, t.c.published_at, t.c.metadata, total_col) + .where(where_clause) + .order_by(*order_clauses) + .limit(limit) + ) + + result = await self.session.execute(stmt) + rows = result.mappings().all() + + total = rows[0]["_total"] if rows else 0 + results = [ + RecordSummary( + srn=RecordSRN.parse(row["srn"]), + published_at=row["published_at"], + metadata=row["metadata"], + ) + for row in rows + ] + + return results, total + + async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: + """List all feature tables with column schemas and record counts.""" + stmt = select( + feature_tables_table.c.hook_name, + feature_tables_table.c.pg_table, + feature_tables_table.c.feature_schema, + ) + result = await self.session.execute(stmt) + catalog_rows = result.mappings().all() + + entries: list[FeatureCatalogEntry] = [] + for row in catalog_rows: + schema_data = row["feature_schema"] + columns_raw = schema_data.get("columns", []) if isinstance(schema_data, dict) else [] + columns = [ + ColumnInfo( + name=col["name"], + type=col.get("json_type", "string"), + required=col.get("required", False), + ) + for col in columns_raw + ] + + pg_table = row["pg_table"] + safe_table = quoted_name(pg_table, quote=True) + count_stmt = select(func.count(func.distinct(text("record_srn")))).select_from( + text(f"features.{safe_table}") + ) + count_result = await self.session.execute(count_stmt) + record_count = count_result.scalar() or 0 + + entries.append( + FeatureCatalogEntry( + hook_name=row["hook_name"], + columns=columns, + record_count=record_count, + ) + ) + + return entries + + async def get_feature_table_schema(self, hook_name: str) -> FeatureCatalogEntry | None: + """Look up a single feature table's schema by hook name.""" + stmt = select( + feature_tables_table.c.hook_name, + feature_tables_table.c.feature_schema, + ).where(feature_tables_table.c.hook_name == hook_name) + result = await self.session.execute(stmt) + row = result.mappings().first() + if row is None: + return None + + schema_data = row["feature_schema"] + columns_raw = schema_data.get("columns", []) if isinstance(schema_data, dict) else [] + columns = [ + ColumnInfo( + name=col["name"], + type=col.get("json_type", "string"), + required=col.get("required", False), + ) + for col in columns_raw + ] + + return FeatureCatalogEntry( + hook_name=row["hook_name"], + columns=columns, + record_count=0, + ) + + async def search_features( + self, + hook_name: str, + filters: list[Filter], + record_srn: RecordSRN | None, + sort: str, + order: SortOrder, + cursor: dict[str, Any] | None, + limit: int, + ) -> tuple[list[FeatureRow], int]: + """Build and execute a dynamic SQL query for feature row search.""" + # Look up pg_table and feature_schema from catalog + pg_table_stmt = select( + feature_tables_table.c.pg_table, + feature_tables_table.c.feature_schema, + ).where(feature_tables_table.c.hook_name == hook_name) + pg_result = await self.session.execute(pg_table_stmt) + pg_row = pg_result.mappings().first() + if pg_row is None: + return [], 0 + pg_table: str = pg_row["pg_table"] + feature_schema: dict = pg_row["feature_schema"] + + # Build Table with full column list from schema using local MetaData + from osa.domain.shared.model.hook import ColumnDef + from osa.infrastructure.persistence.column_mapper import map_column + + schema_columns = ( + feature_schema.get("columns", []) if isinstance(feature_schema, dict) else [] + ) + data_columns = [ + map_column( + ColumnDef( + name=col["name"], + json_type=col.get("json_type", "string"), + format=col.get("format"), + required=col.get("required", False), + ) + ) + for col in schema_columns + ] + + local_meta = MetaData() + ft = Table( + pg_table, + local_meta, + Column("id", Integer, primary_key=True), + Column("record_srn", String), + Column("created_at", String), + *data_columns, + schema="features", + ) + + conditions: list[Any] = [] + + # Record SRN filter + if record_srn is not None: + conditions.append(ft.c.record_srn == str(record_srn)) + + # Column filters — all columns are known from schema + for f in filters: + col = ft.c[f.field] + if f.operator == FilterOperator.EQ: + conditions.append(col == f.value) + elif f.operator == FilterOperator.CONTAINS: + conditions.append(cast(col, String).ilike(f"%{f.value}%")) + elif f.operator == FilterOperator.GTE: + conditions.append(col >= f.value) + elif f.operator == FilterOperator.LTE: + conditions.append(col <= f.value) + + # Sort expression + if sort == "id": + sort_expr = ft.c.id + else: + sort_expr = ft.c[sort] + + if order == SortOrder.ASC: + order_clauses = [sort_expr.asc(), ft.c.id.asc()] + else: + order_clauses = [sort_expr.desc(), ft.c.id.desc()] + + # Keyset cursor + if cursor is not None: + cursor_sort = cursor["s"] + cursor_id = cursor["id"] + if order == SortOrder.ASC: + conditions.append( + or_( + sort_expr > cursor_sort, + and_(sort_expr == cursor_sort, ft.c.id > cursor_id), + ) + ) + else: + conditions.append( + or_( + sort_expr < cursor_sort, + and_(sort_expr == cursor_sort, ft.c.id < cursor_id), + ) + ) + + where_clause = and_(*conditions) if conditions else text("TRUE") + total_col = func.count().over().label("_total") + + # Select all columns except auto ones, plus total + auto_cols = {"id", "created_at"} + stmt = ( + select( + ft.c.id, + ft.c.record_srn, + total_col, + *[ + c + for c in ft.columns + if c.key not in auto_cols and c.key not in ("id", "record_srn") + ], + ) + .where(where_clause) + .order_by(*order_clauses) + .limit(limit) + ) + + result = await self.session.execute(stmt) + rows = result.mappings().all() + + total = rows[0]["_total"] if rows else 0 + feature_rows: list[FeatureRow] = [] + for row in rows: + row_dict = dict(row) + row_dict.pop("_total", None) + row_id = row_dict.pop("id") + rsrn = RecordSRN.parse(row_dict.pop("record_srn")) + row_dict.pop("created_at", None) + feature_rows.append(FeatureRow(row_id=row_id, record_srn=rsrn, data=row_dict)) + + return feature_rows, total + + @staticmethod + def _record_filter_clause(f: Filter, field_type: FieldType | None = None) -> Any: + """Build a SQL clause for a single record metadata filter.""" + t = records_table + if f.operator == FilterOperator.EQ: + # Use JSONB @> containment (GIN-indexed) + return t.c.metadata.op("@>")(cast(func.json_build_object(f.field, f.value), JSONB)) + elif f.operator == FilterOperator.CONTAINS: + return t.c.metadata[f.field].astext.ilike(f"%{f.value}%") + elif f.operator in (FilterOperator.GTE, FilterOperator.LTE): + # Use typed casts: numeric for NUMBER, date for DATE, string fallback + if field_type == FieldType.NUMBER: + col_expr = cast(t.c.metadata[f.field].astext, Float) + val = float(f.value) + elif field_type == FieldType.DATE: + col_expr = cast(t.c.metadata[f.field].astext, Date) + val = str(f.value) + else: + col_expr = cast(t.c.metadata[f.field].astext, String) + val = str(f.value) + if f.operator == FilterOperator.GTE: + return col_expr >= val + else: + return col_expr <= val + else: + raise ValueError(f"Unknown operator: {f.operator}") # pragma: no cover diff --git a/server/osa/infrastructure/persistence/adapter/feature_reader.py b/server/osa/infrastructure/persistence/adapter/feature_reader.py new file mode 100644 index 0000000..a118ca8 --- /dev/null +++ b/server/osa/infrastructure/persistence/adapter/feature_reader.py @@ -0,0 +1,61 @@ +"""PostgresFeatureReader — reads feature data for record enrichment.""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy import select, text +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.elements import quoted_name + +from osa.domain.shared.model.srn import RecordSRN +from osa.infrastructure.persistence.tables import feature_tables_table + + +class PostgresFeatureReader: + """Queries feature_tables catalog and dynamic feature tables for a record.""" + + AUTO_COLUMNS = {"id", "created_at"} + + def __init__(self, session: AsyncSession) -> None: + self.session = session + + async def get_features_for_record( + self, record_srn: RecordSRN + ) -> dict[str, list[dict[str, Any]]]: + # Get all feature tables from catalog + stmt = select( + feature_tables_table.c.hook_name, + feature_tables_table.c.pg_table, + ) + result = await self.session.execute(stmt) + catalog_rows = result.mappings().all() + + if not catalog_rows: + return {} + + features: dict[str, list[dict[str, Any]]] = {} + for row in catalog_rows: + hook_name = row["hook_name"] + pg_table = row["pg_table"] + + # Query the dynamic feature table for this record + safe_table = quoted_name(pg_table, quote=True) + query = text( + f"SELECT * FROM features.{safe_table} WHERE record_srn = :srn" # noqa: S608 + ) + feat_result = await self.session.execute(query, {"srn": str(record_srn)}) + feat_rows = feat_result.mappings().all() + + if feat_rows: + rows_list: list[dict[str, Any]] = [] + for feat_row in feat_rows: + row_dict = { + k: v + for k, v in dict(feat_row).items() + if k not in self.AUTO_COLUMNS and k != "record_srn" + } + rows_list.append(row_dict) + features[hook_name] = rows_list + + return features diff --git a/server/osa/infrastructure/persistence/di.py b/server/osa/infrastructure/persistence/di.py index 8815144..bf180c6 100644 --- a/server/osa/infrastructure/persistence/di.py +++ b/server/osa/infrastructure/persistence/di.py @@ -10,9 +10,11 @@ from osa.domain.deposition.port.repository import DepositionRepository from osa.domain.deposition.port.schema_reader import SchemaReader from osa.domain.deposition.port.storage import FileStoragePort +from osa.domain.record.port.feature_reader import FeatureReader from osa.domain.record.port.repository import RecordRepository from osa.domain.record.query.get_record import GetRecordHandler from osa.domain.record.service import RecordService +from osa.infrastructure.persistence.adapter.feature_reader import PostgresFeatureReader from osa.domain.source.port.storage import SourceStoragePort from osa.domain.feature.port.storage import FeatureStoragePort from osa.domain.validation.port.storage import HookStoragePort @@ -23,6 +25,12 @@ from osa.domain.shared.port.event_repository import EventRepository from osa.domain.feature.port.feature_store import FeatureStore from osa.domain.validation.port.repository import ValidationRunRepository +from osa.domain.discovery.port.field_definition_reader import FieldDefinitionReader +from osa.domain.discovery.port.read_store import DiscoveryReadStore +from osa.infrastructure.persistence.adapter.discovery import ( + PostgresDiscoveryReadStore, + PostgresFieldDefinitionReader, +) from osa.infrastructure.persistence.adapter.readers import ( OntologyReaderAdapter, SchemaReaderAdapter, @@ -126,12 +134,16 @@ def get_hook_storage(self, file_storage: FileStoragePort) -> HookStoragePort: def get_feature_storage(self, file_storage: FileStoragePort) -> FeatureStoragePort: return file_storage # type: ignore[return-value] + # Feature reader + feature_reader = provide(PostgresFeatureReader, scope=Scope.UOW, provides=FeatureReader) + @provide(scope=Scope.UOW) def get_record_service( self, record_repo: RecordRepository, outbox: Outbox, config: Config, + feature_reader: FeatureReader, ) -> RecordService: """Provide RecordService for UOW scope. @@ -141,7 +153,16 @@ def get_record_service( record_repo=record_repo, outbox=outbox, node_domain=Domain(config.server.domain), + feature_reader=feature_reader, ) + # Discovery adapters + discovery_read_store = provide( + PostgresDiscoveryReadStore, scope=Scope.UOW, provides=DiscoveryReadStore + ) + field_definition_reader = provide( + PostgresFieldDefinitionReader, scope=Scope.UOW, provides=FieldDefinitionReader + ) + # Record query handlers get_record_handler = provide(GetRecordHandler, scope=Scope.UOW) diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index bd4af14..d667071 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -14,6 +14,7 @@ UniqueConstraint, text, ) +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.types import JSON # Metadata object for all tables @@ -65,13 +66,19 @@ metadata, Column("srn", String, primary_key=True), Column("deposition_srn", String, nullable=False), - Column("metadata", JSON, nullable=False), + Column("metadata", JSONB, nullable=False), Column("indexes", JSON, nullable=False), Column("published_at", DateTime(timezone=True), nullable=False), ) Index("idx_records_deposition_srn", records_table.c.deposition_srn) Index("idx_records_published_at", records_table.c.published_at) +Index( + "idx_records_metadata_gin", + records_table.c.metadata, + postgresql_using="gin", + postgresql_ops={"metadata": "jsonb_path_ops"}, +) # ============================================================================ diff --git a/server/tests/unit/domain/discovery/__init__.py b/server/tests/unit/domain/discovery/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/tests/unit/domain/discovery/test_discovery_service.py b/server/tests/unit/domain/discovery/test_discovery_service.py new file mode 100644 index 0000000..0be662a --- /dev/null +++ b/server/tests/unit/domain/discovery/test_discovery_service.py @@ -0,0 +1,328 @@ +"""Tests for DiscoveryService — filter validation, operator validation, delegation.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock + +import pytest + +from osa.domain.discovery.model.value import ( + ColumnInfo, + FeatureCatalogEntry, + FeatureRow, + Filter, + FilterOperator, + RecordSummary, + SortOrder, +) +from osa.domain.discovery.service.discovery import DiscoveryService +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.error import ValidationError +from osa.domain.shared.model.srn import RecordSRN + + +@pytest.fixture +def mock_read_store() -> AsyncMock: + store = AsyncMock() + store.search_records.return_value = ([], 0) + return store + + +@pytest.fixture +def mock_field_reader() -> AsyncMock: + reader = AsyncMock() + reader.get_all_field_types.return_value = { + "title": FieldType.TEXT, + "resolution": FieldType.NUMBER, + "method": FieldType.TERM, + "published_date": FieldType.DATE, + "is_public": FieldType.BOOLEAN, + "homepage": FieldType.URL, + } + return reader + + +@pytest.fixture +def service(mock_read_store: AsyncMock, mock_field_reader: AsyncMock) -> DiscoveryService: + return DiscoveryService(read_store=mock_read_store, field_reader=mock_field_reader) + + +class TestSearchRecordsValidation: + async def test_rejects_unknown_filter_field(self, service: DiscoveryService) -> None: + with pytest.raises(ValidationError, match="Unknown field 'bogus'"): + await service.search_records( + filters=[Filter(field="bogus", operator=FilterOperator.EQ, value="x")], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + + async def test_rejects_invalid_operator_for_type(self, service: DiscoveryService) -> None: + with pytest.raises(ValidationError, match="contains"): + await service.search_records( + filters=[Filter(field="resolution", operator=FilterOperator.CONTAINS, value="x")], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + + async def test_rejects_unknown_sort_field(self, service: DiscoveryService) -> None: + with pytest.raises(ValidationError, match="Unknown sort field"): + await service.search_records( + filters=[], + q=None, + sort="nonexistent", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + + async def test_accepts_published_at_sort(self, service: DiscoveryService) -> None: + result = await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + assert result.total == 0 + + async def test_accepts_metadata_field_sort(self, service: DiscoveryService) -> None: + result = await service.search_records( + filters=[], + q=None, + sort="resolution", + order=SortOrder.ASC, + cursor=None, + limit=20, + ) + assert result.total == 0 + + async def test_rejects_limit_too_low(self, service: DiscoveryService) -> None: + with pytest.raises(ValidationError, match="limit"): + await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=0, + ) + + async def test_rejects_limit_too_high(self, service: DiscoveryService) -> None: + with pytest.raises(ValidationError, match="limit"): + await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=101, + ) + + +class TestSearchRecordsDelegation: + async def test_delegates_to_read_store( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + await service.search_records( + filters=[Filter(field="method", operator=FilterOperator.EQ, value="X-ray")], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + + mock_read_store.search_records.assert_called_once() + call_kwargs = mock_read_store.search_records.call_args + assert len(call_kwargs.kwargs["filters"]) == 1 + assert call_kwargs.kwargs["q"] is None + assert call_kwargs.kwargs["sort"] == "published_at" + assert call_kwargs.kwargs["limit"] == 20 + + async def test_extracts_text_fields_for_q( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + await service.search_records( + filters=[], + q="kinase", + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + + call_kwargs = mock_read_store.search_records.call_args + text_fields = call_kwargs.kwargs["text_fields"] + # title (TEXT) and homepage (URL) are text-searchable + assert "title" in text_fields + assert "homepage" in text_fields + assert "resolution" not in text_fields + + async def test_decodes_cursor( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + from osa.domain.discovery.model.value import encode_cursor + + cursor = encode_cursor("2026-01-01", "urn:osa:localhost:rec:abc@1") + await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=cursor, + limit=20, + ) + + call_kwargs = mock_read_store.search_records.call_args + decoded = call_kwargs.kwargs["cursor"] + assert decoded["s"] == "2026-01-01" + assert decoded["id"] == "urn:osa:localhost:rec:abc@1" + + async def test_invalid_cursor_raises(self, service: DiscoveryService) -> None: + with pytest.raises(ValidationError, match="cursor"): + await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor="not-a-cursor!!!", + limit=20, + ) + + async def test_encodes_next_cursor_from_results( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + ts = datetime(2026, 1, 1, tzinfo=UTC) + mock_read_store.search_records.return_value = ( + [RecordSummary(srn=srn, published_at=ts, metadata={"title": "Test"})], + 5, + ) + + result = await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=1, + ) + + assert result.has_more is True + assert result.cursor is not None + + from osa.domain.discovery.model.value import decode_cursor + + decoded = decode_cursor(result.cursor) + assert decoded["id"] == str(srn) + + async def test_no_cursor_when_no_more_results( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + mock_read_store.search_records.return_value = ([], 0) + + result = await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + + assert result.cursor is None + assert result.has_more is False + + +class TestSearchRecordsFieldTypes: + async def test_passes_field_types_to_read_store( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + + call_kwargs = mock_read_store.search_records.call_args + field_types = call_kwargs.kwargs["field_types"] + assert field_types["resolution"] == FieldType.NUMBER + assert field_types["title"] == FieldType.TEXT + + +class TestFeatureCursorEncoding: + async def test_cursor_encodes_row_id( + self, mock_read_store: AsyncMock, mock_field_reader: AsyncMock + ) -> None: + from osa.domain.discovery.model.value import decode_cursor + + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + mock_read_store.get_feature_table_schema.return_value = FeatureCatalogEntry( + hook_name="detect_pockets", + columns=[ColumnInfo(name="score", type="number", required=True)], + record_count=0, + ) + mock_read_store.search_features.return_value = ( + [FeatureRow(row_id=42, record_srn=srn, data={"score": 7.66})], + 5, + ) + + service = DiscoveryService(read_store=mock_read_store, field_reader=mock_field_reader) + result = await service.search_features( + hook_name="detect_pockets", + filters=[], + record_srn=None, + sort="score", + order=SortOrder.DESC, + cursor=None, + limit=1, + ) + + assert result.has_more is True + assert result.cursor is not None + decoded = decode_cursor(result.cursor) + assert decoded["id"] == 42 + assert decoded["s"] == 7.66 + + async def test_cursor_uses_row_id_for_id_sort( + self, mock_read_store: AsyncMock, mock_field_reader: AsyncMock + ) -> None: + from osa.domain.discovery.model.value import decode_cursor + + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + mock_read_store.get_feature_table_schema.return_value = FeatureCatalogEntry( + hook_name="detect_pockets", + columns=[ColumnInfo(name="score", type="number", required=True)], + record_count=0, + ) + mock_read_store.search_features.return_value = ( + [FeatureRow(row_id=99, record_srn=srn, data={"score": 5.0})], + 3, + ) + + service = DiscoveryService(read_store=mock_read_store, field_reader=mock_field_reader) + result = await service.search_features( + hook_name="detect_pockets", + filters=[], + record_srn=None, + sort="id", + order=SortOrder.DESC, + cursor=None, + limit=1, + ) + + assert result.has_more is True + assert result.cursor is not None + decoded = decode_cursor(result.cursor) + # When sort is "id", sort_val should be the row_id itself + assert decoded["s"] == 99 + assert decoded["id"] == 99 diff --git a/server/tests/unit/domain/discovery/test_get_feature_catalog.py b/server/tests/unit/domain/discovery/test_get_feature_catalog.py new file mode 100644 index 0000000..6cfb1fe --- /dev/null +++ b/server/tests/unit/domain/discovery/test_get_feature_catalog.py @@ -0,0 +1,100 @@ +"""Tests for GetFeatureCatalogHandler and DiscoveryService.get_feature_catalog().""" + +from unittest.mock import AsyncMock + +import pytest + +from osa.domain.discovery.model.value import ColumnInfo, FeatureCatalogEntry +from osa.domain.discovery.query.get_feature_catalog import ( + GetFeatureCatalog, + GetFeatureCatalogHandler, + GetFeatureCatalogResult, +) +from osa.domain.discovery.service.discovery import DiscoveryService + + +@pytest.fixture +def mock_read_store() -> AsyncMock: + return AsyncMock() + + +@pytest.fixture +def mock_field_reader() -> AsyncMock: + reader = AsyncMock() + reader.get_all_field_types.return_value = {} + return reader + + +@pytest.fixture +def service(mock_read_store: AsyncMock, mock_field_reader: AsyncMock) -> DiscoveryService: + return DiscoveryService(read_store=mock_read_store, field_reader=mock_field_reader) + + +class TestGetFeatureCatalogHandler: + async def test_public_auth_gate(self) -> None: + from osa.domain.shared.authorization.gate import Public + + assert isinstance(GetFeatureCatalogHandler.__auth__, Public) + + async def test_delegates_to_service(self) -> None: + mock_service = AsyncMock() + from osa.domain.discovery.model.value import FeatureCatalog + + mock_service.get_feature_catalog.return_value = FeatureCatalog(tables=[]) + + handler = GetFeatureCatalogHandler(discovery_service=mock_service) + result: GetFeatureCatalogResult = await handler.run(GetFeatureCatalog()) + + assert result.tables == [] + mock_service.get_feature_catalog.assert_called_once() + + async def test_returns_correct_structure(self) -> None: + mock_service = AsyncMock() + from osa.domain.discovery.model.value import FeatureCatalog + + mock_service.get_feature_catalog.return_value = FeatureCatalog( + tables=[ + FeatureCatalogEntry( + hook_name="detect_pockets", + columns=[ + ColumnInfo(name="score", type="number", required=True), + ColumnInfo(name="volume", type="number", required=True), + ], + record_count=142, + ) + ] + ) + + handler = GetFeatureCatalogHandler(discovery_service=mock_service) + result: GetFeatureCatalogResult = await handler.run(GetFeatureCatalog()) + + assert len(result.tables) == 1 + assert result.tables[0]["hook_name"] == "detect_pockets" + assert result.tables[0]["record_count"] == 142 + assert len(result.tables[0]["columns"]) == 2 + + +class TestDiscoveryServiceGetFeatureCatalog: + async def test_delegates_to_read_store( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + mock_read_store.get_feature_catalog.return_value = [] + result = await service.get_feature_catalog() + assert result.tables == [] + mock_read_store.get_feature_catalog.assert_called_once() + + async def test_returns_entries_from_store( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + entries = [ + FeatureCatalogEntry( + hook_name="test_hook", + columns=[ColumnInfo(name="x", type="number", required=True)], + record_count=10, + ) + ] + mock_read_store.get_feature_catalog.return_value = entries + + result = await service.get_feature_catalog() + assert len(result.tables) == 1 + assert result.tables[0].hook_name == "test_hook" diff --git a/server/tests/unit/domain/discovery/test_search_features.py b/server/tests/unit/domain/discovery/test_search_features.py new file mode 100644 index 0000000..fc7a537 --- /dev/null +++ b/server/tests/unit/domain/discovery/test_search_features.py @@ -0,0 +1,212 @@ +"""Tests for SearchFeaturesHandler and DiscoveryService.search_features().""" + +from unittest.mock import AsyncMock + +import pytest + +from osa.domain.discovery.model.value import ( + ColumnInfo, + FeatureCatalogEntry, + FeatureRow, + FeatureSearchResult, + Filter, + FilterOperator, + SortOrder, +) +from osa.domain.discovery.query.search_features import ( + SearchFeatures, + SearchFeaturesHandler, + SearchFeaturesResult, +) +from osa.domain.discovery.service.discovery import DiscoveryService +from osa.domain.shared.error import NotFoundError, ValidationError +from osa.domain.shared.model.srn import RecordSRN + + +def _make_catalog_entry() -> FeatureCatalogEntry: + return FeatureCatalogEntry( + hook_name="detect_pockets", + columns=[ + ColumnInfo(name="score", type="number", required=True), + ColumnInfo(name="volume", type="number", required=True), + ColumnInfo(name="label", type="string", required=False), + ColumnInfo(name="is_active", type="boolean", required=False), + ], + record_count=10, + ) + + +@pytest.fixture +def mock_read_store() -> AsyncMock: + store = AsyncMock() + store.get_feature_table_schema.return_value = _make_catalog_entry() + store.search_features.return_value = ([], 0) + return store + + +@pytest.fixture +def mock_field_reader() -> AsyncMock: + reader = AsyncMock() + reader.get_all_field_types.return_value = {} + return reader + + +@pytest.fixture +def service(mock_read_store: AsyncMock, mock_field_reader: AsyncMock) -> DiscoveryService: + return DiscoveryService(read_store=mock_read_store, field_reader=mock_field_reader) + + +class TestSearchFeaturesHandler: + async def test_public_auth_gate(self) -> None: + from osa.domain.shared.authorization.gate import Public + + assert isinstance(SearchFeaturesHandler.__auth__, Public) + + async def test_delegates_to_service(self) -> None: + mock_service = AsyncMock() + mock_service.search_features.return_value = FeatureSearchResult( + rows=[], total=0, cursor=None, has_more=False + ) + + handler = SearchFeaturesHandler(discovery_service=mock_service) + result: SearchFeaturesResult = await handler.run(SearchFeatures(hook_name="detect_pockets")) + + assert result.total == 0 + mock_service.search_features.assert_called_once() + + async def test_maps_rows_with_record_srn(self) -> None: + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + mock_service = AsyncMock() + mock_service.search_features.return_value = FeatureSearchResult( + rows=[FeatureRow(row_id=1, record_srn=srn, data={"score": 7.66})], + total=1, + cursor=None, + has_more=False, + ) + + handler = SearchFeaturesHandler(discovery_service=mock_service) + result = await handler.run(SearchFeatures(hook_name="detect_pockets")) + + assert result.rows[0]["record_srn"] == str(srn) + assert result.rows[0]["score"] == 7.66 + + +class TestDiscoveryServiceSearchFeatures: + async def test_raises_not_found_for_unknown_hook( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + mock_read_store.get_feature_table_schema.return_value = None + + with pytest.raises(NotFoundError, match="unknown_hook"): + await service.search_features( + hook_name="unknown_hook", + filters=[], + record_srn=None, + sort="id", + order=SortOrder.DESC, + cursor=None, + limit=50, + ) + + async def test_rejects_unknown_column(self, service: DiscoveryService) -> None: + with pytest.raises(ValidationError, match="bogus"): + await service.search_features( + hook_name="detect_pockets", + filters=[Filter(field="bogus", operator=FilterOperator.EQ, value=1)], + record_srn=None, + sort="id", + order=SortOrder.DESC, + cursor=None, + limit=50, + ) + + async def test_validates_operator_for_number_column(self, service: DiscoveryService) -> None: + with pytest.raises(ValidationError, match="contains"): + await service.search_features( + hook_name="detect_pockets", + filters=[Filter(field="score", operator=FilterOperator.CONTAINS, value="x")], + record_srn=None, + sort="id", + order=SortOrder.DESC, + cursor=None, + limit=50, + ) + + async def test_validates_operator_for_boolean_column(self, service: DiscoveryService) -> None: + with pytest.raises(ValidationError, match="gte"): + await service.search_features( + hook_name="detect_pockets", + filters=[Filter(field="is_active", operator=FilterOperator.GTE, value=True)], + record_srn=None, + sort="id", + order=SortOrder.DESC, + cursor=None, + limit=50, + ) + + async def test_accepts_string_contains_operator(self, service: DiscoveryService) -> None: + await service.search_features( + hook_name="detect_pockets", + filters=[Filter(field="label", operator=FilterOperator.CONTAINS, value="test")], + record_srn=None, + sort="id", + order=SortOrder.DESC, + cursor=None, + limit=50, + ) + + async def test_passes_record_srn_filter( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + await service.search_features( + hook_name="detect_pockets", + filters=[], + record_srn=srn, + sort="id", + order=SortOrder.DESC, + cursor=None, + limit=50, + ) + + call_kwargs = mock_read_store.search_features.call_args + assert call_kwargs.kwargs["record_srn"] == srn + + async def test_decodes_cursor( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + from osa.domain.discovery.model.value import encode_cursor + + cursor = encode_cursor(7.66, 42) + await service.search_features( + hook_name="detect_pockets", + filters=[], + record_srn=None, + sort="score", + order=SortOrder.DESC, + cursor=cursor, + limit=50, + ) + + call_kwargs = mock_read_store.search_features.call_args + decoded = call_kwargs.kwargs["cursor"] + assert decoded["s"] == 7.66 + assert decoded["id"] == 42 + + async def test_delegates_to_read_store( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + await service.search_features( + hook_name="detect_pockets", + filters=[Filter(field="score", operator=FilterOperator.GTE, value=6.0)], + record_srn=None, + sort="score", + order=SortOrder.DESC, + cursor=None, + limit=50, + ) + + mock_read_store.search_features.assert_called_once() + call_kwargs = mock_read_store.search_features.call_args + assert call_kwargs.kwargs["hook_name"] == "detect_pockets" + assert len(call_kwargs.kwargs["filters"]) == 1 diff --git a/server/tests/unit/domain/discovery/test_search_records.py b/server/tests/unit/domain/discovery/test_search_records.py new file mode 100644 index 0000000..fb48a98 --- /dev/null +++ b/server/tests/unit/domain/discovery/test_search_records.py @@ -0,0 +1,69 @@ +"""Tests for SearchRecordsHandler — auth gate, delegation, result mapping.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock + +import pytest + +from osa.domain.discovery.model.value import RecordSearchResult, RecordSummary, SortOrder +from osa.domain.discovery.query.search_records import ( + SearchRecords, + SearchRecordsHandler, + SearchRecordsResult, +) +from osa.domain.shared.model.srn import RecordSRN + + +@pytest.fixture +def mock_service() -> AsyncMock: + return AsyncMock() + + +@pytest.fixture +def handler(mock_service: AsyncMock) -> SearchRecordsHandler: + return SearchRecordsHandler(discovery_service=mock_service) + + +class TestSearchRecordsHandler: + async def test_public_auth_gate(self, handler: SearchRecordsHandler) -> None: + from osa.domain.shared.authorization.gate import Public + + assert isinstance(handler.__auth__, Public) + + async def test_delegates_to_service( + self, handler: SearchRecordsHandler, mock_service: AsyncMock + ) -> None: + mock_service.search_records.return_value = RecordSearchResult( + results=[], total=0, cursor=None, has_more=False + ) + cmd = SearchRecords() + await handler.run(cmd) + + mock_service.search_records.assert_called_once_with( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + + async def test_maps_results( + self, handler: SearchRecordsHandler, mock_service: AsyncMock + ) -> None: + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + ts = datetime(2026, 1, 1, tzinfo=UTC) + mock_service.search_records.return_value = RecordSearchResult( + results=[RecordSummary(srn=srn, published_at=ts, metadata={"title": "Test"})], + total=1, + cursor="abc123", + has_more=False, + ) + + result: SearchRecordsResult = await handler.run(SearchRecords()) + + assert result.total == 1 + assert result.cursor == "abc123" + assert result.has_more is False + assert result.results[0]["srn"] == str(srn) + assert result.results[0]["metadata"] == {"title": "Test"} diff --git a/server/tests/unit/domain/discovery/test_value.py b/server/tests/unit/domain/discovery/test_value.py new file mode 100644 index 0000000..0accb29 --- /dev/null +++ b/server/tests/unit/domain/discovery/test_value.py @@ -0,0 +1,100 @@ +"""Tests for discovery domain value objects — cursor helpers and VALID_OPERATORS.""" + +import pytest + +from osa.domain.discovery.model.value import ( + VALID_OPERATORS, + FilterOperator, + decode_cursor, + encode_cursor, +) +from osa.domain.semantics.model.value import FieldType + + +class TestCursorRoundTrip: + def test_round_trip_string_values(self) -> None: + cursor = encode_cursor("2026-01-01", "urn:osa:localhost:rec:abc@1") + decoded = decode_cursor(cursor) + assert decoded["s"] == "2026-01-01" + assert decoded["id"] == "urn:osa:localhost:rec:abc@1" + + def test_round_trip_numeric_values(self) -> None: + cursor = encode_cursor(7.66, 123) + decoded = decode_cursor(cursor) + assert decoded["s"] == 7.66 + assert decoded["id"] == 123 + + def test_round_trip_none_sort_value(self) -> None: + cursor = encode_cursor(None, "urn:osa:localhost:rec:abc@1") + decoded = decode_cursor(cursor) + assert decoded["s"] is None + assert decoded["id"] == "urn:osa:localhost:rec:abc@1" + + +class TestDecodeCursorErrors: + def test_malformed_base64(self) -> None: + with pytest.raises(ValueError, match="Malformed cursor"): + decode_cursor("not-valid-base64!!!") + + def test_invalid_json(self) -> None: + import base64 + + bad = base64.urlsafe_b64encode(b"not json").decode() + with pytest.raises(ValueError, match="Malformed cursor"): + decode_cursor(bad) + + def test_missing_s_key(self) -> None: + import base64 + import json + + bad = base64.urlsafe_b64encode(json.dumps({"id": "x"}).encode()).decode() + with pytest.raises(ValueError, match="'s' and 'id'"): + decode_cursor(bad) + + def test_missing_id_key(self) -> None: + import base64 + import json + + bad = base64.urlsafe_b64encode(json.dumps({"s": 1}).encode()).decode() + with pytest.raises(ValueError, match="'s' and 'id'"): + decode_cursor(bad) + + def test_non_dict_payload(self) -> None: + import base64 + import json + + bad = base64.urlsafe_b64encode(json.dumps([1, 2]).encode()).decode() + with pytest.raises(ValueError, match="'s' and 'id'"): + decode_cursor(bad) + + +class TestValidOperators: + def test_text_operators(self) -> None: + assert VALID_OPERATORS[FieldType.TEXT] == {FilterOperator.EQ, FilterOperator.CONTAINS} + + def test_url_operators(self) -> None: + assert VALID_OPERATORS[FieldType.URL] == {FilterOperator.EQ, FilterOperator.CONTAINS} + + def test_number_operators(self) -> None: + assert VALID_OPERATORS[FieldType.NUMBER] == { + FilterOperator.EQ, + FilterOperator.GTE, + FilterOperator.LTE, + } + + def test_date_operators(self) -> None: + assert VALID_OPERATORS[FieldType.DATE] == { + FilterOperator.EQ, + FilterOperator.GTE, + FilterOperator.LTE, + } + + def test_boolean_operators(self) -> None: + assert VALID_OPERATORS[FieldType.BOOLEAN] == {FilterOperator.EQ} + + def test_term_operators(self) -> None: + assert VALID_OPERATORS[FieldType.TERM] == {FilterOperator.EQ} + + def test_all_field_types_have_operators(self) -> None: + for ft in FieldType: + assert ft in VALID_OPERATORS, f"Missing operators for {ft}" diff --git a/server/tests/unit/domain/record/test_get_record_handler.py b/server/tests/unit/domain/record/test_get_record_handler.py index 0d903b9..a1e175b 100644 --- a/server/tests/unit/domain/record/test_get_record_handler.py +++ b/server/tests/unit/domain/record/test_get_record_handler.py @@ -35,6 +35,7 @@ async def test_returns_record_detail(self): record = _make_record() service = AsyncMock() service.get.return_value = record + service.get_features_for_record.return_value = {} handler = GetRecordHandler(record_service=service) result = await handler.run(GetRecord(srn=record.srn)) diff --git a/server/tests/unit/domain/record/test_record_features.py b/server/tests/unit/domain/record/test_record_features.py new file mode 100644 index 0000000..6df74ca --- /dev/null +++ b/server/tests/unit/domain/record/test_record_features.py @@ -0,0 +1,216 @@ +"""Tests for feature enrichment — PostgresFeatureReader and RecordService integration.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from osa.domain.record.model.aggregate import Record +from osa.domain.record.query.get_record import GetRecord, GetRecordHandler, RecordDetail +from osa.domain.record.service.record import RecordService +from osa.domain.shared.model.srn import DepositionSRN, Domain, RecordSRN +from osa.infrastructure.persistence.adapter.feature_reader import PostgresFeatureReader + + +def _make_catalog_row(hook_name: str, pg_table: str) -> dict: + return {"hook_name": hook_name, "pg_table": pg_table} + + +class TestPostgresFeatureReader: + @pytest.fixture + def mock_session(self) -> AsyncMock: + return AsyncMock() + + @pytest.fixture + def reader(self, mock_session: AsyncMock) -> PostgresFeatureReader: + return PostgresFeatureReader(session=mock_session) + + async def test_returns_dict_keyed_by_hook_name( + self, reader: PostgresFeatureReader, mock_session: AsyncMock + ) -> None: + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + + # First call: catalog query + catalog_result = MagicMock() + catalog_result.mappings.return_value.all.return_value = [ + _make_catalog_row("detect_pockets", "detect_pockets_v1") + ] + + # Second call: feature table query + feature_result = MagicMock() + feature_result.mappings.return_value.all.return_value = [ + { + "id": 1, + "record_srn": str(srn), + "created_at": datetime.now(UTC), + "score": 7.66, + "volume": 1750.0, + } + ] + + mock_session.execute.side_effect = [catalog_result, feature_result] + + result = await reader.get_features_for_record(srn) + + assert "detect_pockets" in result + assert len(result["detect_pockets"]) == 1 + assert result["detect_pockets"][0]["score"] == 7.66 + assert result["detect_pockets"][0]["volume"] == 1750.0 + + async def test_excludes_auto_columns( + self, reader: PostgresFeatureReader, mock_session: AsyncMock + ) -> None: + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + + catalog_result = MagicMock() + catalog_result.mappings.return_value.all.return_value = [ + _make_catalog_row("test_hook", "test_hook_v1") + ] + + feature_result = MagicMock() + feature_result.mappings.return_value.all.return_value = [ + { + "id": 42, + "record_srn": str(srn), + "created_at": datetime.now(UTC), + "metric": 3.14, + } + ] + + mock_session.execute.side_effect = [catalog_result, feature_result] + + result = await reader.get_features_for_record(srn) + + row = result["test_hook"][0] + assert "id" not in row + assert "created_at" not in row + assert "record_srn" not in row + assert row["metric"] == 3.14 + + async def test_returns_empty_when_no_feature_tables( + self, reader: PostgresFeatureReader, mock_session: AsyncMock + ) -> None: + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + + catalog_result = MagicMock() + catalog_result.mappings.return_value.all.return_value = [] + mock_session.execute.return_value = catalog_result + + result = await reader.get_features_for_record(srn) + assert result == {} + + async def test_returns_empty_when_record_has_no_data( + self, reader: PostgresFeatureReader, mock_session: AsyncMock + ) -> None: + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + + catalog_result = MagicMock() + catalog_result.mappings.return_value.all.return_value = [ + _make_catalog_row("detect_pockets", "detect_pockets_v1") + ] + + feature_result = MagicMock() + feature_result.mappings.return_value.all.return_value = [] + + mock_session.execute.side_effect = [catalog_result, feature_result] + + result = await reader.get_features_for_record(srn) + assert result == {} + + async def test_includes_data_from_multiple_tables( + self, reader: PostgresFeatureReader, mock_session: AsyncMock + ) -> None: + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + + catalog_result = MagicMock() + catalog_result.mappings.return_value.all.return_value = [ + _make_catalog_row("hook_a", "hook_a_v1"), + _make_catalog_row("hook_b", "hook_b_v1"), + ] + + feat_a = MagicMock() + feat_a.mappings.return_value.all.return_value = [ + {"id": 1, "record_srn": str(srn), "created_at": datetime.now(UTC), "x": 1} + ] + + feat_b = MagicMock() + feat_b.mappings.return_value.all.return_value = [ + {"id": 2, "record_srn": str(srn), "created_at": datetime.now(UTC), "y": 2} + ] + + mock_session.execute.side_effect = [catalog_result, feat_a, feat_b] + + result = await reader.get_features_for_record(srn) + + assert "hook_a" in result + assert "hook_b" in result + assert result["hook_a"][0]["x"] == 1 + assert result["hook_b"][0]["y"] == 2 + + +def _make_record() -> Record: + return Record( + srn=RecordSRN.parse("urn:osa:localhost:rec:abc@1"), + deposition_srn=DepositionSRN.parse("urn:osa:localhost:dep:dep1"), + metadata={"title": "Test"}, + published_at=datetime.now(UTC), + ) + + +class TestRecordServiceFeatureEnrichment: + async def test_get_features_delegates_to_reader(self) -> None: + mock_repo = AsyncMock() + mock_outbox = AsyncMock() + mock_reader = AsyncMock() + mock_reader.get_features_for_record.return_value = {"hook_a": [{"score": 1.0}]} + + service = RecordService( + record_repo=mock_repo, + outbox=mock_outbox, + node_domain=Domain("localhost"), + feature_reader=mock_reader, + ) + + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + result = await service.get_features_for_record(srn) + + assert result == {"hook_a": [{"score": 1.0}]} + mock_reader.get_features_for_record.assert_called_once_with(srn) + + +class TestGetRecordHandlerFeatureEnrichment: + async def test_record_detail_includes_features(self) -> None: + record = _make_record() + mock_service = AsyncMock() + mock_service.get.return_value = record + mock_service.get_features_for_record.return_value = {"detect_pockets": [{"score": 7.66}]} + + handler = GetRecordHandler(record_service=mock_service) + result: RecordDetail = await handler.run(GetRecord(srn=record.srn)) + + assert result.features == {"detect_pockets": [{"score": 7.66}]} + + async def test_record_detail_features_empty_when_none(self) -> None: + record = _make_record() + mock_service = AsyncMock() + mock_service.get.return_value = record + mock_service.get_features_for_record.return_value = {} + + handler = GetRecordHandler(record_service=mock_service) + result: RecordDetail = await handler.run(GetRecord(srn=record.srn)) + + assert result.features == {} + + async def test_existing_behavior_preserved(self) -> None: + record = _make_record() + mock_service = AsyncMock() + mock_service.get.return_value = record + mock_service.get_features_for_record.return_value = {} + + handler = GetRecordHandler(record_service=mock_service) + result: RecordDetail = await handler.run(GetRecord(srn=record.srn)) + + assert result.srn == record.srn + assert result.deposition_srn == record.deposition_srn + assert result.metadata == record.metadata + mock_service.get.assert_called_once_with(record.srn) diff --git a/server/tests/unit/domain/record/test_record_service.py b/server/tests/unit/domain/record/test_record_service.py index 38c9924..c78b778 100644 --- a/server/tests/unit/domain/record/test_record_service.py +++ b/server/tests/unit/domain/record/test_record_service.py @@ -71,6 +71,7 @@ async def test_publish_record_creates_record( record_repo=mock_record_repo, outbox=mock_outbox, node_domain=node_domain, + feature_reader=AsyncMock(), ) # Act @@ -100,6 +101,7 @@ async def test_publish_record_emits_record_published_event( record_repo=mock_record_repo, outbox=mock_outbox, node_domain=node_domain, + feature_reader=AsyncMock(), ) # Act @@ -131,6 +133,7 @@ async def test_publish_record_creates_version_1( record_repo=mock_record_repo, outbox=mock_outbox, node_domain=node_domain, + feature_reader=AsyncMock(), ) # Act diff --git a/server/tests/unit/infrastructure/persistence/test_field_definition_reader.py b/server/tests/unit/infrastructure/persistence/test_field_definition_reader.py new file mode 100644 index 0000000..25974c2 --- /dev/null +++ b/server/tests/unit/infrastructure/persistence/test_field_definition_reader.py @@ -0,0 +1,130 @@ +"""Tests for PostgresFieldDefinitionReader — field type map construction.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from osa.domain.semantics.model.value import FieldType +from osa.domain.shared.error import ValidationError +from osa.infrastructure.persistence.adapter.discovery import PostgresFieldDefinitionReader + + +def _make_schema_row(srn: str, fields: list[dict]) -> dict: + return {"srn": srn, "fields": fields} + + +@pytest.fixture +def mock_session() -> AsyncMock: + return AsyncMock() + + +def _setup_session_result(session: AsyncMock, rows: list[dict]) -> None: + """Configure mock session to return rows from a SELECT on schemas_table.""" + result_mock = MagicMock() + result_mock.mappings.return_value.all.return_value = rows + session.execute.return_value = result_mock + + +class TestGetAllFieldTypes: + async def test_builds_field_map_from_multiple_schemas(self, mock_session: AsyncMock) -> None: + rows = [ + _make_schema_row( + "urn:osa:localhost:schema:a@1", + [ + { + "name": "title", + "type": "text", + "required": True, + "cardinality": "exactly_one", + }, + { + "name": "resolution", + "type": "number", + "required": False, + "cardinality": "exactly_one", + }, + ], + ), + _make_schema_row( + "urn:osa:localhost:schema:b@1", + [ + { + "name": "method", + "type": "term", + "required": True, + "cardinality": "exactly_one", + "constraints": { + "type": "term", + "ontology_srn": "urn:osa:localhost:onto:methods@1", + }, + }, + ], + ), + ] + _setup_session_result(mock_session, rows) + + reader = PostgresFieldDefinitionReader(session=mock_session) + result = await reader.get_all_field_types() + + assert result == { + "title": FieldType.TEXT, + "resolution": FieldType.NUMBER, + "method": FieldType.TERM, + } + + async def test_raises_on_conflicting_types(self, mock_session: AsyncMock) -> None: + rows = [ + _make_schema_row( + "urn:osa:localhost:schema:a@1", + [{"name": "value", "type": "text", "required": True, "cardinality": "exactly_one"}], + ), + _make_schema_row( + "urn:osa:localhost:schema:b@1", + [ + { + "name": "value", + "type": "number", + "required": True, + "cardinality": "exactly_one", + } + ], + ), + ] + _setup_session_result(mock_session, rows) + + reader = PostgresFieldDefinitionReader(session=mock_session) + with pytest.raises(ValidationError, match="value"): + await reader.get_all_field_types() + + async def test_returns_empty_map_when_no_schemas(self, mock_session: AsyncMock) -> None: + _setup_session_result(mock_session, []) + + reader = PostgresFieldDefinitionReader(session=mock_session) + result = await reader.get_all_field_types() + + assert result == {} + + async def test_same_field_same_type_across_schemas_ok(self, mock_session: AsyncMock) -> None: + rows = [ + _make_schema_row( + "urn:osa:localhost:schema:a@1", + [{"name": "title", "type": "text", "required": True, "cardinality": "exactly_one"}], + ), + _make_schema_row( + "urn:osa:localhost:schema:b@1", + [ + { + "name": "title", + "type": "text", + "required": False, + "cardinality": "exactly_one", + } + ], + ), + ] + _setup_session_result(mock_session, rows) + + reader = PostgresFieldDefinitionReader(session=mock_session) + result = await reader.get_all_field_types() + + assert result == {"title": FieldType.TEXT} From 958d004063617d899d35ead0f4d37a65942a4c45 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Fri, 6 Mar 2026 18:44:16 +0000 Subject: [PATCH 2/6] =?UTF-8?q?fix:=20harden=20discovery=20queries=20?= =?UTF-8?q?=E2=80=94=20LIKE=20escaping,=20typed=20sort,=20N+1,=20paginatio?= =?UTF-8?q?n=20total,=20q=20validation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Escape LIKE metacharacters (%, _, \) in free-text and CONTAINS filters to prevent user input from being interpreted as wildcard patterns - Cast sort expressions to Float/Date for NUMBER/DATE fields so keyset cursor comparison uses correct ordering instead of lexicographic text - Consolidate N+1 COUNT queries in get_feature_catalog into a single UNION ALL, and N+1 per-table SELECTs in get_features_for_record into a single UNION ALL with to_jsonb - Remove broken total count from paginated responses — COUNT(*) OVER() was evaluated after the cursor WHERE clause, producing a shrinking total across pages; has_more now uses len(results) == limit - Raise ValidationError when q is provided but no text/URL fields exist, instead of silently discarding the search term --- .../application/api/v1/routes/discovery.py | 4 - server/osa/domain/discovery/model/value.py | 2 - .../osa/domain/discovery/port/read_store.py | 8 +- .../domain/discovery/query/search_features.py | 2 - .../domain/discovery/query/search_records.py | 2 - .../osa/domain/discovery/service/discovery.py | 15 ++-- .../persistence/adapter/discovery.py | 81 ++++++++++--------- .../persistence/adapter/feature_reader.py | 47 ++++++----- .../discovery/test_discovery_service.py | 47 +++++++---- .../domain/discovery/test_search_features.py | 10 +-- .../domain/discovery/test_search_records.py | 4 +- .../domain/record/test_record_features.py | 64 ++++++++++----- 12 files changed, 160 insertions(+), 126 deletions(-) diff --git a/server/osa/application/api/v1/routes/discovery.py b/server/osa/application/api/v1/routes/discovery.py index 723fb01..cdf511b 100644 --- a/server/osa/application/api/v1/routes/discovery.py +++ b/server/osa/application/api/v1/routes/discovery.py @@ -47,7 +47,6 @@ class RecordSearchRequest(BaseModel): class RecordSearchResponse(BaseModel): results: list[dict[str, Any]] - total: int cursor: str | None has_more: bool @@ -67,7 +66,6 @@ class FeatureSearchRequest(BaseModel): class FeatureSearchResponse(BaseModel): rows: list[dict[str, Any]] - total: int cursor: str | None has_more: bool @@ -93,7 +91,6 @@ async def search_records( ) return RecordSearchResponse( results=result.results, - total=result.total, cursor=result.cursor, has_more=result.has_more, ) @@ -128,7 +125,6 @@ async def search_features( ) return FeatureSearchResponse( rows=result.rows, - total=result.total, cursor=result.cursor, has_more=result.has_more, ) diff --git a/server/osa/domain/discovery/model/value.py b/server/osa/domain/discovery/model/value.py index 2a94c6d..1abab81 100644 --- a/server/osa/domain/discovery/model/value.py +++ b/server/osa/domain/discovery/model/value.py @@ -68,7 +68,6 @@ class RecordSummary(BaseModel): class RecordSearchResult(BaseModel): results: list[RecordSummary] - total: int cursor: str | None has_more: bool @@ -97,6 +96,5 @@ class FeatureRow(BaseModel): class FeatureSearchResult(BaseModel): rows: list[FeatureRow] - total: int cursor: str | None has_more: bool diff --git a/server/osa/domain/discovery/port/read_store.py b/server/osa/domain/discovery/port/read_store.py index a0e351e..6ac054d 100644 --- a/server/osa/domain/discovery/port/read_store.py +++ b/server/osa/domain/discovery/port/read_store.py @@ -27,8 +27,8 @@ async def search_records( cursor: dict | None, limit: int, field_types: dict[str, FieldType] | None = None, - ) -> tuple[list[RecordSummary], int]: - """Search and filter published records. Returns (results, total_count).""" + ) -> list[RecordSummary]: + """Search and filter published records.""" ... async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: @@ -51,6 +51,6 @@ async def search_features( order: SortOrder, cursor: dict | None, limit: int, - ) -> tuple[list[FeatureRow], int]: - """Search and filter feature rows. Returns (rows, total_count).""" + ) -> list[FeatureRow]: + """Search and filter feature rows.""" ... diff --git a/server/osa/domain/discovery/query/search_features.py b/server/osa/domain/discovery/query/search_features.py index c61bd60..8401ff6 100644 --- a/server/osa/domain/discovery/query/search_features.py +++ b/server/osa/domain/discovery/query/search_features.py @@ -23,7 +23,6 @@ class SearchFeatures(Query): class SearchFeaturesResult(Result): rows: list[dict] - total: int cursor: str | None has_more: bool @@ -45,7 +44,6 @@ async def run(self, cmd: SearchFeatures) -> SearchFeaturesResult: ) return SearchFeaturesResult( rows=[{"record_srn": str(r.record_srn), **r.data} for r in result.rows], - total=result.total, cursor=result.cursor, has_more=result.has_more, ) diff --git a/server/osa/domain/discovery/query/search_records.py b/server/osa/domain/discovery/query/search_records.py index 2812685..eed8957 100644 --- a/server/osa/domain/discovery/query/search_records.py +++ b/server/osa/domain/discovery/query/search_records.py @@ -21,7 +21,6 @@ class SearchRecords(Query): class SearchRecordsResult(Result): results: list[dict] - total: int cursor: str | None has_more: bool @@ -48,7 +47,6 @@ async def run(self, cmd: SearchRecords) -> SearchRecordsResult: } for r in result.results ], - total=result.total, cursor=result.cursor, has_more=result.has_more, ) diff --git a/server/osa/domain/discovery/service/discovery.py b/server/osa/domain/discovery/service/discovery.py index ad2b508..0edd638 100644 --- a/server/osa/domain/discovery/service/discovery.py +++ b/server/osa/domain/discovery/service/discovery.py @@ -81,8 +81,13 @@ async def search_records( text_fields = [ name for name, ft in field_map.items() if ft in (FieldType.TEXT, FieldType.URL) ] + if q and not text_fields: + raise ValidationError( + "Free-text search is unavailable: no text or URL fields are registered", + field="q", + ) - results, total = await self.read_store.search_records( + results = await self.read_store.search_records( filters=filters, text_fields=text_fields, q=q, @@ -93,7 +98,7 @@ async def search_records( field_types=field_map, ) - has_more = len(results) == limit and len(results) < total + has_more = len(results) == limit next_cursor = None if has_more and results: last = results[-1] @@ -105,7 +110,6 @@ async def search_records( return RecordSearchResult( results=results, - total=total, cursor=next_cursor, has_more=has_more, ) @@ -178,7 +182,7 @@ async def search_features( except ValueError as exc: raise ValidationError(str(exc), field="cursor") from exc - rows, total = await self.read_store.search_features( + rows = await self.read_store.search_features( hook_name=hook_name, filters=filters, record_srn=record_srn, @@ -188,7 +192,7 @@ async def search_features( limit=limit, ) - has_more = len(rows) == limit and len(rows) < total + has_more = len(rows) == limit next_cursor = None if has_more and rows: last = rows[-1] @@ -200,7 +204,6 @@ async def search_features( return FeatureSearchResult( rows=rows, - total=total, cursor=next_cursor, has_more=has_more, ) diff --git a/server/osa/infrastructure/persistence/adapter/discovery.py b/server/osa/infrastructure/persistence/adapter/discovery.py index 52a6ffe..e3be2fe 100644 --- a/server/osa/infrastructure/persistence/adapter/discovery.py +++ b/server/osa/infrastructure/persistence/adapter/discovery.py @@ -16,9 +16,11 @@ and_, cast, func, + literal, or_, select, text, + union_all, ) from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession @@ -45,6 +47,11 @@ logger = logging.getLogger(__name__) +def _escape_like(value: str) -> str: + """Escape LIKE metacharacters so user input is matched literally.""" + return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + class PostgresFieldDefinitionReader: """Builds a global field_name -> FieldType map from all registered schemas.""" @@ -88,7 +95,7 @@ async def search_records( cursor: dict[str, Any] | None, limit: int, field_types: dict[str, FieldType] | None = None, - ) -> tuple[list[RecordSummary], int]: + ) -> list[RecordSummary]: """Build and execute a dynamic SQL query for record search.""" t = records_table conditions: list[Any] = [] @@ -100,13 +107,19 @@ async def search_records( # Free-text search across text fields if q and text_fields: - pattern = f"%{q}%" - text_clauses = [t.c.metadata[field].astext.ilike(pattern) for field in text_fields] + pattern = f"%{_escape_like(q)}%" + text_clauses = [ + t.c.metadata[field].astext.ilike(pattern, escape="\\") for field in text_fields + ] conditions.append(or_(*text_clauses)) - # Determine sort expression + # Determine sort expression (cast to match field type for correct ordering) if sort == "published_at": sort_expr = t.c.published_at + elif ft.get(sort) == FieldType.NUMBER: + sort_expr = cast(t.c.metadata[sort].astext, Float) + elif ft.get(sort) == FieldType.DATE: + sort_expr = cast(t.c.metadata[sort].astext, Date) else: sort_expr = t.c.metadata[sort].astext @@ -135,32 +148,25 @@ async def search_records( ) ) - # Build query with COUNT(*) OVER() for total where_clause = and_(*conditions) if conditions else text("TRUE") - total_col = func.count().over().label("_total") stmt = ( - select(t.c.srn, t.c.published_at, t.c.metadata, total_col) + select(t.c.srn, t.c.published_at, t.c.metadata) .where(where_clause) .order_by(*order_clauses) .limit(limit) ) result = await self.session.execute(stmt) - rows = result.mappings().all() - - total = rows[0]["_total"] if rows else 0 - results = [ + return [ RecordSummary( srn=RecordSRN.parse(row["srn"]), published_at=row["published_at"], metadata=row["metadata"], ) - for row in rows + for row in result.mappings() ] - return results, total - async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: """List all feature tables with column schemas and record counts.""" stmt = select( @@ -171,6 +177,20 @@ async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: result = await self.session.execute(stmt) catalog_rows = result.mappings().all() + if not catalog_rows: + return [] + + # Fetch all record counts in a single UNION ALL query (avoid N+1) + count_parts = [ + select( + literal(row["hook_name"]).label("hook_name"), + func.count(func.distinct(text("record_srn"))).label("cnt"), + ).select_from(text(f"features.{quoted_name(row['pg_table'], quote=True)}")) + for row in catalog_rows + ] + counts_result = await self.session.execute(union_all(*count_parts)) + counts_by_hook = {r["hook_name"]: r["cnt"] for r in counts_result.mappings()} + entries: list[FeatureCatalogEntry] = [] for row in catalog_rows: schema_data = row["feature_schema"] @@ -184,19 +204,11 @@ async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: for col in columns_raw ] - pg_table = row["pg_table"] - safe_table = quoted_name(pg_table, quote=True) - count_stmt = select(func.count(func.distinct(text("record_srn")))).select_from( - text(f"features.{safe_table}") - ) - count_result = await self.session.execute(count_stmt) - record_count = count_result.scalar() or 0 - entries.append( FeatureCatalogEntry( hook_name=row["hook_name"], columns=columns, - record_count=record_count, + record_count=counts_by_hook.get(row["hook_name"], 0), ) ) @@ -239,7 +251,7 @@ async def search_features( order: SortOrder, cursor: dict[str, Any] | None, limit: int, - ) -> tuple[list[FeatureRow], int]: + ) -> list[FeatureRow]: """Build and execute a dynamic SQL query for feature row search.""" # Look up pg_table and feature_schema from catalog pg_table_stmt = select( @@ -249,7 +261,7 @@ async def search_features( pg_result = await self.session.execute(pg_table_stmt) pg_row = pg_result.mappings().first() if pg_row is None: - return [], 0 + return [] pg_table: str = pg_row["pg_table"] feature_schema: dict = pg_row["feature_schema"] @@ -295,7 +307,9 @@ async def search_features( if f.operator == FilterOperator.EQ: conditions.append(col == f.value) elif f.operator == FilterOperator.CONTAINS: - conditions.append(cast(col, String).ilike(f"%{f.value}%")) + conditions.append( + cast(col, String).ilike(f"%{_escape_like(str(f.value))}%", escape="\\") + ) elif f.operator == FilterOperator.GTE: conditions.append(col >= f.value) elif f.operator == FilterOperator.LTE: @@ -332,15 +346,12 @@ async def search_features( ) where_clause = and_(*conditions) if conditions else text("TRUE") - total_col = func.count().over().label("_total") - # Select all columns except auto ones, plus total auto_cols = {"id", "created_at"} stmt = ( select( ft.c.id, ft.c.record_srn, - total_col, *[ c for c in ft.columns @@ -353,19 +364,15 @@ async def search_features( ) result = await self.session.execute(stmt) - rows = result.mappings().all() - - total = rows[0]["_total"] if rows else 0 feature_rows: list[FeatureRow] = [] - for row in rows: + for row in result.mappings(): row_dict = dict(row) - row_dict.pop("_total", None) row_id = row_dict.pop("id") rsrn = RecordSRN.parse(row_dict.pop("record_srn")) row_dict.pop("created_at", None) feature_rows.append(FeatureRow(row_id=row_id, record_srn=rsrn, data=row_dict)) - return feature_rows, total + return feature_rows @staticmethod def _record_filter_clause(f: Filter, field_type: FieldType | None = None) -> Any: @@ -375,7 +382,9 @@ def _record_filter_clause(f: Filter, field_type: FieldType | None = None) -> Any # Use JSONB @> containment (GIN-indexed) return t.c.metadata.op("@>")(cast(func.json_build_object(f.field, f.value), JSONB)) elif f.operator == FilterOperator.CONTAINS: - return t.c.metadata[f.field].astext.ilike(f"%{f.value}%") + return t.c.metadata[f.field].astext.ilike( + f"%{_escape_like(str(f.value))}%", escape="\\" + ) elif f.operator in (FilterOperator.GTE, FilterOperator.LTE): # Use typed casts: numeric for NUMBER, date for DATE, string fallback if field_type == FieldType.NUMBER: diff --git a/server/osa/infrastructure/persistence/adapter/feature_reader.py b/server/osa/infrastructure/persistence/adapter/feature_reader.py index a118ca8..4f77848 100644 --- a/server/osa/infrastructure/persistence/adapter/feature_reader.py +++ b/server/osa/infrastructure/persistence/adapter/feature_reader.py @@ -34,28 +34,31 @@ async def get_features_for_record( if not catalog_rows: return {} - features: dict[str, list[dict[str, Any]]] = {} - for row in catalog_rows: - hook_name = row["hook_name"] - pg_table = row["pg_table"] - - # Query the dynamic feature table for this record - safe_table = quoted_name(pg_table, quote=True) - query = text( - f"SELECT * FROM features.{safe_table} WHERE record_srn = :srn" # noqa: S608 + # Build a single UNION ALL query across all feature tables (avoid N+1). + # to_jsonb serialises each heterogeneous row into a uniform shape. + parts: list[str] = [] + params: dict[str, Any] = {"srn": str(record_srn)} + for i, row in enumerate(catalog_rows): + safe_table = quoted_name(row["pg_table"], quote=True) + hook_param = f"hook_{i}" + params[hook_param] = row["hook_name"] + parts.append( # noqa: S608 + f"SELECT :{hook_param} AS hook_name, to_jsonb(t) AS row_data " + f"FROM features.{safe_table} t " + f"WHERE t.record_srn = :srn" ) - feat_result = await self.session.execute(query, {"srn": str(record_srn)}) - feat_rows = feat_result.mappings().all() - - if feat_rows: - rows_list: list[dict[str, Any]] = [] - for feat_row in feat_rows: - row_dict = { - k: v - for k, v in dict(feat_row).items() - if k not in self.AUTO_COLUMNS and k != "record_srn" - } - rows_list.append(row_dict) - features[hook_name] = rows_list + combined = text(" UNION ALL ".join(parts)) + feat_result = await self.session.execute(combined, params) + + features: dict[str, list[dict[str, Any]]] = {} + for feat_row in feat_result.mappings(): + hook_name: str = feat_row["hook_name"] + row_data: dict[str, Any] = feat_row["row_data"] + filtered = { + k: v + for k, v in row_data.items() + if k not in self.AUTO_COLUMNS and k != "record_srn" + } + features.setdefault(hook_name, []).append(filtered) return features diff --git a/server/tests/unit/domain/discovery/test_discovery_service.py b/server/tests/unit/domain/discovery/test_discovery_service.py index 0be662a..5a2d6b8 100644 --- a/server/tests/unit/domain/discovery/test_discovery_service.py +++ b/server/tests/unit/domain/discovery/test_discovery_service.py @@ -23,7 +23,7 @@ @pytest.fixture def mock_read_store() -> AsyncMock: store = AsyncMock() - store.search_records.return_value = ([], 0) + store.search_records.return_value = [] return store @@ -89,7 +89,7 @@ async def test_accepts_published_at_sort(self, service: DiscoveryService) -> Non cursor=None, limit=20, ) - assert result.total == 0 + assert result.results == [] async def test_accepts_metadata_field_sort(self, service: DiscoveryService) -> None: result = await service.search_records( @@ -100,7 +100,7 @@ async def test_accepts_metadata_field_sort(self, service: DiscoveryService) -> N cursor=None, limit=20, ) - assert result.total == 0 + assert result.results == [] async def test_rejects_limit_too_low(self, service: DiscoveryService) -> None: with pytest.raises(ValidationError, match="limit"): @@ -124,6 +124,24 @@ async def test_rejects_limit_too_high(self, service: DiscoveryService) -> None: limit=101, ) + async def test_rejects_q_when_no_text_fields(self, mock_read_store: AsyncMock) -> None: + """q should raise when no TEXT/URL fields exist to search against.""" + no_text_reader = AsyncMock() + no_text_reader.get_all_field_types.return_value = { + "resolution": FieldType.NUMBER, + } + svc = DiscoveryService(read_store=mock_read_store, field_reader=no_text_reader) + + with pytest.raises(ValidationError, match="Free-text search is unavailable"): + await svc.search_records( + filters=[], + q="kinase", + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + class TestSearchRecordsDelegation: async def test_delegates_to_read_store( @@ -200,10 +218,9 @@ async def test_encodes_next_cursor_from_results( ) -> None: srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") ts = datetime(2026, 1, 1, tzinfo=UTC) - mock_read_store.search_records.return_value = ( - [RecordSummary(srn=srn, published_at=ts, metadata={"title": "Test"})], - 5, - ) + mock_read_store.search_records.return_value = [ + RecordSummary(srn=srn, published_at=ts, metadata={"title": "Test"}) + ] result = await service.search_records( filters=[], @@ -225,7 +242,7 @@ async def test_encodes_next_cursor_from_results( async def test_no_cursor_when_no_more_results( self, service: DiscoveryService, mock_read_store: AsyncMock ) -> None: - mock_read_store.search_records.return_value = ([], 0) + mock_read_store.search_records.return_value = [] result = await service.search_records( filters=[], @@ -271,10 +288,9 @@ async def test_cursor_encodes_row_id( columns=[ColumnInfo(name="score", type="number", required=True)], record_count=0, ) - mock_read_store.search_features.return_value = ( - [FeatureRow(row_id=42, record_srn=srn, data={"score": 7.66})], - 5, - ) + mock_read_store.search_features.return_value = [ + FeatureRow(row_id=42, record_srn=srn, data={"score": 7.66}) + ] service = DiscoveryService(read_store=mock_read_store, field_reader=mock_field_reader) result = await service.search_features( @@ -304,10 +320,9 @@ async def test_cursor_uses_row_id_for_id_sort( columns=[ColumnInfo(name="score", type="number", required=True)], record_count=0, ) - mock_read_store.search_features.return_value = ( - [FeatureRow(row_id=99, record_srn=srn, data={"score": 5.0})], - 3, - ) + mock_read_store.search_features.return_value = [ + FeatureRow(row_id=99, record_srn=srn, data={"score": 5.0}) + ] service = DiscoveryService(read_store=mock_read_store, field_reader=mock_field_reader) result = await service.search_features( diff --git a/server/tests/unit/domain/discovery/test_search_features.py b/server/tests/unit/domain/discovery/test_search_features.py index fc7a537..55eb836 100644 --- a/server/tests/unit/domain/discovery/test_search_features.py +++ b/server/tests/unit/domain/discovery/test_search_features.py @@ -16,7 +16,6 @@ from osa.domain.discovery.query.search_features import ( SearchFeatures, SearchFeaturesHandler, - SearchFeaturesResult, ) from osa.domain.discovery.service.discovery import DiscoveryService from osa.domain.shared.error import NotFoundError, ValidationError @@ -40,7 +39,7 @@ def _make_catalog_entry() -> FeatureCatalogEntry: def mock_read_store() -> AsyncMock: store = AsyncMock() store.get_feature_table_schema.return_value = _make_catalog_entry() - store.search_features.return_value = ([], 0) + store.search_features.return_value = [] return store @@ -65,13 +64,11 @@ async def test_public_auth_gate(self) -> None: async def test_delegates_to_service(self) -> None: mock_service = AsyncMock() mock_service.search_features.return_value = FeatureSearchResult( - rows=[], total=0, cursor=None, has_more=False + rows=[], cursor=None, has_more=False ) handler = SearchFeaturesHandler(discovery_service=mock_service) - result: SearchFeaturesResult = await handler.run(SearchFeatures(hook_name="detect_pockets")) - - assert result.total == 0 + await handler.run(SearchFeatures(hook_name="detect_pockets")) mock_service.search_features.assert_called_once() async def test_maps_rows_with_record_srn(self) -> None: @@ -79,7 +76,6 @@ async def test_maps_rows_with_record_srn(self) -> None: mock_service = AsyncMock() mock_service.search_features.return_value = FeatureSearchResult( rows=[FeatureRow(row_id=1, record_srn=srn, data={"score": 7.66})], - total=1, cursor=None, has_more=False, ) diff --git a/server/tests/unit/domain/discovery/test_search_records.py b/server/tests/unit/domain/discovery/test_search_records.py index fb48a98..c713a5b 100644 --- a/server/tests/unit/domain/discovery/test_search_records.py +++ b/server/tests/unit/domain/discovery/test_search_records.py @@ -34,7 +34,7 @@ async def test_delegates_to_service( self, handler: SearchRecordsHandler, mock_service: AsyncMock ) -> None: mock_service.search_records.return_value = RecordSearchResult( - results=[], total=0, cursor=None, has_more=False + results=[], cursor=None, has_more=False ) cmd = SearchRecords() await handler.run(cmd) @@ -55,14 +55,12 @@ async def test_maps_results( ts = datetime(2026, 1, 1, tzinfo=UTC) mock_service.search_records.return_value = RecordSearchResult( results=[RecordSummary(srn=srn, published_at=ts, metadata={"title": "Test"})], - total=1, cursor="abc123", has_more=False, ) result: SearchRecordsResult = await handler.run(SearchRecords()) - assert result.total == 1 assert result.cursor == "abc123" assert result.has_more is False assert result.results[0]["srn"] == str(srn) diff --git a/server/tests/unit/domain/record/test_record_features.py b/server/tests/unit/domain/record/test_record_features.py index 6df74ca..791b7d4 100644 --- a/server/tests/unit/domain/record/test_record_features.py +++ b/server/tests/unit/domain/record/test_record_features.py @@ -36,15 +36,18 @@ async def test_returns_dict_keyed_by_hook_name( _make_catalog_row("detect_pockets", "detect_pockets_v1") ] - # Second call: feature table query + # Second call: UNION ALL query returning {hook_name, row_data} mappings feature_result = MagicMock() - feature_result.mappings.return_value.all.return_value = [ + feature_result.mappings.return_value = [ { - "id": 1, - "record_srn": str(srn), - "created_at": datetime.now(UTC), - "score": 7.66, - "volume": 1750.0, + "hook_name": "detect_pockets", + "row_data": { + "id": 1, + "record_srn": str(srn), + "created_at": "2026-01-01T00:00:00", + "score": 7.66, + "volume": 1750.0, + }, } ] @@ -68,12 +71,15 @@ async def test_excludes_auto_columns( ] feature_result = MagicMock() - feature_result.mappings.return_value.all.return_value = [ + feature_result.mappings.return_value = [ { - "id": 42, - "record_srn": str(srn), - "created_at": datetime.now(UTC), - "metric": 3.14, + "hook_name": "test_hook", + "row_data": { + "id": 42, + "record_srn": str(srn), + "created_at": "2026-01-01T00:00:00", + "metric": 3.14, + }, } ] @@ -109,8 +115,9 @@ async def test_returns_empty_when_record_has_no_data( _make_catalog_row("detect_pockets", "detect_pockets_v1") ] + # UNION ALL returns no rows when record has no feature data feature_result = MagicMock() - feature_result.mappings.return_value.all.return_value = [] + feature_result.mappings.return_value = [] mock_session.execute.side_effect = [catalog_result, feature_result] @@ -128,17 +135,30 @@ async def test_includes_data_from_multiple_tables( _make_catalog_row("hook_b", "hook_b_v1"), ] - feat_a = MagicMock() - feat_a.mappings.return_value.all.return_value = [ - {"id": 1, "record_srn": str(srn), "created_at": datetime.now(UTC), "x": 1} - ] - - feat_b = MagicMock() - feat_b.mappings.return_value.all.return_value = [ - {"id": 2, "record_srn": str(srn), "created_at": datetime.now(UTC), "y": 2} + # Single UNION ALL result containing rows from both tables + feature_result = MagicMock() + feature_result.mappings.return_value = [ + { + "hook_name": "hook_a", + "row_data": { + "id": 1, + "record_srn": str(srn), + "created_at": "2026-01-01T00:00:00", + "x": 1, + }, + }, + { + "hook_name": "hook_b", + "row_data": { + "id": 2, + "record_srn": str(srn), + "created_at": "2026-01-01T00:00:00", + "y": 2, + }, + }, ] - mock_session.execute.side_effect = [catalog_result, feat_a, feat_b] + mock_session.execute.side_effect = [catalog_result, feature_result] result = await reader.get_features_for_record(srn) From 95c4a41ca186931b7da0ec8c72484aaa3593415c Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Fri, 6 Mar 2026 21:49:02 +0000 Subject: [PATCH 3/6] fix: replace ineffective quoted_name with actual SQL identifier quoting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit quoted_name only instructs SQLAlchemy's compiler to quote — when interpolated via f-string into text(), it emits the bare string unquoted. Replace with _quote_ident() that properly double-quotes identifiers and escapes embedded double-quotes. --- .../infrastructure/persistence/adapter/discovery.py | 8 ++++++-- .../persistence/adapter/feature_reader.py | 10 +++++++--- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/server/osa/infrastructure/persistence/adapter/discovery.py b/server/osa/infrastructure/persistence/adapter/discovery.py index e3be2fe..a690fc1 100644 --- a/server/osa/infrastructure/persistence/adapter/discovery.py +++ b/server/osa/infrastructure/persistence/adapter/discovery.py @@ -24,7 +24,6 @@ ) from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql.elements import quoted_name from osa.domain.discovery.model.value import ( ColumnInfo, @@ -52,6 +51,11 @@ def _escape_like(value: str) -> str: return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") +def _quote_ident(name: str) -> str: + """Double-quote a SQL identifier, escaping embedded double-quotes.""" + return '"' + name.replace('"', '""') + '"' + + class PostgresFieldDefinitionReader: """Builds a global field_name -> FieldType map from all registered schemas.""" @@ -185,7 +189,7 @@ async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: select( literal(row["hook_name"]).label("hook_name"), func.count(func.distinct(text("record_srn"))).label("cnt"), - ).select_from(text(f"features.{quoted_name(row['pg_table'], quote=True)}")) + ).select_from(text(f"features.{_quote_ident(row['pg_table'])}")) for row in catalog_rows ] counts_result = await self.session.execute(union_all(*count_parts)) diff --git a/server/osa/infrastructure/persistence/adapter/feature_reader.py b/server/osa/infrastructure/persistence/adapter/feature_reader.py index 4f77848..7ce999c 100644 --- a/server/osa/infrastructure/persistence/adapter/feature_reader.py +++ b/server/osa/infrastructure/persistence/adapter/feature_reader.py @@ -6,12 +6,16 @@ from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql.elements import quoted_name from osa.domain.shared.model.srn import RecordSRN from osa.infrastructure.persistence.tables import feature_tables_table +def _quote_ident(name: str) -> str: + """Double-quote a SQL identifier, escaping embedded double-quotes.""" + return '"' + name.replace('"', '""') + '"' + + class PostgresFeatureReader: """Queries feature_tables catalog and dynamic feature tables for a record.""" @@ -39,12 +43,12 @@ async def get_features_for_record( parts: list[str] = [] params: dict[str, Any] = {"srn": str(record_srn)} for i, row in enumerate(catalog_rows): - safe_table = quoted_name(row["pg_table"], quote=True) + quoted = _quote_ident(row["pg_table"]) hook_param = f"hook_{i}" params[hook_param] = row["hook_name"] parts.append( # noqa: S608 f"SELECT :{hook_param} AS hook_name, to_jsonb(t) AS row_data " - f"FROM features.{safe_table} t " + f"FROM features.{quoted} t " f"WHERE t.record_srn = :srn" ) combined = text(" UNION ALL ".join(parts)) From cdd4b42003bed0d73c37c4b6f945f2c0a7467762 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 7 Mar 2026 00:11:53 +0000 Subject: [PATCH 4/6] refactor: extract typed FeatureSchema and shared table-building helpers Replace raw dict handling of feature_schema JSON with a typed FeatureSchema Pydantic model. Parse at the DB boundary, then pass typed objects everywhere downstream. - Add FeatureSchema, build_feature_table, and data_columns to feature_table.py as the single source of truth for feature table construction - Eliminate raw SQL (text() f-strings, manual _quote_ident) from get_feature_catalog and get_features_for_record - Remove duplicated auto-column definitions from feature_store.py - Extract _to_column_info helper, remove dead pop("created_at") - Consolidate redundant DDL tests into test_feature_table.py --- .../persistence/adapter/discovery.py | 127 +++++----------- .../persistence/adapter/feature_reader.py | 62 ++++---- .../persistence/feature_store.py | 34 ++--- .../persistence/feature_table.py | 56 +++++++ .../persistence/test_feature_store.py | 6 +- .../domain/record/test_record_features.py | 59 ++++---- .../infrastructure/test_feature_queries.py | 139 +++--------------- .../unit/infrastructure/test_feature_table.py | 117 +++++++++++++++ .../test_postgres_feature_store.py | 6 +- 9 files changed, 310 insertions(+), 296 deletions(-) create mode 100644 server/osa/infrastructure/persistence/feature_table.py create mode 100644 server/tests/unit/infrastructure/test_feature_table.py diff --git a/server/osa/infrastructure/persistence/adapter/discovery.py b/server/osa/infrastructure/persistence/adapter/discovery.py index a690fc1..f3cf0ba 100644 --- a/server/osa/infrastructure/persistence/adapter/discovery.py +++ b/server/osa/infrastructure/persistence/adapter/discovery.py @@ -6,20 +6,16 @@ from typing import Any from sqlalchemy import ( - Column, Date, Float, - Integer, - MetaData, String, - Table, and_, cast, func, literal, or_, select, - text, + true, union_all, ) from sqlalchemy.dialects.postgresql import JSONB @@ -37,6 +33,11 @@ from osa.domain.semantics.model.value import FieldType from osa.domain.shared.error import ValidationError from osa.domain.shared.model.srn import RecordSRN +from osa.infrastructure.persistence.feature_table import ( + FeatureSchema, + build_feature_table, + data_columns, +) from osa.infrastructure.persistence.tables import ( feature_tables_table, records_table, @@ -51,9 +52,9 @@ def _escape_like(value: str) -> str: return value.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") -def _quote_ident(name: str) -> str: - """Double-quote a SQL identifier, escaping embedded double-quotes.""" - return '"' + name.replace('"', '""') + '"' +def _to_column_info(schema: FeatureSchema) -> list[ColumnInfo]: + """Map typed FeatureSchema columns to API-facing ColumnInfo list.""" + return [ColumnInfo(name=c.name, type=c.json_type, required=c.required) for c in schema.columns] class PostgresFieldDefinitionReader: @@ -152,7 +153,7 @@ async def search_records( ) ) - where_clause = and_(*conditions) if conditions else text("TRUE") + where_clause = and_(*conditions) if conditions else true() stmt = ( select(t.c.srn, t.c.published_at, t.c.metadata) @@ -184,39 +185,33 @@ async def get_feature_catalog(self) -> list[FeatureCatalogEntry]: if not catalog_rows: return [] - # Fetch all record counts in a single UNION ALL query (avoid N+1) - count_parts = [ - select( - literal(row["hook_name"]).label("hook_name"), - func.count(func.distinct(text("record_srn"))).label("cnt"), - ).select_from(text(f"features.{_quote_ident(row['pg_table'])}")) + # Parse schemas at the boundary + parsed = [ + (row["hook_name"], FeatureSchema.model_validate(row["feature_schema"]), row["pg_table"]) for row in catalog_rows ] + + # Fetch all record counts in a single UNION ALL query (avoid N+1) + count_parts = [] + for hook_name, schema, pg_table in parsed: + ft = build_feature_table(pg_table, schema) + count_parts.append( + select( + literal(hook_name).label("hook_name"), + func.count(func.distinct(ft.c.record_srn)).label("cnt"), + ).select_from(ft) + ) counts_result = await self.session.execute(union_all(*count_parts)) counts_by_hook = {r["hook_name"]: r["cnt"] for r in counts_result.mappings()} - entries: list[FeatureCatalogEntry] = [] - for row in catalog_rows: - schema_data = row["feature_schema"] - columns_raw = schema_data.get("columns", []) if isinstance(schema_data, dict) else [] - columns = [ - ColumnInfo( - name=col["name"], - type=col.get("json_type", "string"), - required=col.get("required", False), - ) - for col in columns_raw - ] - - entries.append( - FeatureCatalogEntry( - hook_name=row["hook_name"], - columns=columns, - record_count=counts_by_hook.get(row["hook_name"], 0), - ) + return [ + FeatureCatalogEntry( + hook_name=hook_name, + columns=_to_column_info(schema), + record_count=counts_by_hook.get(hook_name, 0), ) - - return entries + for hook_name, schema, _pg_table in parsed + ] async def get_feature_table_schema(self, hook_name: str) -> FeatureCatalogEntry | None: """Look up a single feature table's schema by hook name.""" @@ -229,20 +224,10 @@ async def get_feature_table_schema(self, hook_name: str) -> FeatureCatalogEntry if row is None: return None - schema_data = row["feature_schema"] - columns_raw = schema_data.get("columns", []) if isinstance(schema_data, dict) else [] - columns = [ - ColumnInfo( - name=col["name"], - type=col.get("json_type", "string"), - required=col.get("required", False), - ) - for col in columns_raw - ] - + schema = FeatureSchema.model_validate(row["feature_schema"]) return FeatureCatalogEntry( hook_name=row["hook_name"], - columns=columns, + columns=_to_column_info(schema), record_count=0, ) @@ -267,37 +252,9 @@ async def search_features( if pg_row is None: return [] pg_table: str = pg_row["pg_table"] - feature_schema: dict = pg_row["feature_schema"] + schema = FeatureSchema.model_validate(pg_row["feature_schema"]) - # Build Table with full column list from schema using local MetaData - from osa.domain.shared.model.hook import ColumnDef - from osa.infrastructure.persistence.column_mapper import map_column - - schema_columns = ( - feature_schema.get("columns", []) if isinstance(feature_schema, dict) else [] - ) - data_columns = [ - map_column( - ColumnDef( - name=col["name"], - json_type=col.get("json_type", "string"), - format=col.get("format"), - required=col.get("required", False), - ) - ) - for col in schema_columns - ] - - local_meta = MetaData() - ft = Table( - pg_table, - local_meta, - Column("id", Integer, primary_key=True), - Column("record_srn", String), - Column("created_at", String), - *data_columns, - schema="features", - ) + ft = build_feature_table(pg_table, schema) conditions: list[Any] = [] @@ -349,19 +306,10 @@ async def search_features( ) ) - where_clause = and_(*conditions) if conditions else text("TRUE") + where_clause = and_(*conditions) if conditions else true() - auto_cols = {"id", "created_at"} stmt = ( - select( - ft.c.id, - ft.c.record_srn, - *[ - c - for c in ft.columns - if c.key not in auto_cols and c.key not in ("id", "record_srn") - ], - ) + select(ft.c.id, ft.c.record_srn, *data_columns(ft)) .where(where_clause) .order_by(*order_clauses) .limit(limit) @@ -373,7 +321,6 @@ async def search_features( row_dict = dict(row) row_id = row_dict.pop("id") rsrn = RecordSRN.parse(row_dict.pop("record_srn")) - row_dict.pop("created_at", None) feature_rows.append(FeatureRow(row_id=row_id, record_srn=rsrn, data=row_dict)) return feature_rows diff --git a/server/osa/infrastructure/persistence/adapter/feature_reader.py b/server/osa/infrastructure/persistence/adapter/feature_reader.py index 7ce999c..beeb07c 100644 --- a/server/osa/infrastructure/persistence/adapter/feature_reader.py +++ b/server/osa/infrastructure/persistence/adapter/feature_reader.py @@ -4,23 +4,21 @@ from typing import Any -from sqlalchemy import select, text +from sqlalchemy import func, literal, select, union_all from sqlalchemy.ext.asyncio import AsyncSession from osa.domain.shared.model.srn import RecordSRN +from osa.infrastructure.persistence.feature_table import ( + FeatureSchema, + build_feature_table, + data_columns, +) from osa.infrastructure.persistence.tables import feature_tables_table -def _quote_ident(name: str) -> str: - """Double-quote a SQL identifier, escaping embedded double-quotes.""" - return '"' + name.replace('"', '""') + '"' - - class PostgresFeatureReader: """Queries feature_tables catalog and dynamic feature tables for a record.""" - AUTO_COLUMNS = {"id", "created_at"} - def __init__(self, session: AsyncSession) -> None: self.session = session @@ -31,6 +29,7 @@ async def get_features_for_record( stmt = select( feature_tables_table.c.hook_name, feature_tables_table.c.pg_table, + feature_tables_table.c.feature_schema, ) result = await self.session.execute(stmt) catalog_rows = result.mappings().all() @@ -39,30 +38,39 @@ async def get_features_for_record( return {} # Build a single UNION ALL query across all feature tables (avoid N+1). - # to_jsonb serialises each heterogeneous row into a uniform shape. - parts: list[str] = [] - params: dict[str, Any] = {"srn": str(record_srn)} - for i, row in enumerate(catalog_rows): - quoted = _quote_ident(row["pg_table"]) - hook_param = f"hook_{i}" - params[hook_param] = row["hook_name"] - parts.append( # noqa: S608 - f"SELECT :{hook_param} AS hook_name, to_jsonb(t) AS row_data " - f"FROM features.{quoted} t " - f"WHERE t.record_srn = :srn" + # Use jsonb_build_object with explicit data columns to exclude auto columns + # at the SQL level. + parts = [] + for row in catalog_rows: + schema = FeatureSchema.model_validate(row["feature_schema"]) + ft = build_feature_table(row["pg_table"], schema) + dcols = data_columns(ft) + + # Build jsonb_build_object('col1', col1, 'col2', col2, ...) + jsonb_args: list[Any] = [] + for col in dcols: + jsonb_args.extend([literal(col.key), col]) + + row_data_expr = ( + func.jsonb_build_object(*jsonb_args) if jsonb_args else func.jsonb_build_object() + ) + + parts.append( + select( + literal(row["hook_name"]).label("hook_name"), + row_data_expr.label("row_data"), + ) + .select_from(ft) + .where(ft.c.record_srn == str(record_srn)) ) - combined = text(" UNION ALL ".join(parts)) - feat_result = await self.session.execute(combined, params) + + combined = union_all(*parts) + feat_result = await self.session.execute(combined) features: dict[str, list[dict[str, Any]]] = {} for feat_row in feat_result.mappings(): hook_name: str = feat_row["hook_name"] row_data: dict[str, Any] = feat_row["row_data"] - filtered = { - k: v - for k, v in row_data.items() - if k not in self.AUTO_COLUMNS and k != "record_srn" - } - features.setdefault(hook_name, []).append(filtered) + features.setdefault(hook_name, []).append(row_data) return features diff --git a/server/osa/infrastructure/persistence/feature_store.py b/server/osa/infrastructure/persistence/feature_store.py index a12708b..b43c729 100644 --- a/server/osa/infrastructure/persistence/feature_store.py +++ b/server/osa/infrastructure/persistence/feature_store.py @@ -12,11 +12,13 @@ from osa.domain.feature.port.feature_store import FeatureStore from osa.domain.shared.error import ConflictError, ValidationError from osa.domain.shared.model.hook import ColumnDef -from osa.infrastructure.persistence.column_mapper import map_column +from osa.infrastructure.persistence.feature_table import ( + FEATURES_SCHEMA, + FeatureSchema, + build_feature_table, +) from osa.infrastructure.persistence.tables import feature_tables_table -FEATURES_SCHEMA = "features" - _PG_IDENTIFIER = re.compile(r"^[a-z][a-z0-9_]{0,62}$") @@ -58,34 +60,16 @@ async def create_table(self, hook_name: str, columns: list[ColumnDef]) -> None: raise ConflictError(f"Feature table already exists: {hook_name}") # Build dynamic table - metadata = sa.MetaData(schema=FEATURES_SCHEMA) - sa_columns: list[sa.Column] = [ - sa.Column("id", sa.BigInteger, primary_key=True, autoincrement=True), - sa.Column("record_srn", sa.Text, nullable=False, index=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - nullable=False, - server_default=sa.func.now(), - ), - ] - - for col_def in columns: - sa_columns.append(map_column(col_def)) - - # sa.Table registers itself on the metadata object - sa.Table(hook_name, metadata, *sa_columns) + schema = FeatureSchema(columns=columns) + table = build_feature_table(hook_name, schema) # Create table - await conn.run_sync(metadata.create_all, checkfirst=False) - - # Register in catalog - feature_schema = {"columns": [c.model_dump() for c in columns]} + await conn.run_sync(table.metadata.create_all, checkfirst=False) await conn.execute( feature_tables_table.insert().values( hook_name=hook_name, pg_table=hook_name, - feature_schema=feature_schema, + feature_schema=schema.model_dump(), schema_version=1, created_at=datetime.now(UTC), ) diff --git a/server/osa/infrastructure/persistence/feature_table.py b/server/osa/infrastructure/persistence/feature_table.py new file mode 100644 index 0000000..48eb214 --- /dev/null +++ b/server/osa/infrastructure/persistence/feature_table.py @@ -0,0 +1,56 @@ +"""Shared helpers for building dynamic feature Table objects from catalog schema.""" + +from __future__ import annotations + +import sqlalchemy as sa + +from osa.domain.shared.model.hook import ColumnDef +from osa.domain.shared.model.value import ValueObject +from osa.infrastructure.persistence.column_mapper import map_column + +FEATURES_SCHEMA = "features" + +AUTO_COLUMN_NAMES = frozenset({"id", "record_srn", "created_at"}) + + +class FeatureSchema(ValueObject): + """Typed representation of the ``feature_tables.feature_schema`` JSON column. + + Serialised with :meth:`model_dump`, deserialised with :meth:`model_validate`. + """ + + columns: list[ColumnDef] = [] + + +def build_feature_table(pg_table: str, schema: FeatureSchema) -> sa.Table: + """Build a SQLAlchemy ``Table`` for a dynamic feature table. + + Returns a ``Table`` with auto columns (``id``, ``record_srn``, ``created_at``) + plus data columns derived from *schema* via :func:`map_column`, in the + ``features`` PG schema. + + Each call creates a disposable ``MetaData`` — these Tables are used for + query building only, not for DDL lifecycle management. + """ + data_columns = [map_column(col_def) for col_def in schema.columns] + + metadata = sa.MetaData() + return sa.Table( + pg_table, + metadata, + sa.Column("id", sa.BigInteger, primary_key=True, autoincrement=True), + sa.Column("record_srn", sa.Text, nullable=False, index=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + *data_columns, + schema=FEATURES_SCHEMA, + ) + + +def data_columns(table: sa.Table) -> list[sa.Column]: + """Return only the user-defined data columns, excluding auto columns.""" + return [c for c in table.columns if c.key not in AUTO_COLUMN_NAMES] diff --git a/server/tests/integration/persistence/test_feature_store.py b/server/tests/integration/persistence/test_feature_store.py index a0f82e8..0c81bc8 100644 --- a/server/tests/integration/persistence/test_feature_store.py +++ b/server/tests/integration/persistence/test_feature_store.py @@ -10,10 +10,8 @@ OciConfig, TableFeatureSpec, ) -from osa.infrastructure.persistence.feature_store import ( - FEATURES_SCHEMA, - PostgresFeatureStore, -) +from osa.infrastructure.persistence.feature_store import PostgresFeatureStore +from osa.infrastructure.persistence.feature_table import FEATURES_SCHEMA def _make_hook( diff --git a/server/tests/unit/domain/record/test_record_features.py b/server/tests/unit/domain/record/test_record_features.py index 791b7d4..700f967 100644 --- a/server/tests/unit/domain/record/test_record_features.py +++ b/server/tests/unit/domain/record/test_record_features.py @@ -12,8 +12,12 @@ from osa.infrastructure.persistence.adapter.feature_reader import PostgresFeatureReader -def _make_catalog_row(hook_name: str, pg_table: str) -> dict: - return {"hook_name": hook_name, "pg_table": pg_table} +def _make_catalog_row(hook_name: str, pg_table: str, columns: list[dict] | None = None) -> dict: + return { + "hook_name": hook_name, + "pg_table": pg_table, + "feature_schema": {"columns": columns or []}, + } class TestPostgresFeatureReader: @@ -33,18 +37,23 @@ async def test_returns_dict_keyed_by_hook_name( # First call: catalog query catalog_result = MagicMock() catalog_result.mappings.return_value.all.return_value = [ - _make_catalog_row("detect_pockets", "detect_pockets_v1") + _make_catalog_row( + "detect_pockets", + "detect_pockets_v1", + [ + {"name": "score", "json_type": "number", "required": True}, + {"name": "volume", "json_type": "number", "required": False}, + ], + ) ] # Second call: UNION ALL query returning {hook_name, row_data} mappings + # row_data now excludes auto columns (jsonb_build_object only includes data cols) feature_result = MagicMock() feature_result.mappings.return_value = [ { "hook_name": "detect_pockets", "row_data": { - "id": 1, - "record_srn": str(srn), - "created_at": "2026-01-01T00:00:00", "score": 7.66, "volume": 1750.0, }, @@ -67,7 +76,11 @@ async def test_excludes_auto_columns( catalog_result = MagicMock() catalog_result.mappings.return_value.all.return_value = [ - _make_catalog_row("test_hook", "test_hook_v1") + _make_catalog_row( + "test_hook", + "test_hook_v1", + [{"name": "metric", "json_type": "number", "required": True}], + ) ] feature_result = MagicMock() @@ -75,9 +88,6 @@ async def test_excludes_auto_columns( { "hook_name": "test_hook", "row_data": { - "id": 42, - "record_srn": str(srn), - "created_at": "2026-01-01T00:00:00", "metric": 3.14, }, } @@ -112,7 +122,11 @@ async def test_returns_empty_when_record_has_no_data( catalog_result = MagicMock() catalog_result.mappings.return_value.all.return_value = [ - _make_catalog_row("detect_pockets", "detect_pockets_v1") + _make_catalog_row( + "detect_pockets", + "detect_pockets_v1", + [{"name": "score", "json_type": "number", "required": True}], + ) ] # UNION ALL returns no rows when record has no feature data @@ -131,30 +145,25 @@ async def test_includes_data_from_multiple_tables( catalog_result = MagicMock() catalog_result.mappings.return_value.all.return_value = [ - _make_catalog_row("hook_a", "hook_a_v1"), - _make_catalog_row("hook_b", "hook_b_v1"), + _make_catalog_row( + "hook_a", "hook_a_v1", [{"name": "x", "json_type": "integer", "required": True}] + ), + _make_catalog_row( + "hook_b", "hook_b_v1", [{"name": "y", "json_type": "integer", "required": True}] + ), ] # Single UNION ALL result containing rows from both tables + # row_data now excludes auto columns feature_result = MagicMock() feature_result.mappings.return_value = [ { "hook_name": "hook_a", - "row_data": { - "id": 1, - "record_srn": str(srn), - "created_at": "2026-01-01T00:00:00", - "x": 1, - }, + "row_data": {"x": 1}, }, { "hook_name": "hook_b", - "row_data": { - "id": 2, - "record_srn": str(srn), - "created_at": "2026-01-01T00:00:00", - "y": 2, - }, + "row_data": {"y": 2}, }, ] diff --git a/server/tests/unit/infrastructure/test_feature_queries.py b/server/tests/unit/infrastructure/test_feature_queries.py index 6687ec4..89a8ed5 100644 --- a/server/tests/unit/infrastructure/test_feature_queries.py +++ b/server/tests/unit/infrastructure/test_feature_queries.py @@ -1,126 +1,24 @@ -"""Tests for feature table DDL correctness — verifies tables support SQL JOINs and typed queries. +"""Tests for feature table SQL generation — verifies tables support SQL JOINs and typed queries. -These tests validate that the dynamically generated SQLAlchemy table metadata -produces correct DDL: proper column types, nullable constraints, record_srn FK -column for JOINs, and a single ``features`` PG schema. +DDL correctness tests live in test_feature_table.py. These tests verify that +SQLAlchemy can generate valid SELECT/WHERE/JOIN expressions from the built tables. """ import sqlalchemy as sa -from sqlalchemy.dialects.postgresql import JSONB from osa.domain.shared.model.hook import ColumnDef -from osa.infrastructure.persistence.column_mapper import map_column -from osa.infrastructure.persistence.feature_store import FEATURES_SCHEMA +from osa.infrastructure.persistence.feature_table import ( + FEATURES_SCHEMA, + FeatureSchema, + build_feature_table, +) -def _make_columns(*col_defs: tuple[str, str, bool]) -> list[ColumnDef]: - """Create a list of ColumnDef from (name, json_type, required) tuples.""" - return [ColumnDef(name=n, json_type=t, required=r) for n, t, r in col_defs] - - -def _build_feature_table(table_name: str, columns: list[ColumnDef]) -> sa.Table: - """Build a dynamic feature table exactly as PostgresFeatureStore does.""" - metadata = sa.MetaData(schema=FEATURES_SCHEMA) - sa_columns: list[sa.Column] = [ - sa.Column("id", sa.BigInteger, primary_key=True, autoincrement=True), - sa.Column("record_srn", sa.Text, nullable=False, index=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - nullable=False, - server_default=sa.func.now(), - ), - ] - for col_def in columns: - sa_columns.append(map_column(col_def)) - - return sa.Table(table_name, metadata, *sa_columns) - - -class TestFeatureTableDDL: - def test_record_srn_column_exists(self): - """Feature tables must have record_srn for JOINing with records table.""" - columns = _make_columns(("score", "number", True)) - table = _build_feature_table("pocket_detect", columns) - - assert "record_srn" in table.c - assert not table.c.record_srn.nullable - - def test_record_srn_is_indexed(self): - """record_srn must be indexed for efficient JOINs.""" - columns = _make_columns(("score", "number", True)) - table = _build_feature_table("pocket_detect", columns) - - assert table.c.record_srn.index is True - - def test_number_columns_are_float(self): - """Number columns map to Float(53) for double-precision queries.""" - columns = _make_columns(("score", "number", True), ("volume", "number", False)) - table = _build_feature_table("detect", columns) - - assert isinstance(table.c.score.type, sa.Float) - assert isinstance(table.c.volume.type, sa.Float) - - def test_integer_columns_are_bigint(self): - columns = _make_columns(("n_atoms", "integer", True)) - table = _build_feature_table("check", columns) - - assert isinstance(table.c.n_atoms.type, sa.BigInteger) - - def test_string_columns_are_text(self): - columns = _make_columns(("pocket_id", "string", True)) - table = _build_feature_table("detect", columns) - - assert isinstance(table.c.pocket_id.type, sa.Text) - - def test_boolean_columns_are_boolean(self): - columns = _make_columns(("is_valid", "boolean", True)) - table = _build_feature_table("check", columns) - - assert isinstance(table.c.is_valid.type, sa.Boolean) - - def test_array_columns_are_jsonb(self): - columns = _make_columns(("residues", "array", True)) - table = _build_feature_table("detect", columns) - - assert isinstance(table.c.residues.type, JSONB) - - def test_object_columns_are_jsonb(self): - columns = _make_columns(("metadata", "object", False)) - table = _build_feature_table("detect", columns) - - assert isinstance(table.c.metadata.type, JSONB) - - def test_nullable_respects_required_field(self): - columns = _make_columns( - ("score", "number", True), - ("notes", "string", False), - ) - table = _build_feature_table("detect", columns) - - assert not table.c.score.nullable # required -> NOT NULL - assert table.c.notes.nullable # optional -> nullable - - def test_has_primary_key(self): - columns = _make_columns(("score", "number", True)) - table = _build_feature_table("detect", columns) - - pk_cols = [c.name for c in table.primary_key.columns] - assert pk_cols == ["id"] - - def test_has_created_at(self): - columns = _make_columns(("score", "number", True)) - table = _build_feature_table("detect", columns) - - assert "created_at" in table.c - assert isinstance(table.c.created_at.type, sa.DateTime) - - def test_table_uses_features_schema(self): - """All feature tables live in the single 'features' PG schema.""" - columns = _make_columns(("score", "number", True)) - table = _build_feature_table("detect", columns) - - assert table.schema == FEATURES_SCHEMA +def _schema(*col_defs: tuple[str, str, bool]) -> FeatureSchema: + """Create a FeatureSchema from (name, json_type, required) tuples.""" + return FeatureSchema( + columns=[ColumnDef(name=n, json_type=t, required=r) for n, t, r in col_defs] + ) class TestFeatureTableSQLGeneration: @@ -128,8 +26,7 @@ class TestFeatureTableSQLGeneration: def test_where_on_typed_column(self): """Can build WHERE score > 0.5 on a number column.""" - columns = _make_columns(("score", "number", True)) - table = _build_feature_table("detect", columns) + table = build_feature_table("detect", _schema(("score", "number", True))) stmt = sa.select(table.c.record_srn).where(table.c.score > 0.5) compiled = str(stmt.compile(compile_kwargs={"literal_binds": True})) @@ -139,8 +36,7 @@ def test_where_on_typed_column(self): def test_join_on_record_srn(self): """Can JOIN feature table with records table on record_srn.""" - columns = _make_columns(("score", "number", True)) - feature_table = _build_feature_table("detect", columns) + feature_table = build_feature_table("detect", _schema(("score", "number", True))) # Simulate a records table records_meta = sa.MetaData() @@ -169,8 +65,9 @@ def test_join_on_record_srn(self): def test_aggregate_on_typed_columns(self): """Can compute aggregates (AVG, COUNT) on typed columns.""" - columns = _make_columns(("score", "number", True), ("volume", "number", True)) - table = _build_feature_table("detect", columns) + table = build_feature_table( + "detect", _schema(("score", "number", True), ("volume", "number", True)) + ) stmt = sa.select( table.c.record_srn, diff --git a/server/tests/unit/infrastructure/test_feature_table.py b/server/tests/unit/infrastructure/test_feature_table.py new file mode 100644 index 0000000..d76b702 --- /dev/null +++ b/server/tests/unit/infrastructure/test_feature_table.py @@ -0,0 +1,117 @@ +"""Tests for the shared build_feature_table helper and FeatureSchema model.""" + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB + +from osa.domain.shared.model.hook import ColumnDef +from osa.infrastructure.persistence.feature_table import ( + FEATURES_SCHEMA, + FeatureSchema, + build_feature_table, + data_columns, +) + + +class TestFeatureSchema: + def test_round_trips_through_json(self) -> None: + schema = FeatureSchema( + columns=[ + ColumnDef(name="score", json_type="number", required=True), + ColumnDef(name="label", json_type="string", required=False), + ] + ) + raw = schema.model_dump() + restored = FeatureSchema.model_validate(raw) + assert restored == schema + + def test_defaults_to_empty_columns(self) -> None: + schema = FeatureSchema() + assert schema.columns == [] + + def test_validates_from_catalog_json(self) -> None: + raw = { + "columns": [ + {"name": "score", "json_type": "number", "required": True}, + {"name": "label", "json_type": "string", "required": False}, + ] + } + schema = FeatureSchema.model_validate(raw) + assert len(schema.columns) == 2 + assert schema.columns[0].name == "score" + assert schema.columns[0].json_type == "number" + + +class TestBuildFeatureTable: + def test_uses_features_schema(self) -> None: + table = build_feature_table("my_hook", FeatureSchema()) + assert table.schema == FEATURES_SCHEMA + + def test_has_auto_columns(self) -> None: + table = build_feature_table("my_hook", FeatureSchema()) + + assert "id" in table.c + assert "record_srn" in table.c + assert "created_at" in table.c + + def test_id_is_primary_key(self) -> None: + table = build_feature_table("my_hook", FeatureSchema()) + pk_cols = [c.name for c in table.primary_key.columns] + assert pk_cols == ["id"] + + def test_record_srn_not_nullable(self) -> None: + table = build_feature_table("my_hook", FeatureSchema()) + assert not table.c.record_srn.nullable + + def test_data_columns_from_schema(self) -> None: + schema = FeatureSchema( + columns=[ + ColumnDef(name="score", json_type="number", required=True), + ColumnDef(name="label", json_type="string", required=False), + ] + ) + table = build_feature_table("detect", schema) + + assert "score" in table.c + assert isinstance(table.c.score.type, sa.Float) + assert not table.c.score.nullable + + assert "label" in table.c + assert isinstance(table.c.label.type, sa.Text) + assert table.c.label.nullable + + def test_empty_schema_has_only_auto_columns(self) -> None: + table = build_feature_table("empty", FeatureSchema()) + assert set(c.key for c in table.columns) == {"id", "record_srn", "created_at"} + + def test_array_column_is_jsonb(self) -> None: + schema = FeatureSchema( + columns=[ColumnDef(name="residues", json_type="array", required=True)] + ) + table = build_feature_table("detect", schema) + assert isinstance(table.c.residues.type, JSONB) + + def test_table_name_matches(self) -> None: + table = build_feature_table("my_table_name", FeatureSchema()) + assert table.name == "my_table_name" + + +class TestDataColumns: + def test_excludes_auto_columns(self) -> None: + schema = FeatureSchema( + columns=[ + ColumnDef(name="score", json_type="number", required=True), + ColumnDef(name="label", json_type="string", required=False), + ] + ) + table = build_feature_table("detect", schema) + dcols = data_columns(table) + col_names = [c.key for c in dcols] + + assert col_names == ["score", "label"] + assert "id" not in col_names + assert "record_srn" not in col_names + assert "created_at" not in col_names + + def test_empty_for_schema_with_no_data_columns(self) -> None: + table = build_feature_table("empty", FeatureSchema()) + assert data_columns(table) == [] diff --git a/server/tests/unit/infrastructure/test_postgres_feature_store.py b/server/tests/unit/infrastructure/test_postgres_feature_store.py index d0a8fa2..6f45d20 100644 --- a/server/tests/unit/infrastructure/test_postgres_feature_store.py +++ b/server/tests/unit/infrastructure/test_postgres_feature_store.py @@ -8,10 +8,8 @@ from osa.domain.shared.error import ConflictError, ValidationError from osa.domain.shared.model.hook import ColumnDef -from osa.infrastructure.persistence.feature_store import ( - FEATURES_SCHEMA, - PostgresFeatureStore, -) +from osa.infrastructure.persistence.feature_store import PostgresFeatureStore +from osa.infrastructure.persistence.feature_table import FEATURES_SCHEMA def _make_columns() -> list[ColumnDef]: From 3f1a07ec816358be67e3266b635fc5e918c7daf8 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 7 Mar 2026 01:34:49 +0000 Subject: [PATCH 5/6] fix: keyset pagination NULL handling, SRN validation, and jsonb param typing - Add KeysetPage abstraction that derives ORDER BY and WHERE from a single sort spec with correct NULL semantics (NULL cursor no longer produces `sort > NULL` which is always false in PostgreSQL) - Wrap RecordSRN.parse in SearchFeaturesHandler with try/except to return 422 instead of unhandled 500 - Use type_coerce for jsonb_build_object string keys so asyncpg can determine parameter types --- .../domain/discovery/query/search_features.py | 8 +- .../persistence/adapter/discovery.py | 64 ++---- .../persistence/adapter/feature_reader.py | 4 +- .../osa/infrastructure/persistence/keyset.py | 114 +++++++++++ .../domain/discovery/test_search_features.py | 10 + .../infrastructure/persistence/test_keyset.py | 192 ++++++++++++++++++ 6 files changed, 346 insertions(+), 46 deletions(-) create mode 100644 server/osa/infrastructure/persistence/keyset.py create mode 100644 server/tests/unit/infrastructure/persistence/test_keyset.py diff --git a/server/osa/domain/discovery/query/search_features.py b/server/osa/domain/discovery/query/search_features.py index 8401ff6..4019dcf 100644 --- a/server/osa/domain/discovery/query/search_features.py +++ b/server/osa/domain/discovery/query/search_features.py @@ -7,6 +7,7 @@ ) from osa.domain.discovery.service.discovery import DiscoveryService from osa.domain.shared.authorization.gate import public +from osa.domain.shared.error import ValidationError from osa.domain.shared.model.srn import RecordSRN from osa.domain.shared.query import Query, QueryHandler, Result @@ -32,7 +33,12 @@ class SearchFeaturesHandler(QueryHandler[SearchFeatures, SearchFeaturesResult]): discovery_service: DiscoveryService async def run(self, cmd: SearchFeatures) -> SearchFeaturesResult: - record_srn = RecordSRN.parse(cmd.record_srn) if cmd.record_srn else None + record_srn: RecordSRN | None = None + if cmd.record_srn: + try: + record_srn = RecordSRN.parse(cmd.record_srn) + except ValueError as exc: + raise ValidationError(str(exc), field="record_srn") from exc result: FeatureSearchResult = await self.discovery_service.search_features( hook_name=cmd.hook_name, filters=cmd.filters, diff --git a/server/osa/infrastructure/persistence/adapter/discovery.py b/server/osa/infrastructure/persistence/adapter/discovery.py index f3cf0ba..23af2bc 100644 --- a/server/osa/infrastructure/persistence/adapter/discovery.py +++ b/server/osa/infrastructure/persistence/adapter/discovery.py @@ -38,6 +38,7 @@ build_feature_table, data_columns, ) +from osa.infrastructure.persistence.keyset import KeysetPage, SortKey from osa.infrastructure.persistence.tables import ( feature_tables_table, records_table, @@ -128,30 +129,18 @@ async def search_records( else: sort_expr = t.c.metadata[sort].astext - # Sort direction - if order == SortOrder.ASC: - order_clauses = [sort_expr.asc().nullslast(), t.c.srn.asc()] - else: - order_clauses = [sort_expr.desc().nullslast(), t.c.srn.desc()] + # Keyset pagination with correct NULL handling + is_desc = order == SortOrder.DESC + page = KeysetPage( + [ + SortKey(sort_expr, descending=is_desc, nulls_last=True), + SortKey(t.c.srn, descending=is_desc), + ] + ) + order_clauses = page.order_by() - # Keyset cursor if cursor is not None: - cursor_sort = cursor["s"] - cursor_id = cursor["id"] - if order == SortOrder.ASC: - conditions.append( - or_( - sort_expr > cursor_sort, - and_(sort_expr == cursor_sort, t.c.srn > cursor_id), - ) - ) - else: - conditions.append( - or_( - sort_expr < cursor_sort, - and_(sort_expr == cursor_sort, t.c.srn < cursor_id), - ) - ) + conditions.append(page.after((cursor["s"], cursor["id"]))) where_clause = and_(*conditions) if conditions else true() @@ -282,29 +271,18 @@ async def search_features( else: sort_expr = ft.c[sort] - if order == SortOrder.ASC: - order_clauses = [sort_expr.asc(), ft.c.id.asc()] - else: - order_clauses = [sort_expr.desc(), ft.c.id.desc()] + # Keyset pagination with correct NULL handling + is_desc = order == SortOrder.DESC + page = KeysetPage( + [ + SortKey(sort_expr, descending=is_desc, nulls_last=True), + SortKey(ft.c.id, descending=is_desc), + ] + ) + order_clauses = page.order_by() - # Keyset cursor if cursor is not None: - cursor_sort = cursor["s"] - cursor_id = cursor["id"] - if order == SortOrder.ASC: - conditions.append( - or_( - sort_expr > cursor_sort, - and_(sort_expr == cursor_sort, ft.c.id > cursor_id), - ) - ) - else: - conditions.append( - or_( - sort_expr < cursor_sort, - and_(sort_expr == cursor_sort, ft.c.id < cursor_id), - ) - ) + conditions.append(page.after((cursor["s"], cursor["id"]))) where_clause = and_(*conditions) if conditions else true() diff --git a/server/osa/infrastructure/persistence/adapter/feature_reader.py b/server/osa/infrastructure/persistence/adapter/feature_reader.py index beeb07c..3145f3c 100644 --- a/server/osa/infrastructure/persistence/adapter/feature_reader.py +++ b/server/osa/infrastructure/persistence/adapter/feature_reader.py @@ -4,7 +4,7 @@ from typing import Any -from sqlalchemy import func, literal, select, union_all +from sqlalchemy import String, func, literal, select, type_coerce, union_all from sqlalchemy.ext.asyncio import AsyncSession from osa.domain.shared.model.srn import RecordSRN @@ -49,7 +49,7 @@ async def get_features_for_record( # Build jsonb_build_object('col1', col1, 'col2', col2, ...) jsonb_args: list[Any] = [] for col in dcols: - jsonb_args.extend([literal(col.key), col]) + jsonb_args.extend([type_coerce(literal(col.key), String), col]) row_data_expr = ( func.jsonb_build_object(*jsonb_args) if jsonb_args else func.jsonb_build_object() diff --git a/server/osa/infrastructure/persistence/keyset.py b/server/osa/infrastructure/persistence/keyset.py new file mode 100644 index 0000000..dc42392 --- /dev/null +++ b/server/osa/infrastructure/persistence/keyset.py @@ -0,0 +1,114 @@ +"""Keyset pagination helpers with correct NULL semantics. + +Derives both ORDER BY and WHERE predicate from a single sort specification +so that NULL handling is consistent between the two. + +Key insight for NULLS LAST ordering: +- Non-null cursor value: "strictly after" must include ``OR expr IS NULL`` + because NULLs sort after all non-null values. +- Null cursor value: only the tiebreaker applies + (``sort IS NULL AND id > cursor_id``), since nothing comes after the + NULL region except more NULLs distinguished by the tiebreaker. + +NULLS FIRST is the mirror image. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Sequence + +from sqlalchemy import ColumnElement, UnaryExpression, and_, false, or_ + + +@dataclass(frozen=True) +class SortKey: + """One column in a multi-column keyset sort.""" + + expression: ColumnElement[Any] + descending: bool = False + nulls_last: bool = True + + def order_clause(self) -> UnaryExpression[Any]: + clause = self.expression.desc() if self.descending else self.expression.asc() + return clause.nullslast() if self.nulls_last else clause.nullsfirst() + + +class KeysetPage: + """Build ORDER BY + WHERE predicate for keyset pagination. + + Usage:: + + page = KeysetPage([ + SortKey(sort_expr, descending=is_desc, nulls_last=True), + SortKey(t.c.id, descending=is_desc), + ]) + stmt = stmt.order_by(*page.order_by()) + if cursor: + stmt = stmt.where(page.after(cursor_values)) + """ + + def __init__(self, keys: Sequence[SortKey]) -> None: + self._keys = list(keys) + + def order_by(self) -> list[UnaryExpression[Any]]: + return [k.order_clause() for k in self._keys] + + def after(self, cursor_values: tuple[Any, ...]) -> ColumnElement[Any]: + """Build the WHERE predicate for "rows strictly after this cursor".""" + if len(cursor_values) != len(self._keys): + raise ValueError( + f"Cursor length {len(cursor_values)} does not match key length {len(self._keys)}" + ) + + # Build from right to left: for keys (k0, k1), the predicate is + # strictly_after(k0, v0) OR (eq(k0, v0) AND strictly_after(k1, v1)) + result: ColumnElement[Any] = false() + for i in range(len(self._keys) - 1, -1, -1): + after_i = _strictly_after(self._keys[i], cursor_values[i]) + if after_i is None: + # No rows can follow on this key alone (null + nulls_last) + # but tiebreaker may still apply via the eq branch below + eq_part = _null_eq(self._keys[i].expression, cursor_values[i]) + result = and_(eq_part, result) + elif i == len(self._keys) - 1: + result = after_i + else: + eq_part = _null_eq(self._keys[i].expression, cursor_values[i]) + result = or_(after_i, and_(eq_part, result)) + + return result + + +def _null_eq(expr: ColumnElement[Any], value: Any) -> ColumnElement[Any]: + """``IS NULL`` when value is None, else ``= value``.""" + if value is None: + return expr.is_(None) + return expr == value + + +def _strictly_after(key: SortKey, value: Any) -> ColumnElement[Any] | None: + """Rows that come strictly after *value* according to this key's ordering. + + Returns ``None`` when no rows can follow (null cursor + nulls_last). + """ + expr = key.expression + + if value is None: + # Cursor is at the NULL region + if key.nulls_last: + # NULLs are last → nothing comes after + return None + else: + # NULLs are first → everything non-null comes after + return expr.is_not(None) + + # Cursor is at a non-null value + gt = expr < value if key.descending else expr > value + + if key.nulls_last: + # NULLs come after all non-nulls → include them + return or_(gt, expr.is_(None)) + else: + # NULLs came before all non-nulls → they're already passed + return gt diff --git a/server/tests/unit/domain/discovery/test_search_features.py b/server/tests/unit/domain/discovery/test_search_features.py index 55eb836..b989dca 100644 --- a/server/tests/unit/domain/discovery/test_search_features.py +++ b/server/tests/unit/domain/discovery/test_search_features.py @@ -71,6 +71,16 @@ async def test_delegates_to_service(self) -> None: await handler.run(SearchFeatures(hook_name="detect_pockets")) mock_service.search_features.assert_called_once() + async def test_invalid_record_srn_raises_validation_error(self) -> None: + mock_service = AsyncMock() + handler = SearchFeaturesHandler(discovery_service=mock_service) + + with pytest.raises(ValidationError, match="not an OSA SRN") as exc_info: + await handler.run(SearchFeatures(hook_name="detect_pockets", record_srn="not-a-srn")) + + assert exc_info.value.field == "record_srn" + mock_service.search_features.assert_not_called() + async def test_maps_rows_with_record_srn(self) -> None: srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") mock_service = AsyncMock() diff --git a/server/tests/unit/infrastructure/persistence/test_keyset.py b/server/tests/unit/infrastructure/persistence/test_keyset.py new file mode 100644 index 0000000..62e5a02 --- /dev/null +++ b/server/tests/unit/infrastructure/persistence/test_keyset.py @@ -0,0 +1,192 @@ +"""Tests for keyset pagination helpers — NULL × direction matrix.""" + +from __future__ import annotations + +import pytest +from typing import Any + +from sqlalchemy import column + +from osa.infrastructure.persistence.keyset import KeysetPage, SortKey + + +def _compile(clause: Any) -> str: + """Compile a SQLAlchemy clause element to a raw SQL string for assertions.""" + return str(clause.compile(compile_kwargs={"literal_binds": True})) + + +# --------------------------------------------------------------------------- +# SortKey.order_clause +# --------------------------------------------------------------------------- + + +class TestSortKeyOrderClause: + def test_asc_nulls_last(self) -> None: + key = SortKey(expression=column("score"), descending=False, nulls_last=True) + sql = _compile(key.order_clause()) + assert "ASC" in sql + assert "NULLS LAST" in sql + + def test_desc_nulls_last(self) -> None: + key = SortKey(expression=column("score"), descending=True, nulls_last=True) + sql = _compile(key.order_clause()) + assert "DESC" in sql + assert "NULLS LAST" in sql + + def test_asc_nulls_first(self) -> None: + key = SortKey(expression=column("score"), descending=False, nulls_last=False) + sql = _compile(key.order_clause()) + assert "ASC" in sql + assert "NULLS FIRST" in sql + + def test_desc_nulls_first(self) -> None: + key = SortKey(expression=column("score"), descending=True, nulls_last=False) + sql = _compile(key.order_clause()) + assert "DESC" in sql + assert "NULLS FIRST" in sql + + +# --------------------------------------------------------------------------- +# KeysetPage.after — non-null cursor sort value +# --------------------------------------------------------------------------- + + +class TestKeysetAfterNonNull: + """Cursor sort value is NOT None — standard keyset with NULL awareness.""" + + def test_asc_nulls_last(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=False, nulls_last=True), + SortKey(column("id"), descending=False), + ] + ) + sql = _compile(page.after((5, "abc"))) + # Must include OR score IS NULL (NULLs come after non-nulls) + assert "score > 5" in sql + assert "score IS NULL" in sql + assert "id > 'abc'" in sql + + def test_desc_nulls_last(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=True, nulls_last=True), + SortKey(column("id"), descending=True), + ] + ) + sql = _compile(page.after((5, "abc"))) + assert "score < 5" in sql + assert "score IS NULL" in sql + assert "id < 'abc'" in sql + + def test_asc_nulls_first(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=False, nulls_last=False), + SortKey(column("id"), descending=False, nulls_last=False), + ] + ) + sql = _compile(page.after((5, "abc"))) + assert "score > 5" in sql + # No score IS NULL — nulls already came before + assert "score IS NULL" not in sql + assert "id > 'abc'" in sql + + def test_desc_nulls_first(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=True, nulls_last=False), + SortKey(column("id"), descending=True, nulls_last=False), + ] + ) + sql = _compile(page.after((5, "abc"))) + assert "score < 5" in sql + assert "score IS NULL" not in sql + assert "id < 'abc'" in sql + + +# --------------------------------------------------------------------------- +# KeysetPage.after — null cursor sort value (the core bug) +# --------------------------------------------------------------------------- + + +class TestKeysetAfterNull: + """Cursor sort IS None — must avoid `sort > NULL` which is always false in SQL.""" + + def test_asc_nulls_last_null_cursor(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=False, nulls_last=True), + SortKey(column("id"), descending=False), + ] + ) + sql = _compile(page.after((None, "abc"))) + # Nothing after NULLs in NULLS LAST except by tiebreaker + assert "score IS NULL" in sql + assert "id > 'abc'" in sql + # Must NOT contain score > or score < + assert "score >" not in sql + assert "score <" not in sql + + def test_desc_nulls_last_null_cursor(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=True, nulls_last=True), + SortKey(column("id"), descending=True), + ] + ) + sql = _compile(page.after((None, "abc"))) + assert "score IS NULL" in sql + assert "id < 'abc'" in sql + assert "score >" not in sql + assert "score <" not in sql + + def test_asc_nulls_first_null_cursor(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=False, nulls_last=False), + SortKey(column("id"), descending=False), + ] + ) + sql = _compile(page.after((None, "abc"))) + # NULLs are first, so everything non-null comes after, plus tiebreaker among NULLs + assert "IS NOT NULL" in sql + assert "id > 'abc'" in sql + + def test_desc_nulls_first_null_cursor(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=True, nulls_last=False), + SortKey(column("id"), descending=True), + ] + ) + sql = _compile(page.after((None, "abc"))) + assert "IS NOT NULL" in sql + assert "id < 'abc'" in sql + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestKeysetEdgeCases: + def test_mismatched_cursor_length_raises(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=False, nulls_last=True), + SortKey(column("id"), descending=False), + ] + ) + with pytest.raises(ValueError, match="length"): + page.after((1, 2, 3)) + + def test_order_by_returns_all_keys(self) -> None: + page = KeysetPage( + [ + SortKey(column("score"), descending=False, nulls_last=True), + SortKey(column("id"), descending=False), + ] + ) + clauses = page.order_by() + assert len(clauses) == 2 From d6f36651b1dc939791be7d1a65d94a67ded64bd3 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 7 Mar 2026 04:00:30 +0000 Subject: [PATCH 6/6] fix: eliminate has_more false positive with N+1 pagination trick Service now fetches limit+1 rows from the adapter and uses the extra row as the "more" signal, then slices back to limit before returning. This prevents clients from making a wasted round-trip on exact-limit pages. --- .../osa/domain/discovery/service/discovery.py | 10 ++- .../discovery/test_discovery_service.py | 80 ++++++++++++++++++- .../domain/discovery/test_search_features.py | 65 +++++++++++++++ 3 files changed, 147 insertions(+), 8 deletions(-) diff --git a/server/osa/domain/discovery/service/discovery.py b/server/osa/domain/discovery/service/discovery.py index 0edd638..5d9bab0 100644 --- a/server/osa/domain/discovery/service/discovery.py +++ b/server/osa/domain/discovery/service/discovery.py @@ -94,11 +94,12 @@ async def search_records( sort=sort, order=order, cursor=decoded_cursor, - limit=limit, + limit=limit + 1, field_types=field_map, ) - has_more = len(results) == limit + has_more = len(results) > limit + results = results[:limit] next_cursor = None if has_more and results: last = results[-1] @@ -189,10 +190,11 @@ async def search_features( sort=sort, order=order, cursor=decoded_cursor, - limit=limit, + limit=limit + 1, ) - has_more = len(rows) == limit + has_more = len(rows) > limit + rows = rows[:limit] next_cursor = None if has_more and rows: last = rows[-1] diff --git a/server/tests/unit/domain/discovery/test_discovery_service.py b/server/tests/unit/domain/discovery/test_discovery_service.py index 5a2d6b8..c38a79d 100644 --- a/server/tests/unit/domain/discovery/test_discovery_service.py +++ b/server/tests/unit/domain/discovery/test_discovery_service.py @@ -161,7 +161,7 @@ async def test_delegates_to_read_store( assert len(call_kwargs.kwargs["filters"]) == 1 assert call_kwargs.kwargs["q"] is None assert call_kwargs.kwargs["sort"] == "published_at" - assert call_kwargs.kwargs["limit"] == 20 + assert call_kwargs.kwargs["limit"] == 21 # N+1 trick async def test_extracts_text_fields_for_q( self, service: DiscoveryService, mock_read_store: AsyncMock @@ -218,8 +218,9 @@ async def test_encodes_next_cursor_from_results( ) -> None: srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") ts = datetime(2026, 1, 1, tzinfo=UTC) + # Return limit+1 rows so the service detects has_more=True mock_read_store.search_records.return_value = [ - RecordSummary(srn=srn, published_at=ts, metadata={"title": "Test"}) + RecordSummary(srn=srn, published_at=ts, metadata={"title": f"r{i}"}) for i in range(2) ] result = await service.search_records( @@ -233,6 +234,7 @@ async def test_encodes_next_cursor_from_results( assert result.has_more is True assert result.cursor is not None + assert len(result.results) == 1 # trimmed back to limit from osa.domain.discovery.model.value import decode_cursor @@ -257,6 +259,70 @@ async def test_no_cursor_when_no_more_results( assert result.has_more is False +class TestSearchRecordsPagination: + async def test_has_more_false_when_exactly_limit_rows( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + """Exactly limit rows should NOT report has_more (no false positive).""" + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + ts = datetime(2026, 1, 1, tzinfo=UTC) + mock_read_store.search_records.return_value = [ + RecordSummary(srn=srn, published_at=ts, metadata={"title": f"r{i}"}) for i in range(3) + ] + + result = await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=3, + ) + + assert result.has_more is False + assert result.cursor is None + assert len(result.results) == 3 + + async def test_has_more_true_when_more_than_limit_rows( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + """Adapter returning limit+1 rows signals more pages exist.""" + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + ts = datetime(2026, 1, 1, tzinfo=UTC) + mock_read_store.search_records.return_value = [ + RecordSummary(srn=srn, published_at=ts, metadata={"title": f"r{i}"}) for i in range(4) + ] + + result = await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=3, + ) + + assert result.has_more is True + assert result.cursor is not None + assert len(result.results) == 3 # trimmed back to limit + + async def test_passes_limit_plus_one_to_read_store( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + """Service should fetch one extra row to detect more pages.""" + await service.search_records( + filters=[], + q=None, + sort="published_at", + order=SortOrder.DESC, + cursor=None, + limit=20, + ) + + call_kwargs = mock_read_store.search_records.call_args + assert call_kwargs.kwargs["limit"] == 21 + + class TestSearchRecordsFieldTypes: async def test_passes_field_types_to_read_store( self, service: DiscoveryService, mock_read_store: AsyncMock @@ -288,8 +354,10 @@ async def test_cursor_encodes_row_id( columns=[ColumnInfo(name="score", type="number", required=True)], record_count=0, ) + # Return limit+1 rows so the service detects has_more=True mock_read_store.search_features.return_value = [ - FeatureRow(row_id=42, record_srn=srn, data={"score": 7.66}) + FeatureRow(row_id=42, record_srn=srn, data={"score": 7.66}), + FeatureRow(row_id=43, record_srn=srn, data={"score": 6.0}), ] service = DiscoveryService(read_store=mock_read_store, field_reader=mock_field_reader) @@ -305,6 +373,7 @@ async def test_cursor_encodes_row_id( assert result.has_more is True assert result.cursor is not None + assert len(result.rows) == 1 decoded = decode_cursor(result.cursor) assert decoded["id"] == 42 assert decoded["s"] == 7.66 @@ -320,8 +389,10 @@ async def test_cursor_uses_row_id_for_id_sort( columns=[ColumnInfo(name="score", type="number", required=True)], record_count=0, ) + # Return limit+1 rows so the service detects has_more=True mock_read_store.search_features.return_value = [ - FeatureRow(row_id=99, record_srn=srn, data={"score": 5.0}) + FeatureRow(row_id=99, record_srn=srn, data={"score": 5.0}), + FeatureRow(row_id=98, record_srn=srn, data={"score": 4.0}), ] service = DiscoveryService(read_store=mock_read_store, field_reader=mock_field_reader) @@ -337,6 +408,7 @@ async def test_cursor_uses_row_id_for_id_sort( assert result.has_more is True assert result.cursor is not None + assert len(result.rows) == 1 decoded = decode_cursor(result.cursor) # When sort is "id", sort_val should be the row_id itself assert decoded["s"] == 99 diff --git a/server/tests/unit/domain/discovery/test_search_features.py b/server/tests/unit/domain/discovery/test_search_features.py index b989dca..8c46b1f 100644 --- a/server/tests/unit/domain/discovery/test_search_features.py +++ b/server/tests/unit/domain/discovery/test_search_features.py @@ -216,3 +216,68 @@ async def test_delegates_to_read_store( call_kwargs = mock_read_store.search_features.call_args assert call_kwargs.kwargs["hook_name"] == "detect_pockets" assert len(call_kwargs.kwargs["filters"]) == 1 + + +class TestSearchFeaturesPagination: + async def test_has_more_false_when_exactly_limit_rows( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + """Exactly limit rows should NOT report has_more (no false positive).""" + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + mock_read_store.search_features.return_value = [ + FeatureRow(row_id=i, record_srn=srn, data={"score": float(i)}) for i in range(3) + ] + + result = await service.search_features( + hook_name="detect_pockets", + filters=[], + record_srn=None, + sort="score", + order=SortOrder.DESC, + cursor=None, + limit=3, + ) + + assert result.has_more is False + assert result.cursor is None + assert len(result.rows) == 3 + + async def test_has_more_true_when_more_than_limit_rows( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + """Adapter returning limit+1 rows signals more pages exist.""" + srn = RecordSRN.parse("urn:osa:localhost:rec:abc@1") + mock_read_store.search_features.return_value = [ + FeatureRow(row_id=i, record_srn=srn, data={"score": float(i)}) for i in range(4) + ] + + result = await service.search_features( + hook_name="detect_pockets", + filters=[], + record_srn=None, + sort="score", + order=SortOrder.DESC, + cursor=None, + limit=3, + ) + + assert result.has_more is True + assert result.cursor is not None + assert len(result.rows) == 3 + + async def test_passes_limit_plus_one_to_read_store( + self, service: DiscoveryService, mock_read_store: AsyncMock + ) -> None: + """Service should fetch one extra row to detect more pages.""" + await service.search_features( + hook_name="detect_pockets", + filters=[], + record_srn=None, + sort="score", + order=SortOrder.DESC, + cursor=None, + limit=50, + ) + + call_kwargs = mock_read_store.search_features.call_args + assert call_kwargs.kwargs["limit"] == 51