diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e40dab2..3637451 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: push: branches: [main] pull_request: - branches: [main] + # Run on PRs to any branch (not just main) env: PYTHON_VERSION: "3.13" diff --git a/.gitignore b/.gitignore index ce19d52..55d654d 100644 --- a/.gitignore +++ b/.gitignore @@ -237,3 +237,4 @@ web/.turbo/ # Environment .env !.env.example +server/osa.yaml diff --git a/Justfile b/Justfile index 482c602..0add92f 100644 --- a/Justfile +++ b/Justfile @@ -74,6 +74,10 @@ dev-detached: dev-down: docker compose -f deploy/docker-compose.yml -f deploy/docker-compose.dev.yml down +# Open the web UI in browser +open-ui: + open http://localhost:8080 + # === Individual Service Development === # Run server independently (requires database) diff --git a/server/Dockerfile b/server/Dockerfile index 989d0eb..6b0d381 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -29,6 +29,9 @@ COPY osa.yaml ./config.yaml RUN --mount=type=cache,target=/root/.cache/uv \ uv sync --frozen --no-dev +# Pre-download embedding model to bake into image (avoids runtime download) +RUN /app/.venv/bin/python -c "from sentence_transformers import SentenceTransformer; SentenceTransformer('all-MiniLM-L6-v2')" + # Stage 2: Runtime FROM python:3.13-slim-bookworm AS runtime @@ -43,6 +46,9 @@ WORKDIR /app # Copy the virtual environment from builder COPY --from=builder --chown=appuser:appuser /app/.venv /app/.venv + +# Copy pre-downloaded embedding model cache (avoids runtime download) +COPY --from=builder --chown=appuser:appuser /root/.cache/huggingface /home/appuser/.cache/huggingface COPY --from=builder --chown=appuser:appuser /app/osa /app/osa COPY --from=builder --chown=appuser:appuser /app/sources /app/sources COPY --from=builder --chown=appuser:appuser /app/migrations /app/migrations @@ -65,7 +71,7 @@ USER appuser EXPOSE 8000 # Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ +HEALTHCHECK --interval=5s --timeout=10s --start-period=5s --retries=3 \ CMD curl --fail http://localhost:8000/api/v1/health || exit 1 ENTRYPOINT ["/app/entrypoint.sh"] diff --git a/server/migrations/versions/add_worker_columns.py b/server/migrations/versions/add_worker_columns.py new file mode 100644 index 0000000..d9c6c7b --- /dev/null +++ b/server/migrations/versions/add_worker_columns.py @@ -0,0 +1,72 @@ +"""add_worker_columns + +Add columns and indexes to events table for pull-based worker architecture. + +Revision ID: add_worker_columns +Revises: 0d9fbacf8e58 +Create Date: 2026-02-02 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "add_worker_columns" +down_revision: Union[str, Sequence[str], None] = "0d9fbacf8e58" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add worker columns to events table.""" + # Add new columns for pull-based claiming + op.add_column("events", sa.Column("routing_key", sa.String(255), nullable=True)) + op.add_column( + "events", sa.Column("retry_count", sa.Integer(), nullable=False, server_default="0") + ) + op.add_column("events", sa.Column("claimed_at", sa.DateTime(timezone=True), nullable=True)) + op.add_column( + "events", + sa.Column( + "updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now() + ), + ) + + # Create partial index for efficient claiming query + # Covers: status=pending/claimed, event_type, routing_key, created_at + op.create_index( + "idx_events_claim", + "events", + ["delivery_status", "event_type", "routing_key", "created_at"], + postgresql_where=sa.text("delivery_status IN ('pending', 'claimed')"), + ) + + # Create partial index for stale claim detection + op.create_index( + "idx_events_stale_claims", + "events", + ["claimed_at"], + postgresql_where=sa.text("delivery_status = 'claimed'"), + ) + + # Create partial index for failed event queries + op.create_index( + "idx_events_failed", + "events", + ["event_type", "created_at"], + postgresql_where=sa.text("delivery_status = 'failed'"), + ) + + +def downgrade() -> None: + """Remove worker columns from events table.""" + op.drop_index("idx_events_failed", table_name="events") + op.drop_index("idx_events_stale_claims", table_name="events") + op.drop_index("idx_events_claim", table_name="events") + op.drop_column("events", "updated_at") + op.drop_column("events", "claimed_at") + op.drop_column("events", "retry_count") + op.drop_column("events", "routing_key") diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index cd6af1d..f997a6f 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -9,12 +9,10 @@ from osa.application.api.v1.routes import events, health, records, search, stats, validation from osa.application.di import create_container from osa.config import Config, configure_logging -from osa.domain.index.service import IndexService from osa.domain.shared.error import OSAError -from osa.infrastructure.event.worker import BackgroundWorker +from osa.infrastructure.event.worker import WorkerPool from osa.infrastructure.source.discovery import validate_sources_at_startup from osa.util.di.fastapi import setup_dishka -from osa.util.di.scope import Scope logger = logging.getLogger(__name__) @@ -23,16 +21,11 @@ async def lifespan(app: FastAPI): container = app.state.dishka_container - # Run background worker (emits ServerStarted internally) - worker = await container.get(BackgroundWorker) - async with worker: - yield + # Run unified worker pool (pull-based event handlers + scheduled tasks) + worker_pool = await container.get(WorkerPool) - # Flush all index backends on shutdown to ensure buffered records are persisted - logger.info("Flushing index backends on shutdown...") - async with container(scope=Scope.UOW) as scope: - index_service = await scope.get(IndexService) - await index_service.flush_all() + async with worker_pool: + yield await container.close() diff --git a/server/osa/domain/curation/handler/__init__.py b/server/osa/domain/curation/handler/__init__.py new file mode 100644 index 0000000..0115200 --- /dev/null +++ b/server/osa/domain/curation/handler/__init__.py @@ -0,0 +1,5 @@ +"""Curation domain event handlers.""" + +from osa.domain.curation.handler.auto_approve_curation import AutoApproveCuration + +__all__ = ["AutoApproveCuration"] diff --git a/server/osa/domain/curation/listener/auto_approve_curation_tool.py b/server/osa/domain/curation/handler/auto_approve_curation.py similarity index 84% rename from server/osa/domain/curation/listener/auto_approve_curation_tool.py rename to server/osa/domain/curation/handler/auto_approve_curation.py index 1dc551e..12a6d99 100644 --- a/server/osa/domain/curation/listener/auto_approve_curation_tool.py +++ b/server/osa/domain/curation/handler/auto_approve_curation.py @@ -1,10 +1,10 @@ -"""AutoApproveCurationTool - auto-approves depositions on validation completion.""" +"""AutoApproveCuration - auto-approves depositions on validation completion.""" import logging from uuid import uuid4 from osa.domain.curation.event.deposition_approved import DepositionApproved -from osa.domain.shared.event import EventId, EventListener +from osa.domain.shared.event import EventHandler, EventId from osa.domain.shared.outbox import Outbox from osa.domain.validation.event.validation_completed import ValidationCompleted from osa.domain.validation.model import RunStatus @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -class AutoApproveCurationTool(EventListener[ValidationCompleted]): +class AutoApproveCuration(EventHandler[ValidationCompleted]): """Auto-approves validation and emits DepositionApproved. 0 curation = instant approve.""" outbox: Outbox @@ -40,6 +40,5 @@ async def handle(self, event: ValidationCompleted) -> None: ) await self.outbox.append(approved) - # Session commit handled by BackgroundWorker logger.debug(f"Deposition approved: {event.deposition_srn}") diff --git a/server/osa/domain/curation/listener/__init__.py b/server/osa/domain/curation/listener/__init__.py deleted file mode 100644 index 90e03c0..0000000 --- a/server/osa/domain/curation/listener/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Curation domain listeners.""" - -from osa.domain.curation.listener.auto_approve_curation_tool import ( - AutoApproveCurationTool, -) - -__all__ = ["AutoApproveCurationTool"] diff --git a/server/osa/domain/index/handler/__init__.py b/server/osa/domain/index/handler/__init__.py new file mode 100644 index 0000000..75cfab4 --- /dev/null +++ b/server/osa/domain/index/handler/__init__.py @@ -0,0 +1,7 @@ +"""Index domain event handlers.""" + +from osa.domain.index.handler.fanout_to_index_backends import FanOutToIndexBackends +from osa.domain.index.handler.keyword_index_handler import KeywordIndexHandler +from osa.domain.index.handler.vector_index_handler import VectorIndexHandler + +__all__ = ["FanOutToIndexBackends", "KeywordIndexHandler", "VectorIndexHandler"] diff --git a/server/osa/domain/index/listener/fanout_listener.py b/server/osa/domain/index/handler/fanout_to_index_backends.py similarity index 75% rename from server/osa/domain/index/listener/fanout_listener.py rename to server/osa/domain/index/handler/fanout_to_index_backends.py index d1c66f4..cedbf92 100644 --- a/server/osa/domain/index/listener/fanout_listener.py +++ b/server/osa/domain/index/handler/fanout_to_index_backends.py @@ -1,28 +1,34 @@ """FanOutToIndexBackends - creates per-backend IndexRecord events from RecordPublished.""" import logging +from typing import ClassVar from uuid import uuid4 from osa.domain.index.event.index_record import IndexRecord from osa.domain.index.model.registry import IndexRegistry from osa.domain.record.event.record_published import RecordPublished -from osa.domain.shared.event import EventId, EventListener +from osa.domain.shared.event import EventHandler, EventId from osa.domain.shared.outbox import Outbox logger = logging.getLogger(__name__) -class FanOutToIndexBackends(EventListener[RecordPublished]): +class FanOutToIndexBackends(EventHandler[RecordPublished]): """Creates per-backend IndexRecord events from RecordPublished. - When a record is published, this listener creates one IndexRecord event + When a record is published, this handler creates one IndexRecord event per registered backend. Each IndexRecord is stored in the outbox, enabling independent retry and failure isolation per backend. This replaces the previous pattern where a single RecordPublished event triggered immediate indexing to all backends in a single transaction. + + Batch processing is used for efficiency when multiple records are + published in quick succession. """ + __batch_size__: ClassVar[int] = 10 + indexes: IndexRegistry outbox: Outbox @@ -38,4 +44,4 @@ async def handle(self, event: RecordPublished) -> None: record_srn=event.record_srn, metadata=event.metadata, ) - await self.outbox.append(index_event) + await self.outbox.append(index_event, routing_key=backend_name) diff --git a/server/osa/domain/index/handler/keyword_index_handler.py b/server/osa/domain/index/handler/keyword_index_handler.py new file mode 100644 index 0000000..488524d --- /dev/null +++ b/server/osa/domain/index/handler/keyword_index_handler.py @@ -0,0 +1,39 @@ +"""KeywordIndexHandler - processes IndexRecord events for keyword backends.""" + +import logging +from typing import ClassVar + +from osa.domain.index.event.index_record import IndexRecord +from osa.domain.index.model.registry import IndexRegistry +from osa.domain.shared.error import SkippedEvents +from osa.domain.shared.event import EventHandler + +logger = logging.getLogger(__name__) + + +class KeywordIndexHandler(EventHandler[IndexRecord]): + """Processes IndexRecord events for the keyword backend. + + Claims events with routing_key="keyword" and processes them immediately + (batch_size=1) since keyword indexing doesn't benefit from batching. + """ + + __routing_key__: ClassVar[str | None] = "keyword" + __batch_size__: ClassVar[int] = 1 + + indexes: IndexRegistry + + async def handle(self, event: IndexRecord) -> None: + """Process a single IndexRecord event.""" + backend = self.indexes.get("keyword") + if backend is None: + raise SkippedEvents( + event_ids=[event.id], + reason="Keyword backend not available", + ) + + record = (str(event.record_srn), event.metadata) + + await backend.ingest_batch([record]) + + logger.debug(f"KeywordIndexHandler: indexed event {event.id}") diff --git a/server/osa/domain/index/handler/vector_index_handler.py b/server/osa/domain/index/handler/vector_index_handler.py new file mode 100644 index 0000000..4034184 --- /dev/null +++ b/server/osa/domain/index/handler/vector_index_handler.py @@ -0,0 +1,49 @@ +"""VectorIndexHandler - processes IndexRecord events for vector backends.""" + +import logging +from typing import ClassVar + +from osa.domain.index.event.index_record import IndexRecord +from osa.domain.index.model.registry import IndexRegistry +from osa.domain.shared.error import SkippedEvents +from osa.domain.shared.event import EventHandler + +logger = logging.getLogger(__name__) + + +class VectorIndexHandler(EventHandler[IndexRecord]): + """Processes IndexRecord events for the vector backend. + + Claims events with routing_key="vector" and processes them in batches + for efficient embedding generation. + """ + + __routing_key__: ClassVar[str | None] = "vector" + __batch_size__: ClassVar[int] = 100 + __batch_timeout__: ClassVar[float] = 5.0 + + indexes: IndexRegistry + + async def handle_batch(self, events: list[IndexRecord]) -> None: + """Process a batch of IndexRecord events. + + Converts events to records and calls ingest_batch on the backend. + """ + if not events: + return + + backend = self.indexes.get("vector") + if backend is None: + raise SkippedEvents( + event_ids=[e.id for e in events], + reason="Vector backend not available", + ) + + # Prepare records for batch ingestion + records = [(str(e.record_srn), e.metadata) for e in events] + + logger.debug(f"VectorIndexHandler: ingesting {len(records)} records to backend") + + await backend.ingest_batch(records) + + logger.debug(f"VectorIndexHandler: ingested {len(events)} records") diff --git a/server/osa/domain/index/listener/__init__.py b/server/osa/domain/index/listener/__init__.py deleted file mode 100644 index 04f3ccc..0000000 --- a/server/osa/domain/index/listener/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -"""Index domain listeners.""" - -from osa.domain.index.listener.fanout_listener import FanOutToIndexBackends -from osa.domain.index.listener.index_batch_listener import IndexRecordBatch - -__all__ = [ - "FanOutToIndexBackends", - "IndexRecordBatch", -] diff --git a/server/osa/domain/index/listener/index_batch_listener.py b/server/osa/domain/index/listener/index_batch_listener.py deleted file mode 100644 index 7e414bc..0000000 --- a/server/osa/domain/index/listener/index_batch_listener.py +++ /dev/null @@ -1,92 +0,0 @@ -"""IndexRecordBatch - batch processes IndexRecord events per backend.""" - -import logging -from collections import defaultdict - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.error import SkippedEventsError -from osa.domain.shared.event import BatchEventListener - -logger = logging.getLogger(__name__) - - -class IndexRecordBatch(BatchEventListener[IndexRecord]): - """Batch processes IndexRecord events by grouping per backend. - - The BackgroundWorker groups IndexRecord events and calls handle_batch() - with all events of this type. This listener further groups events by - backend_name and calls ingest_batch() on each backend. - - This enables: - - Efficient batch embedding generation - - Crash-safe processing (events remain in outbox until committed) - - Note: If a backend is not found (e.g., removed from config), raises - SkippedEventsError so those events are marked as skipped rather than failed. - """ - - indexes: IndexRegistry - - async def handle_batch(self, events: list[IndexRecord]) -> None: - """Process a batch of IndexRecord events grouped by backend. - - Args: - events: List of IndexRecord events to process - - Raises: - SkippedEventsError: If backend not found (events should be skipped) - RuntimeError: If backend fails to index (events will be retried) - """ - if not events: - return - - # Group events by backend - by_backend: dict[str, list[IndexRecord]] = defaultdict(list) - for event in events: - by_backend[event.backend_name].append(event) - - logger.debug( - f"IndexRecordBatch: grouping {len(events)} events for {len(by_backend)} backends" - ) - - # Process each backend's batch - for backend_name, backend_events in by_backend.items(): - backend = self.indexes.get(backend_name) - if backend is None: - # Backend not found (may have been removed from config) - # Raise SkippedEventsError so events are marked as skipped, not failed - record_srns = [str(e.record_srn) for e in backend_events] - reason = ( - f"Backend '{backend_name}' not found (may have been removed). " - f"Skipping {len(backend_events)} events. " - f"Records: {record_srns[:5]}{'...' if len(record_srns) > 5 else ''}" - ) - logger.error(reason) - raise SkippedEventsError( - event_ids=[e.id for e in backend_events], - reason=reason, - ) - - # Prepare records for batch ingestion - records = [(str(event.record_srn), event.metadata) for event in backend_events] - - logger.debug( - f"Batch indexing {len(records)} records to backend '{backend_name}' " - f"(batch efficiency: {len(records)} records in single call)" - ) - - try: - await backend.ingest_batch(records) - logger.debug(f"Indexed {len(records)} records to backend '{backend_name}'") - except Exception as e: - # Enhanced error with backend name and record SRNs (T025, T026) - record_srns = [srn for srn, _ in records] - error_context = ( - f"Backend '{backend_name}' failed to index {len(records)} records. " - f"Records: {record_srns[:3]}{'...' if len(record_srns) > 3 else ''}. " - f"Error: {e}" - ) - logger.error(error_context) - # Re-raise with context so worker can record in delivery_error - raise RuntimeError(error_context) from e diff --git a/server/osa/domain/record/handler/__init__.py b/server/osa/domain/record/handler/__init__.py new file mode 100644 index 0000000..b5b23b2 --- /dev/null +++ b/server/osa/domain/record/handler/__init__.py @@ -0,0 +1,5 @@ +"""Record domain event handlers.""" + +from osa.domain.record.handler.convert_deposition_to_record import ConvertDepositionToRecord + +__all__ = ["ConvertDepositionToRecord"] diff --git a/server/osa/domain/record/listener/record_creation_listener.py b/server/osa/domain/record/handler/convert_deposition_to_record.py similarity index 65% rename from server/osa/domain/record/listener/record_creation_listener.py rename to server/osa/domain/record/handler/convert_deposition_to_record.py index 8d4a283..1beeadf 100644 --- a/server/osa/domain/record/listener/record_creation_listener.py +++ b/server/osa/domain/record/handler/convert_deposition_to_record.py @@ -1,14 +1,14 @@ -"""RecordCreationListener - creates records when depositions are approved.""" +"""ConvertDepositionToRecord - creates records when depositions are approved.""" from osa.domain.curation.event.deposition_approved import DepositionApproved from osa.domain.record.service import RecordService -from osa.domain.shared.event import EventListener +from osa.domain.shared.event import EventHandler -class ConvertDepositionToRecord(EventListener[DepositionApproved]): +class ConvertDepositionToRecord(EventHandler[DepositionApproved]): """Creates and persists records when depositions are approved. - This listener delegates to RecordService for all business logic. + This handler delegates to RecordService for all business logic. """ service: RecordService diff --git a/server/osa/domain/record/listener/__init__.py b/server/osa/domain/record/listener/__init__.py deleted file mode 100644 index fb025b3..0000000 --- a/server/osa/domain/record/listener/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Record domain listeners.""" - -from osa.domain.record.listener.record_creation_listener import ( - ConvertDepositionToRecord, -) - -__all__ = ["ConvertDepositionToRecord"] diff --git a/server/osa/domain/shared/error.py b/server/osa/domain/shared/error.py index 8cbb887..ab50e83 100644 --- a/server/osa/domain/shared/error.py +++ b/server/osa/domain/shared/error.py @@ -77,11 +77,10 @@ class ConfigurationError(InfrastructureError): # ============================================================================= -class SkippedEventsError(Exception): - """Raised when some events should be skipped (not failed, not delivered). +class SkippedEvents(Exception): + """Raised when events should be skipped (not failed, not delivered). - Used when a backend is intentionally removed and leftover events should - be marked as skipped for clean semantic separation from failures. + Control flow exception for when a backend is unavailable or removed. """ def __init__(self, event_ids: list, reason: str) -> None: diff --git a/server/osa/domain/shared/event.py b/server/osa/domain/shared/event.py index e106324..914f2e2 100644 --- a/server/osa/domain/shared/event.py +++ b/server/osa/domain/shared/event.py @@ -1,12 +1,14 @@ -"""Domain events, event listeners, and scheduled tasks.""" +"""Domain events, event handlers, scheduled tasks, and worker infrastructure.""" from abc import ABC, ABCMeta, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import UTC, datetime +from enum import Enum from typing import ( Any, ClassVar, Generic, + Iterator, NewType, TypeVar, dataclass_transform, @@ -15,7 +17,7 @@ ) from uuid import UUID -from pydantic import Field +from pydantic import BaseModel, Field, field_validator, model_validator from osa.domain.shared.model.entity import Entity @@ -46,15 +48,114 @@ def __init_subclass__(cls, **kwargs: Any) -> None: cls._registry[cls.__name__] = cls -# --- EventListener (Subscription) --- +# --- Worker Infrastructure --- + + +class WorkerConfig(BaseModel): + """Configuration for a single worker instance. + + Attributes: + name: Unique worker identifier. + event_types: Event types to claim. + routing_key: Optional routing key filter. + batch_size: Max events per batch (default: 1). + batch_timeout: Max seconds to wait for batch (default: 5.0). + poll_interval: Seconds between polls when idle (default: 0.5). + max_retries: Max retry attempts before marking failed (default: 3). + claim_timeout: Seconds before claim considered stale (default: 300.0). + """ + + model_config = {"frozen": True} + + name: str + event_types: tuple[type["Event"], ...] + routing_key: str | None = None + batch_size: int = Field(default=1, ge=1) + batch_timeout: float = Field(default=5.0, gt=0) + poll_interval: float = Field(default=0.5, gt=0) + max_retries: int = Field(default=3, ge=0) + claim_timeout: float = Field(default=300.0, gt=0) + + @field_validator("event_types") + @classmethod + def event_types_not_empty(cls, v: tuple) -> tuple: + if not v: + raise ValueError("event_types must not be empty") + return v + + @model_validator(mode="after") + def claim_timeout_greater_than_batch_timeout(self) -> "WorkerConfig": + if self.claim_timeout <= self.batch_timeout: + raise ValueError("claim_timeout must be > batch_timeout") + return self + + +class WorkerStatus(Enum): + """Status of a running worker.""" + + IDLE = "idle" + CLAIMING = "claiming" + PROCESSING = "processing" + STOPPING = "stopping" + + +@dataclass +class WorkerState: + """Runtime state for a running worker (not persisted). + + Attributes: + config: Worker configuration. + status: Current worker status. + current_batch: Events currently being processed. + last_claim_at: When last claim was made. + processed_count: Total events processed. + failed_count: Total events failed. + error: Last error if any. + """ + + config: WorkerConfig + status: WorkerStatus = WorkerStatus.IDLE + current_batch: list["Event"] = field(default_factory=list) + last_claim_at: datetime | None = None + processed_count: int = 0 + failed_count: int = 0 + error: Exception | None = None + + +@dataclass(frozen=True) +class ClaimResult: + """Result of a claim operation. + + Attributes: + events: Claimed events (locked). + claimed_at: Timestamp of claim. + """ + + events: list["Event"] + claimed_at: datetime + + def __bool__(self) -> bool: + """Return True if events are present.""" + return len(self.events) > 0 + + def __len__(self) -> int: + """Return number of events.""" + return len(self.events) + + def __iter__(self) -> Iterator["Event"]: + """Iterate over events.""" + return iter(self.events) + + +# --- EventHandler --- def _extract_event_type(cls: type) -> type["Event"] | None: - """Extract the event type E from EventListener[E] or BatchEventListener[E] in class bases.""" + """Extract the event type E from EventHandler[E] in class bases.""" for base in getattr(cls, "__orig_bases__", []): origin = get_origin(base) origin_name = getattr(origin, "__name__", None) - if origin is not None and origin_name in ("EventListener", "BatchEventListener"): + if origin is not None and origin_name == "EventHandler": args = get_args(base) if args and isinstance(args[0], type) and issubclass(args[0], Event): return args[0] @@ -62,8 +163,8 @@ def _extract_event_type(cls: type) -> type["Event"] | None: @dataclass_transform() -class _EventListenerMeta(ABCMeta): - """Metaclass that applies @dataclass and extracts __event_type__ from EventListener[E].""" +class _EventHandlerMeta(ABCMeta): + """Metaclass that applies @dataclass and extracts __event_type__ from EventHandler[E].""" def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]) -> type: cls = super().__new__(mcs, name, bases, namespace) @@ -76,65 +177,78 @@ def __new__(mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]) return cls -class EventListener(Generic[E], metaclass=_EventListenerMeta): - """Base class for event listeners (subscriptions). +class EventHandler(Generic[E], metaclass=_EventHandlerMeta): + """Base class for pull-based event handlers. - Subclasses are automatically dataclasses and have __event_type__ set - based on their generic parameter. - - Example: - class SourceListener(EventListener[SourceRequested]): - outbox: Outbox - config: Config - - async def handle(self, event: SourceRequested) -> None: - ... - - # SourceListener.__event_type__ == SourceRequested - """ - - __event_type__: ClassVar[type[Event]] + EventHandler replaces both EventListener and BatchEventListener with a unified pattern. + Workers claim events from the outbox and delegate to handlers for processing. - @abstractmethod - async def handle(self, event: Any) -> None: - """Handle the event. Subclasses should type event as their specific event type.""" - ... + Subclasses are automatically dataclasses with DI-injected dependencies. + The __event_type__ is extracted from the generic parameter. + Configuration is via class variables: + __routing_key__: Optional filter for routing key (default: None) + __batch_size__: Max events to claim at once (default: 1) + __batch_timeout__: Timeout for partial batches in seconds (default: 5.0) + __poll_interval__: Seconds between polls when idle (default: 0.5) + __max_retries__: Max retry attempts before marking failed (default: 3) + __claim_timeout__: Seconds before claim considered stale (default: 300.0) -class BatchEventListener(Generic[E], metaclass=_EventListenerMeta): - """Base class for event listeners that process events in batches. + Example (single event): + class TriggerInitialSourceRun(EventHandler[ServerStarted]): + _config: Config + _outbox: Outbox - The BackgroundWorker detects batch listeners and groups events - by type before calling handle_batch(). This enables efficient - batch operations (e.g., batch embedding generation). + async def handle(self, event: ServerStarted) -> None: + for source in self._config.sources: + if source.initial_run and source.initial_run.enabled: + await self._outbox.append(SourceRequested(...)) - Events in a batch are all of the same type and should be processed - atomically - all succeed or all fail together. + Example (batch processing): + class VectorIndexHandler(EventHandler[IndexRecord]): + __routing_key__ = "vector" + __batch_size__ = 100 + __batch_timeout__ = 5.0 - Example: - class IndexRecordBatch(BatchEventListener[IndexRecord]): - indexes: IndexRegistry + _backend: VectorStorageBackend async def handle_batch(self, events: list[IndexRecord]) -> None: - # Group by backend and call ingest_batch - ... - - # IndexRecordBatch.__event_type__ == IndexRecord + records = [(str(e.record_srn), e.metadata) for e in events] + await self._backend.ingest_batch(records) """ __event_type__: ClassVar[type[Event]] + __routing_key__: ClassVar[str | None] = None + __batch_size__: ClassVar[int] = 1 + __batch_timeout__: ClassVar[float] = 5.0 + __poll_interval__: ClassVar[float] = 0.5 + __max_retries__: ClassVar[int] = 3 + __claim_timeout__: ClassVar[float] = 300.0 - @abstractmethod - async def handle_batch(self, events: list[Any]) -> None: - """Process a batch of events. + async def handle(self, event: E) -> None: + """Handle a single event. Override for single-event processing. Args: - events: List of events to process (all same type) + event: The event to handle. Raises: - Exception: If batch processing fails (all events will be retried) + NotImplementedError: If neither handle() nor handle_batch() is overridden. """ - ... + raise NotImplementedError( + f"{type(self).__name__} must implement handle() or handle_batch()" + ) + + async def handle_batch(self, events: list[E]) -> None: + """Handle a batch of events. Override for batch processing. + + Default implementation loops over handle() for each event. + Override for more efficient batch operations. + + Args: + events: List of events to handle (all same type). + """ + for event in events: + await self.handle(event) # --- Schedule --- diff --git a/server/osa/domain/shared/outbox.py b/server/osa/domain/shared/outbox.py index 7a74ccc..545b7c2 100644 --- a/server/osa/domain/shared/outbox.py +++ b/server/osa/domain/shared/outbox.py @@ -2,7 +2,7 @@ from typing import TypeVar -from osa.domain.shared.event import Event, EventId +from osa.domain.shared.event import ClaimResult, Event, EventId from osa.domain.shared.port.event_repository import EventRepository from osa.domain.shared.service import Service @@ -13,15 +13,20 @@ class Outbox(Service): """Domain service for reliable event delivery via the transactional outbox pattern. Wraps EventRepository with delivery semantics. Business code uses this - to append events and query event history. The BackgroundWorker uses this - to fetch pending events and mark them as delivered/failed. + to append events and query event history. Workers use this to claim + events for processing and mark them as delivered/failed. """ _repo: EventRepository - async def append(self, event: Event) -> None: - """Add an event to the outbox for delivery.""" - await self._repo.save(event, status="pending") + async def append(self, event: Event, routing_key: str | None = None) -> None: + """Add an event to the outbox for delivery. + + Args: + event: The event to append. + routing_key: Optional routing key for worker filtering. + """ + await self._repo.save(event, status="pending", routing_key=routing_key) async def fetch_pending(self, limit: int = 100, fair: bool = True) -> list[Event]: """Fetch events awaiting delivery. @@ -47,3 +52,60 @@ async def mark_skipped(self, event_id: EventId, reason: str) -> None: async def find_latest(self, event_type: type[E]) -> E | None: """Find the most recent event of a given type.""" return await self._repo.find_latest_by_type(event_type) + + async def claim( + self, + event_types: list[type[Event]], + limit: int, + routing_key: str | None = None, + ) -> ClaimResult: + """Claim pending events for processing. + + Uses FOR UPDATE SKIP LOCKED to ensure concurrent workers claim + different events without blocking. + + Args: + event_types: Event classes to claim. + limit: Maximum number of events to claim. + routing_key: Optional routing key filter. + + Returns: + ClaimResult containing claimed events and timestamp. + """ + event_type_names = [et.__name__ for et in event_types] + return await self._repo.claim( + event_types=event_type_names, + limit=limit, + routing_key=routing_key, + ) + + async def mark_failed_with_retry( + self, + event_id: EventId, + error: str, + max_retries: int, + ) -> None: + """Mark an event as failed, with retry logic. + + If retry_count < max_retries, resets status to pending for retry. + If retry_count >= max_retries, sets status to failed permanently. + + Args: + event_id: The event ID. + error: Error message. + max_retries: Maximum retry attempts before marking as failed. + """ + await self._repo.mark_failed_with_retry(event_id, error=error, max_retries=max_retries) + + async def reset_stale_claims(self, timeout_seconds: float) -> int: + """Reset events that have been claimed for too long. + + Called periodically to recover from crashed workers. + + Args: + timeout_seconds: Consider claims older than this as stale. + + Returns: + Number of events reset. + """ + return await self._repo.reset_stale_claims(timeout_seconds) diff --git a/server/osa/domain/shared/port/event_repository.py b/server/osa/domain/shared/port/event_repository.py index 74e648c..2989db0 100644 --- a/server/osa/domain/shared/port/event_repository.py +++ b/server/osa/domain/shared/port/event_repository.py @@ -2,7 +2,7 @@ from typing import Protocol, TypeVar -from osa.domain.shared.event import Event, EventId +from osa.domain.shared.event import ClaimResult, Event, EventId E = TypeVar("E", bound=Event) @@ -13,8 +13,10 @@ class EventRepository(Protocol): Delivery semantics (pending/delivered/failed) are handled by the Outbox service. """ - async def save(self, event: Event, status: str = "pending") -> None: - """Persist an event with initial status.""" + async def save( + self, event: Event, status: str = "pending", routing_key: str | None = None + ) -> None: + """Persist an event with initial status and optional routing key.""" ... async def get(self, event_id: EventId) -> Event | None: @@ -68,3 +70,60 @@ async def list_events( async def count(self, event_types: list[str] | None = None) -> int: """Count events, optionally filtered by types.""" ... + + async def claim( + self, + event_types: list[str], + limit: int, + routing_key: str | None = None, + ) -> ClaimResult: + """Claim pending events for processing using FOR UPDATE SKIP LOCKED. + + This atomically: + 1. Selects pending events matching event_types and routing_key + 2. Locks them with FOR UPDATE SKIP LOCKED (concurrent workers skip) + 3. Sets status to 'claimed' and claimed_at to current timestamp + + Args: + event_types: Event type names to claim (class names). + limit: Maximum number of events to claim. + routing_key: Optional routing key filter. If None, claims unrouted events only. + + Returns: + ClaimResult containing claimed events and timestamp. + """ + ... + + async def reset_stale_claims(self, timeout_seconds: float) -> int: + """Reset events that have been claimed for longer than timeout. + + Sets status back to 'pending' for events where: + - status = 'claimed' + - claimed_at < now() - timeout_seconds + + Args: + timeout_seconds: Consider claims older than this as stale. + + Returns: + Number of events reset. + """ + ... + + async def mark_failed_with_retry( + self, + event_id: "EventId", + error: str, + max_retries: int, + ) -> None: + """Mark an event as failed with retry logic. + + If retry_count < max_retries, increments retry_count and resets + status to 'pending' for retry. + If retry_count >= max_retries, sets status to 'failed' permanently. + + Args: + event_id: The event ID. + error: Error message to record. + max_retries: Maximum retry attempts before marking as failed. + """ + ... diff --git a/server/osa/domain/source/handler/__init__.py b/server/osa/domain/source/handler/__init__.py new file mode 100644 index 0000000..3cc4466 --- /dev/null +++ b/server/osa/domain/source/handler/__init__.py @@ -0,0 +1,6 @@ +"""Source domain event handlers.""" + +from osa.domain.source.handler.pull_from_source import PullFromSource +from osa.domain.source.handler.trigger_initial_source_run import TriggerInitialSourceRun + +__all__ = ["PullFromSource", "TriggerInitialSourceRun"] diff --git a/server/osa/domain/source/listener/source_listener.py b/server/osa/domain/source/handler/pull_from_source.py similarity index 75% rename from server/osa/domain/source/listener/source_listener.py rename to server/osa/domain/source/handler/pull_from_source.py index 42369b5..351cca6 100644 --- a/server/osa/domain/source/listener/source_listener.py +++ b/server/osa/domain/source/handler/pull_from_source.py @@ -1,14 +1,14 @@ -"""SourceListener - handles SourceRequested events.""" +"""PullFromSource - handles SourceRequested events.""" -from osa.domain.shared.event import EventListener +from osa.domain.shared.event import EventHandler from osa.domain.source.event.source_requested import SourceRequested from osa.domain.source.service import SourceService -class PullFromSource(EventListener[SourceRequested]): +class PullFromSource(EventHandler[SourceRequested]): """Pulls from a data source and creates depositions. - This listener delegates to SourceService for all business logic. + This handler delegates to SourceService for all business logic. Supports chunked processing with continuation events. """ diff --git a/server/osa/domain/source/listener/initial_source_listener.py b/server/osa/domain/source/handler/trigger_initial_source_run.py similarity index 89% rename from server/osa/domain/source/listener/initial_source_listener.py rename to server/osa/domain/source/handler/trigger_initial_source_run.py index 14fd555..321a93b 100644 --- a/server/osa/domain/source/listener/initial_source_listener.py +++ b/server/osa/domain/source/handler/trigger_initial_source_run.py @@ -1,11 +1,11 @@ -"""InitialSourceListener - triggers source pull on server startup if configured.""" +"""TriggerInitialSourceRun - triggers source pull on server startup if configured.""" import logging from uuid import uuid4 from osa.application.event import ServerStarted from osa.config import Config -from osa.domain.shared.event import EventId, EventListener +from osa.domain.shared.event import EventHandler, EventId from osa.domain.shared.outbox import Outbox from osa.domain.source.event.source_requested import SourceRequested from osa.domain.source.event.source_run_completed import SourceRunCompleted @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class TriggerInitialSourceRun(EventListener[ServerStarted]): +class TriggerInitialSourceRun(EventHandler[ServerStarted]): """Emits SourceRequested on server startup for sources with initial_run enabled.""" config: Config @@ -66,4 +66,3 @@ async def handle(self, event: ServerStarted) -> None: limit=limit, ) ) - # Session commit handled by BackgroundWorker diff --git a/server/osa/domain/source/listener/__init__.py b/server/osa/domain/source/listener/__init__.py deleted file mode 100644 index 7435bc6..0000000 --- a/server/osa/domain/source/listener/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Source domain listeners.""" - -from osa.domain.source.listener.initial_source_listener import TriggerInitialSourceRun -from osa.domain.source.listener.source_listener import PullFromSource - -__all__ = ["PullFromSource", "TriggerInitialSourceRun"] diff --git a/server/osa/domain/validation/handler.py b/server/osa/domain/validation/handler.py deleted file mode 100644 index 6717fc1..0000000 --- a/server/osa/domain/validation/handler.py +++ /dev/null @@ -1,46 +0,0 @@ -import asyncio -from uuid import uuid4 - -import logfire - -from osa.domain.deposition.event.submitted import DepositionSubmittedEvent -from osa.domain.shared.event import EventId, EventListener -from osa.domain.shared.model.srn import Domain, LocalId, ValidationRunSRN -from osa.domain.shared.outbox import Outbox -from osa.domain.validation.event.validation_completed import ValidationCompleted -from osa.domain.validation.model import RunStatus - - -class BeginMockValidation(EventListener[DepositionSubmittedEvent]): - """Stub handler that simulates validation. Replace with real ValidationService wiring.""" - - outbox: Outbox - - async def handle(self, event: DepositionSubmittedEvent) -> None: - with logfire.span("ValidationHandler"): - logfire.info( - "Received DepositionSubmitted, starting validation simulation", - deposition_id=str(event.deposition_id), - ) - - # Simulate async work - await asyncio.sleep(1) - - # Create a mock validation run SRN - validation_run_srn = ValidationRunSRN( - domain=Domain("localhost"), - id=LocalId("mock-validation-run"), - version=None, - ) - - # Emit ValidationCompleted via outbox - completed_event = ValidationCompleted( - id=EventId(uuid4()), - validation_run_srn=validation_run_srn, - deposition_srn=event.deposition_id, - status=RunStatus.COMPLETED, - results=[], - metadata=event.metadata, # Pass through the original metadata - ) - await self.outbox.append(completed_event) - logfire.info("Validation completed event saved to outbox") diff --git a/server/osa/domain/validation/handler/__init__.py b/server/osa/domain/validation/handler/__init__.py new file mode 100644 index 0000000..e19ffb1 --- /dev/null +++ b/server/osa/domain/validation/handler/__init__.py @@ -0,0 +1,5 @@ +"""Validation domain event handlers.""" + +from osa.domain.validation.handler.validate_deposition import ValidateDeposition + +__all__ = ["ValidateDeposition"] diff --git a/server/osa/domain/validation/listener/validation_listener.py b/server/osa/domain/validation/handler/validate_deposition.py similarity index 88% rename from server/osa/domain/validation/listener/validation_listener.py rename to server/osa/domain/validation/handler/validate_deposition.py index 0006ae2..5fe172c 100644 --- a/server/osa/domain/validation/listener/validation_listener.py +++ b/server/osa/domain/validation/handler/validate_deposition.py @@ -1,11 +1,11 @@ -"""ValidationListener - handles DepositionSubmitted events.""" +"""ValidateDeposition - handles DepositionSubmitted events.""" import logging from uuid import uuid4 from osa.config import Config from osa.domain.deposition.event.submitted import DepositionSubmittedEvent -from osa.domain.shared.event import EventId, EventListener +from osa.domain.shared.event import EventHandler, EventId from osa.domain.shared.model.srn import Domain, LocalId, ValidationRunSRN from osa.domain.shared.outbox import Outbox from osa.domain.validation.event.validation_completed import ValidationCompleted @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) -class ValidateNewDeposition(EventListener[DepositionSubmittedEvent]): +class ValidateDeposition(EventHandler[DepositionSubmittedEvent]): """Runs validation on depositions. 0 validators = instant pass.""" outbox: Outbox @@ -55,6 +55,5 @@ async def handle(self, event: DepositionSubmittedEvent) -> None: ) await self.outbox.append(completed) - # Session commit handled by BackgroundWorker logger.debug(f"Validation completed for: {event.deposition_id}") diff --git a/server/osa/domain/validation/listener/__init__.py b/server/osa/domain/validation/listener/__init__.py deleted file mode 100644 index c2aeffb..0000000 --- a/server/osa/domain/validation/listener/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Validation domain listeners.""" - -from osa.domain.validation.listener.validation_listener import ValidateNewDeposition - -__all__ = ["ValidateNewDeposition"] diff --git a/server/osa/domain/validation/util/di/provider.py b/server/osa/domain/validation/util/di/provider.py index 7d3209b..ab6cadb 100644 --- a/server/osa/domain/validation/util/di/provider.py +++ b/server/osa/domain/validation/util/di/provider.py @@ -1,15 +1,13 @@ -from osa.util.di.scope import Scope from dishka import provide from osa.config import Config from osa.domain.shared.model.srn import Domain -from osa.domain.validation.handler import BeginMockValidation from osa.domain.validation.service import ValidationService from osa.util.di.base import Provider +from osa.util.di.scope import Scope class ValidationProvider(Provider): - validation_handler = provide(BeginMockValidation, scope=Scope.UOW) service = provide(ValidationService, scope=Scope.UOW) @provide(scope=Scope.UOW) diff --git a/server/osa/infrastructure/event/di.py b/server/osa/infrastructure/event/di.py index 15010ae..6cdc045 100644 --- a/server/osa/infrastructure/event/di.py +++ b/server/osa/infrastructure/event/di.py @@ -1,24 +1,25 @@ """Dependency injection provider for event system.""" import logging +from typing import Any, NewType from dishka import AsyncContainer, provide from osa.config import Config -from osa.domain.curation.listener import AutoApproveCurationTool -from osa.domain.index.listener import FanOutToIndexBackends, IndexRecordBatch -from osa.domain.source.listener import PullFromSource, TriggerInitialSourceRun -from osa.domain.source.schedule import SourceSchedule -from osa.domain.record.listener import ConvertDepositionToRecord +from osa.domain.curation.handler import AutoApproveCuration +from osa.domain.index.handler import FanOutToIndexBackends, KeywordIndexHandler, VectorIndexHandler +from osa.domain.record.handler import ConvertDepositionToRecord +from osa.domain.shared.event import EventHandler from osa.domain.shared.event_log import EventLog from osa.domain.shared.outbox import Outbox from osa.domain.shared.port.event_repository import EventRepository -from osa.domain.validation.listener import ValidateNewDeposition +from osa.domain.source.handler import PullFromSource, TriggerInitialSourceRun +from osa.domain.source.schedule import SourceSchedule +from osa.domain.validation.handler import ValidateDeposition from osa.infrastructure.event.worker import ( - BackgroundWorker, ScheduleConfig, ScheduleConfigs, - Subscriptions, + WorkerPool, ) from osa.util.di.base import Provider from osa.util.di.scope import Scope @@ -26,18 +27,25 @@ logger = logging.getLogger(__name__) -# All event listeners - single source of truth -# Includes both EventListener (single event) and BatchEventListener (batch of events) -LISTENER_TYPES: Subscriptions = Subscriptions( +# Type alias for handler list +HandlerTypes = NewType("HandlerTypes", list[type[EventHandler[Any]]]) + +# All event handlers for WorkerPool registration +HANDLERS: HandlerTypes = HandlerTypes( [ + # Source handlers TriggerInitialSourceRun, PullFromSource, - ValidateNewDeposition, - AutoApproveCurationTool, + # Validation handlers + ValidateDeposition, + # Curation handlers + AutoApproveCuration, + # Record handlers ConvertDepositionToRecord, - # Index listeners: FanOut creates IndexRecord events, IndexRecordBatch processes them + # Index handlers FanOutToIndexBackends, - IndexRecordBatch, + VectorIndexHandler, + KeywordIndexHandler, ] ) @@ -45,8 +53,8 @@ class EventProvider(Provider): """Provides event system components. - Listeners, Schedules, and Outbox are UOW-scoped (fresh per unit of work). - BackgroundWorker is APP-scoped singleton. + Handlers, Schedules, and Outbox are UOW-scoped (fresh per unit of work). + WorkerPool is APP-scoped singleton. """ # UOW-scoped Outbox (wraps EventRepository) - write side @@ -59,17 +67,17 @@ def get_outbox(self, repo: EventRepository) -> Outbox: def get_event_log(self, repo: EventRepository) -> EventLog: return EventLog(repo) - # UOW-scoped providers for listeners - for _listener_type in LISTENER_TYPES: - locals()[_listener_type.__name__] = provide(_listener_type, scope=Scope.UOW) + # UOW-scoped providers for handlers + for _handler_type in HANDLERS: + locals()[_handler_type.__name__] = provide(_handler_type, scope=Scope.UOW) # UOW-scoped provider for SourceSchedule source_schedule = provide(SourceSchedule, scope=Scope.UOW) @provide(scope=Scope.APP) - def get_subscriptions(self) -> Subscriptions: - """Return the listener types for BackgroundWorker registration.""" - return LISTENER_TYPES + def get_handler_types(self) -> HandlerTypes: + """Return the handler types for WorkerPool registration.""" + return HANDLERS @provide(scope=Scope.APP) def get_schedule_configs(self, config: Config) -> ScheduleConfigs: @@ -95,18 +103,22 @@ def get_schedule_configs(self, config: Config) -> ScheduleConfigs: return ScheduleConfigs(configs) @provide(scope=Scope.APP) - def get_background_worker( + def get_worker_pool( self, container: AsyncContainer, - subscriptions: Subscriptions, + handler_types: HandlerTypes, schedules: ScheduleConfigs, - config: Config, - ) -> BackgroundWorker: - """BackgroundWorker that polls outbox and runs scheduled tasks.""" - return BackgroundWorker( - container, - subscriptions, - schedules, - poll_interval=config.worker.poll_interval, - batch_size=config.worker.batch_size, - ) + ) -> WorkerPool: + """WorkerPool with pull-based event handlers. + + Registers all handlers and scheduled tasks. + Workers use the container to create scoped dependencies per operation. + """ + pool = WorkerPool(container=container, stale_claim_interval=60.0, schedules=schedules) + + # Register all handlers + for handler_type in handler_types: + pool.register(handler_type) + + logger.info(f"WorkerPool created with {len(pool.workers)} workers") + return pool diff --git a/server/osa/infrastructure/event/worker.py b/server/osa/infrastructure/event/worker.py index 102993b..fb5e758 100644 --- a/server/osa/infrastructure/event/worker.py +++ b/server/osa/infrastructure/event/worker.py @@ -1,34 +1,33 @@ -"""BackgroundWorker - unified background work using APScheduler.""" +"""Worker and WorkerPool for pull-based event processing.""" import asyncio import logging -from collections import defaultdict from contextlib import AsyncExitStack from dataclasses import dataclass, field -from typing import Any, NewType, Union, cast +from typing import Any, NewType from uuid import uuid4 from apscheduler import AsyncScheduler from apscheduler.triggers.cron import CronTrigger -from apscheduler.triggers.interval import IntervalTrigger from dishka import AsyncContainer from sqlalchemy.ext.asyncio import AsyncSession from osa.application.event import ServerStarted -from osa.domain.shared.error import SkippedEventsError -from osa.domain.shared.event import BatchEventListener, Event, EventId, EventListener, Schedule +from osa.domain.shared.error import SkippedEvents +from osa.domain.shared.event import ( + EventHandler, + EventId, + Schedule, + WorkerConfig, + WorkerState, + WorkerStatus, +) from osa.domain.shared.outbox import Outbox from osa.util.di.scope import Scope logger = logging.getLogger(__name__) -# Type aliases for DI -# Listener can be either EventListener (single event) or BatchEventListener (batch of events) -ListenerType = Union[type[EventListener[Any]], type[BatchEventListener[Any]]] -Subscriptions = NewType("Subscriptions", list[ListenerType]) - - @dataclass class ScheduleConfig: """Configuration for a scheduled task.""" @@ -42,75 +41,323 @@ class ScheduleConfig: ScheduleConfigs = NewType("ScheduleConfigs", list[ScheduleConfig]) -class BackgroundWorker: - """Unified background worker: outbox polling + scheduled tasks. +class Worker: + """Pull-based event worker that delegates to an EventHandler. + + Workers claim events from the outbox using FOR UPDATE SKIP LOCKED, + enabling concurrent processing without coordination. The Worker + handles all polling/transaction logic while the EventHandler + contains the business logic. + + Configuration is read from the handler's class variables: + __event_type__: Event type to claim + __routing_key__: Optional routing key filter + __batch_size__: Max events per batch + __batch_timeout__: Timeout for partial batches + __poll_interval__: Seconds between polls when idle + __max_retries__: Max retry attempts before marking failed + __claim_timeout__: Seconds before claim considered stale + + Example: + class VectorIndexHandler(EventHandler[IndexRecord]): + __routing_key__ = "vector" + __batch_size__ = 100 + + _backend: VectorStorageBackend + + async def handle_batch(self, events: list[IndexRecord]) -> None: + records = [(str(e.record_srn), e.metadata) for e in events] + await self._backend.ingest_batch(records) + + # Worker created from handler type + worker = Worker(VectorIndexHandler) + worker.set_container(container) + worker.start() + """ + + def __init__(self, handler_type: type[EventHandler[Any]]) -> None: + """Initialize worker from handler type. + + Args: + handler_type: EventHandler subclass with config in classvars. + """ + self._handler_type = handler_type + + # Read config from handler classvars + self._event_type = handler_type.__event_type__ + self._routing_key = handler_type.__routing_key__ + self._batch_size = handler_type.__batch_size__ + self._batch_timeout = handler_type.__batch_timeout__ + self._poll_interval = handler_type.__poll_interval__ + self._max_retries = handler_type.__max_retries__ + self._claim_timeout = handler_type.__claim_timeout__ + + # Create WorkerConfig for state tracking (backwards compat) + self._config = WorkerConfig( + name=handler_type.__name__, + event_types=(self._event_type,), + routing_key=self._routing_key, + batch_size=self._batch_size, + batch_timeout=self._batch_timeout, + poll_interval=self._poll_interval, + max_retries=self._max_retries, + claim_timeout=self._claim_timeout, + ) + self._state = WorkerState(config=self._config) + self._shutdown = False + self._task: asyncio.Task | None = None + self._container: AsyncContainer | None = None + + @property + def name(self) -> str: + """Worker name (handler class name).""" + return self._handler_type.__name__ + + @property + def handler_type(self) -> type[EventHandler[Any]]: + """The EventHandler type this worker delegates to.""" + return self._handler_type + + @property + def config(self) -> WorkerConfig: + """Worker configuration (derived from handler classvars).""" + return self._config + + @property + def state(self) -> WorkerState: + """Current worker state.""" + return self._state + + def set_container(self, container: AsyncContainer) -> None: + """Set the DI container for scoped dependency resolution.""" + self._container = container + + def start(self) -> asyncio.Task: + """Start the worker in a background task. + + Returns: + The asyncio.Task running the worker. + """ + if self._container is None: + raise RuntimeError("Container not set. Call set_container() first.") + + self._shutdown = False + self._task = asyncio.create_task(self._run(), name=f"worker-{self.name}") + logger.info(f"Worker '{self.name}' started") + return self._task + + def stop(self) -> None: + """Signal the worker to stop gracefully. + + The worker will finish processing its current batch before stopping. + """ + self._shutdown = True + self._state.status = WorkerStatus.STOPPING + logger.info(f"Worker '{self.name}' stopping...") + + async def _run(self) -> None: + """Main worker loop.""" + try: + while not self._shutdown: + had_events = await self._poll_once() + if not had_events: + await asyncio.sleep(self._poll_interval) + except asyncio.CancelledError: + logger.info(f"Worker '{self.name}' cancelled") + raise + except Exception as e: + logger.exception(f"Worker '{self.name}' crashed: {e}") + self._state.error = e + raise + finally: + logger.info(f"Worker '{self.name}' stopped") + + async def _poll_once(self) -> bool: + """Execute one poll cycle: claim, process, repeat within UOW scope. + + Returns: + True if events were processed, False if idle. + """ + if self._container is None: + raise RuntimeError("Container not set") + + self._state.status = WorkerStatus.CLAIMING + + # Claim and process within a UOW scope + async with self._container(scope=Scope.UOW) as scope: + outbox = await scope.get(Outbox) + session = await scope.get(AsyncSession) + + # Claim events + result = await outbox.claim( + event_types=[self._event_type], + limit=self._batch_size, + routing_key=self._routing_key, + ) + + if not result.events: + # No events available - commit and return + await session.commit() + self._state.status = WorkerStatus.IDLE + return False + + # Process claimed events via handler + self._state.status = WorkerStatus.PROCESSING + self._state.current_batch = result.events + self._state.last_claim_at = result.claimed_at + + try: + # Get handler instance from DI container + handler = await scope.get(self._handler_type) - Uses APScheduler for all timing. APP-scoped, spawns UOW scopes per work unit. + # Delegate to handler's batch method + if self._batch_size > 1: + await handler.handle_batch(result.events) + else: + await handler.handle(result.events[0]) + + # Mark all events as delivered + for event in result.events: + await outbox.mark_delivered(event.id) - - Outbox polling: IntervalTrigger, dispatches events to EventListeners - - Scheduled tasks: CronTrigger, runs Schedule.run() with params + await session.commit() + self._state.processed_count += len(result.events) - Usage: - async with worker: - # worker is running, yield to application - ... + except SkippedEvents as e: + # Mark specific events as skipped (not the whole batch) + logger.warning( + f"Worker '{self.name}' skipping {len(e.event_ids)} events: {e.reason}" + ) + for event_id in e.event_ids: + await outbox.mark_skipped(event_id, e.reason) + # Mark remaining events as delivered + skipped_set = set(e.event_ids) + for event in result.events: + if event.id not in skipped_set: + await outbox.mark_delivered(event.id) + await session.commit() + self._state.processed_count += len(result.events) - len(e.event_ids) + + except Exception as e: + self._state.failed_count += len(result.events) + self._state.error = e + logger.error(f"Worker '{self.name}' batch failed: {e}") + # Mark all as failed with retry + for event in result.events: + await outbox.mark_failed_with_retry( + event.id, + str(e), + max_retries=self._max_retries, + ) + await session.commit() + + finally: + self._state.current_batch = [] + self._state.status = WorkerStatus.IDLE + + return True + + +class WorkerPool: + """Manages multiple workers, scheduled tasks, and handles stale claim cleanup. + + Usage with handler types (preferred): + pool = WorkerPool(container) + pool.register(VectorIndexHandler) + pool.register(KeywordIndexHandler) + + async with pool: + # Workers are running + await some_long_running_task() + # Workers are stopped + + Legacy usage with Worker instances (deprecated): + pool = WorkerPool(container) + pool.add_worker(VectorIndexWorker(config, backend)) """ def __init__( self, - container: AsyncContainer, - subscriptions: Subscriptions, - schedules: ScheduleConfigs, - poll_interval: float = 0.5, - batch_size: int = 100, + container: AsyncContainer | None = None, + stale_claim_interval: float = 60.0, + schedules: "ScheduleConfigs | None" = None, ) -> None: self._container = container - self._poll_interval = poll_interval - self._batch_size = batch_size - - # Maps event type -> listener TYPE (not instance!) - # Supports both EventListener (single) and BatchEventListener (batch) - self._listener_types: dict[type[Event], ListenerType] = {} - self._batch_listener_types: set[type[Event]] = set() - - for listener_type in subscriptions: - event_type = listener_type.__event_type__ - self._listener_types[event_type] = listener_type - - # Track which event types use batch listeners - if hasattr(listener_type, "handle_batch"): - self._batch_listener_types.add(event_type) - logger.debug( - f"Registered batch listener {listener_type.__name__} for {event_type.__name__}" - ) - else: - logger.debug(f"Registered {listener_type.__name__} for {event_type.__name__}") + self._workers: list[Worker] = [] + self._stale_claim_interval = stale_claim_interval + self._stale_claim_task: asyncio.Task | None = None + self._shutdown = False + self._schedules = schedules or ScheduleConfigs([]) + self._scheduler: AsyncScheduler | None = None + self._exit_stack: AsyncExitStack | None = None + self._schedule_failures: dict[str, int] = {} - # Schedule configs - self._schedules = schedules + def set_container(self, container: AsyncContainer) -> None: + """Set the DI container for all workers.""" + self._container = container + for worker in self._workers: + worker.set_container(container) - self._scheduler = AsyncScheduler() - self._exit_stack: AsyncExitStack | None = None + @property + def workers(self) -> list[Worker]: + """List of managed workers.""" + return self._workers - # Track schedule failures for alerting - self._schedule_failures: dict[str, int] = {} + def register(self, handler_type: type[EventHandler[Any]]) -> Worker: + """Register an EventHandler type and create a Worker for it. + + This is the preferred way to add handlers to the pool. + The Worker is created internally and configured from handler classvars. + + Args: + handler_type: EventHandler subclass to register. + + Returns: + The created Worker instance. + """ + worker = Worker(handler_type) + if self._container is not None: + worker.set_container(self._container) + self._workers.append(worker) + logger.debug(f"Registered handler '{handler_type.__name__}' as worker") + return worker - async def __aenter__(self) -> "BackgroundWorker": - """Start the background worker.""" + def add_worker(self, worker: Worker) -> None: + """Add a worker to the pool. + + DEPRECATED: Use register() with EventHandler types instead. + """ + if self._container is not None: + worker.set_container(self._container) + self._workers.append(worker) + logger.debug(f"Added worker '{worker.name}' to pool") + + def get_worker(self, name: str) -> Worker | None: + """Get a worker by name.""" + for worker in self._workers: + if worker.name == name: + return worker + return None + + async def start(self) -> None: + """Start all workers, scheduled tasks, and the stale claim cleanup task.""" + if self._container is None: + raise RuntimeError("Container not set. Call set_container() first.") + + self._shutdown = False + + # Ensure all workers have the container + for worker in self._workers: + if worker._container is None: + worker.set_container(self._container) + + # Setup scheduler for cron tasks self._exit_stack = AsyncExitStack() await self._exit_stack.__aenter__() - # Enter scheduler context (keeps it alive until __aexit__) + self._scheduler = AsyncScheduler() await self._exit_stack.enter_async_context(self._scheduler) - # Register outbox polling as interval task - await self._scheduler.add_schedule( - self._poll_outbox, - IntervalTrigger(seconds=self._poll_interval), - id="outbox-poll", - ) - logger.debug(f"Registered outbox polling (interval={self._poll_interval}s)") - # Register schedules as cron tasks for config in self._schedules: await self._scheduler.add_schedule( @@ -119,27 +366,33 @@ async def __aenter__(self) -> "BackgroundWorker": id=config.id, kwargs={"config": config}, ) - logger.debug(f"Registered {config.id} (cron={config.cron})") + logger.debug(f"Registered schedule {config.id} (cron={config.cron})") await self._scheduler.start_in_background() - logger.info( - f"BackgroundWorker started with {len(self._listener_types)} listeners, " - f"{len(self._schedules)} schedules" - ) - # Emit ServerStarted to trigger startup listeners + # Emit ServerStarted event to trigger startup handlers await self._emit_server_started() - return self + # Start all workers + for worker in self._workers: + worker.start() - async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # noqa: ANN001 - """Stop the background worker.""" - if self._exit_stack: - await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb) - logger.info("BackgroundWorker stopped") + # Start stale claim cleanup task + if self._stale_claim_interval > 0: + self._stale_claim_task = asyncio.create_task( + self._run_stale_claim_cleanup(), name="stale-claim-cleanup" + ) + + logger.info( + f"WorkerPool started with {len(self._workers)} workers, " + f"{len(self._schedules)} schedules" + ) async def _emit_server_started(self) -> None: - """Emit ServerStarted event to trigger startup listeners.""" + """Emit ServerStarted event to trigger startup handlers.""" + if self._container is None: + return + async with self._container(scope=Scope.UOW) as scope: outbox = await scope.get(Outbox) await outbox.append(ServerStarted(id=EventId(uuid4()))) @@ -147,158 +400,45 @@ async def _emit_server_started(self) -> None: await session.commit() logger.info("ServerStarted event emitted") - async def _poll_outbox(self) -> None: - """Interval task: fetch pending events and dispatch. - - Events are grouped by type. Batch listeners receive all events of their - type together, while regular listeners receive events one-by-one. - """ - try: - # Fetch pending events in one scope - async with self._container(scope=Scope.UOW) as scope: - outbox = await scope.get(Outbox) - events = await outbox.fetch_pending(self._batch_size) - session = await scope.get(AsyncSession) - await session.commit() - - if not events: - return - - # Group events by type for batch processing - by_type: dict[type[Event], list[Event]] = defaultdict(list) - for event in events: - by_type[type(event)].append(event) - - # Log the distribution of event types (shows round-robin working) - type_counts = {t.__name__: len(evts) for t, evts in by_type.items()} - logger.info(f"Processing {len(events)} events: {type_counts}") - - # Process each event type - for event_type, type_events in by_type.items(): - if event_type in self._batch_listener_types: - # Batch listener - dispatch all events together - await self._dispatch_batch(type_events) - else: - # Regular listener - dispatch one-by-one - for event in type_events: - await self._dispatch(event) - - except (asyncio.CancelledError, SystemExit, KeyboardInterrupt): - # Let control exceptions propagate for graceful shutdown - raise - except Exception as e: - # Log but don't re-raise - let the scheduler continue polling - logger.exception(f"Error in outbox poll cycle: {e}") - - async def _dispatch(self, event: Event) -> None: - """Dispatch a single event to its listener in UOW scope.""" - listener_type = self._listener_types.get(type(event)) - if listener_type is None: - # No listener - mark as delivered so it doesn't stay in outbox forever - async with self._container(scope=Scope.UOW) as scope: - outbox = await scope.get(Outbox) - await outbox.mark_delivered(event.id) - session = await scope.get(AsyncSession) - await session.commit() - logger.debug(f"No listener for {type(event).__name__}, marked as delivered") - return - - try: - logger.debug(f"Dispatching {type(event).__name__} -> {listener_type.__name__}") - async with self._container(scope=Scope.UOW) as scope: - # Dishka creates a fresh listener instance with injected deps - listener = cast(EventListener[Any], await scope.get(listener_type)) - await listener.handle(event) - - # Mark delivered and commit - outbox = await scope.get(Outbox) - await outbox.mark_delivered(event.id) - session = await scope.get(AsyncSession) - await session.commit() - - logger.debug(f"Delivered {type(event).__name__} (id={event.id})") - - except Exception as e: - logger.error(f"Failed to handle {type(event).__name__} (id={event.id}): {e}") - # Mark failed in a new scope - async with self._container(scope=Scope.UOW) as scope: - outbox = await scope.get(Outbox) - await outbox.mark_failed(event.id, str(e)) - session = await scope.get(AsyncSession) - await session.commit() - - async def _dispatch_batch(self, events: list[Event]) -> None: - """Dispatch a batch of events to a BatchEventListener in UOW scope. + async def stop(self, timeout: float = 30.0) -> None: + """Stop all workers gracefully. - All events in the batch are of the same type and processed together. - On success, all events are marked delivered. On failure, all are marked failed. + Args: + timeout: Maximum time to wait for workers to stop. """ - if not events: - return + self._shutdown = True + + # Signal all workers to stop + for worker in self._workers: + worker.stop() + + # Stop stale claim cleanup task + if self._stale_claim_task and not self._stale_claim_task.done(): + self._stale_claim_task.cancel() + try: + await self._stale_claim_task + except asyncio.CancelledError: + pass + + # Wait for workers to finish with timeout + tasks = [w._task for w in self._workers if w._task and not w._task.done()] + if tasks: + done, pending = await asyncio.wait(tasks, timeout=timeout) + for task in pending: + task.cancel() + + # Stop scheduler + if self._exit_stack: + await self._exit_stack.__aexit__(None, None, None) + self._exit_stack = None - event_type = type(events[0]) - listener_type = self._listener_types.get(event_type) + logger.info("WorkerPool stopped") - if listener_type is None: - # No listener - mark all as delivered - async with self._container(scope=Scope.UOW) as scope: - outbox = await scope.get(Outbox) - for event in events: - await outbox.mark_delivered(event.id) - session = await scope.get(AsyncSession) - await session.commit() - logger.debug( - f"No listener for {event_type.__name__}, marked {len(events)} as delivered" - ) + async def _run_schedule(self, config: "ScheduleConfig") -> None: + """Cron task: run a scheduled task in UOW scope.""" + if self._container is None: return - event_ids = [e.id for e in events] - - try: - logger.debug( - f"Dispatching batch of {len(events)} {event_type.__name__} -> {listener_type.__name__}" - ) - async with self._container(scope=Scope.UOW) as scope: - # Dishka creates a fresh batch listener instance with injected deps - listener = cast(BatchEventListener[Any], await scope.get(listener_type)) - await listener.handle_batch(events) - - # Mark all as delivered and commit - outbox = await scope.get(Outbox) - for event_id in event_ids: - await outbox.mark_delivered(event_id) - session = await scope.get(AsyncSession) - await session.commit() - - logger.debug(f"Delivered batch of {len(events)} {event_type.__name__} events") - - except SkippedEventsError as e: - # Mark specific events as skipped (not the whole batch) - logger.warning( - f"Skipping {len(e.event_ids)} events for {event_type.__name__}: {e.reason}" - ) - async with self._container(scope=Scope.UOW) as scope: - outbox = await scope.get(Outbox) - for event_id in e.event_ids: - await outbox.mark_skipped(event_id, e.reason) - session = await scope.get(AsyncSession) - await session.commit() - - except Exception as e: - error_msg = str(e) - logger.error( - f"Failed to handle batch of {len(events)} {event_type.__name__} events: {error_msg}" - ) - # Mark all as failed in a new scope - async with self._container(scope=Scope.UOW) as scope: - outbox = await scope.get(Outbox) - for event_id in event_ids: - await outbox.mark_failed(event_id, error_msg) - session = await scope.get(AsyncSession) - await session.commit() - - async def _run_schedule(self, config: ScheduleConfig) -> None: - """Cron task: run a scheduled task in UOW scope.""" try: async with self._container(scope=Scope.UOW) as scope: schedule = await scope.get(config.schedule_type) @@ -320,3 +460,39 @@ async def _run_schedule(self, config: ScheduleConfig) -> None: logger.error(f"Failed to run schedule {config.id} (failures: {failures}): {e}") if failures >= 5: logger.critical(f"Schedule {config.id} has failed {failures} consecutive times") + + async def __aenter__(self) -> "WorkerPool": + """Start the pool as async context manager.""" + await self.start() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: # noqa: ANN001 + """Stop the pool on context exit.""" + await self.stop() + + async def _run_stale_claim_cleanup(self) -> None: + """Periodically reset stale claims.""" + while not self._shutdown: + try: + await asyncio.sleep(self._stale_claim_interval) + + if self._shutdown or self._container is None: + break + + # Get max claim_timeout from all workers + if self._workers: + max_timeout = max(w.config.claim_timeout for w in self._workers) + + # Use a scoped outbox for cleanup + async with self._container(scope=Scope.UOW) as scope: + outbox = await scope.get(Outbox) + session = await scope.get(AsyncSession) + count = await outbox.reset_stale_claims(max_timeout) + await session.commit() + if count > 0: + logger.info(f"Reset {count} stale claims") + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Stale claim cleanup failed: {e}") diff --git a/server/osa/infrastructure/persistence/repository/event.py b/server/osa/infrastructure/persistence/repository/event.py index 0cbe947..22d0e4a 100644 --- a/server/osa/infrastructure/persistence/repository/event.py +++ b/server/osa/infrastructure/persistence/repository/event.py @@ -1,13 +1,15 @@ """SQLAlchemy adapter implementing EventRepository.""" import logging -from datetime import UTC, datetime +from datetime import UTC, datetime, timedelta from typing import TypeVar -from sqlalchemy import func, insert, select, update +from sqlalchemy import func, insert, or_, select, update +from sqlalchemy.dialects.postgresql import INTERVAL +from sqlalchemy.sql import literal from sqlalchemy.ext.asyncio import AsyncSession -from osa.domain.shared.event import Event, EventId +from osa.domain.shared.event import ClaimResult, Event, EventId from osa.domain.shared.port.event_repository import EventRepository from osa.infrastructure.persistence.tables import events_table @@ -25,14 +27,20 @@ class SQLAlchemyEventRepository(EventRepository): def __init__(self, session: AsyncSession) -> None: self._session = session - async def save(self, event: Event, status: str = "pending") -> None: + async def save( + self, event: Event, status: str = "pending", routing_key: str | None = None + ) -> None: """Persist an event.""" + now = datetime.now(UTC) stmt = insert(events_table).values( id=str(event.id), event_type=type(event).__name__, payload=event.model_dump(mode="json"), - created_at=datetime.now(UTC), + created_at=now, delivery_status=status, + routing_key=routing_key, + retry_count=0, + updated_at=now, ) await self._session.execute(stmt) @@ -59,9 +67,11 @@ async def update_status( error: str | None = None, ) -> None: """Update an event's delivery status.""" + now = datetime.now(UTC) values: dict = { "delivery_status": status, - "delivered_at": datetime.now(UTC), + "delivered_at": now, + "updated_at": now, } if error is not None: values["delivery_error"] = error @@ -213,3 +223,171 @@ def _deserialize(self, event_type: str, payload: dict | str) -> Event | None: except Exception as e: logger.error(f"Failed to deserialize event type '{event_type}': {e}") return None + + async def claim( + self, + event_types: list[str], + limit: int, + routing_key: str | None = None, + ) -> ClaimResult: + """Claim pending events using FOR UPDATE SKIP LOCKED. + + This atomically selects and locks events for processing. Concurrent + workers will skip already-locked events. + + Args: + event_types: Event type names to claim. + limit: Maximum number of events to claim. + routing_key: Optional routing key filter. + + Returns: + ClaimResult containing claimed events and timestamp. + """ + now = datetime.now(UTC) + + # Build WHERE clause for pending events + # Include events that are pending and eligible for retry (based on backoff) + where_clauses = [ + events_table.c.delivery_status == "pending", + events_table.c.event_type.in_(event_types), + ] + + # Routing key filter + if routing_key is not None: + where_clauses.append(events_table.c.routing_key == routing_key) + else: + # When routing_key is None, only claim unrouted events + where_clauses.append(events_table.c.routing_key.is_(None)) + + # Backoff eligibility: either first attempt (retry_count=0) or enough time passed + # Backoff formula: min(30, 5^retry_count) seconds + # Events must have updated_at <= now - backoff_seconds to be eligible + backoff_seconds = func.least( + literal(30), + func.power(literal(5), events_table.c.retry_count), + ) + backoff_interval = func.cast(func.concat(backoff_seconds, literal(" seconds")), INTERVAL) + backoff_eligible = or_( + events_table.c.retry_count == 0, + events_table.c.updated_at <= func.now() - backoff_interval, + ) + where_clauses.append(backoff_eligible) + + # Select with FOR UPDATE SKIP LOCKED + stmt = ( + select(events_table.c.id, events_table.c.event_type, events_table.c.payload) + .where(*where_clauses) + .order_by(events_table.c.created_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True) + ) + + result = await self._session.execute(stmt) + rows = result.fetchall() + + if not rows: + return ClaimResult(events=[], claimed_at=now) + + # Update status to 'claimed' and set claimed_at + event_ids = [row[0] for row in rows] + update_stmt = ( + update(events_table) + .where(events_table.c.id.in_(event_ids)) + .values(delivery_status="claimed", claimed_at=now, updated_at=now) + ) + await self._session.execute(update_stmt) + + # Deserialize events + events: list[Event] = [] + for row in rows: + _, event_type, payload = row + event = self._deserialize(event_type, payload) + if event is not None: + events.append(event) + + return ClaimResult(events=events, claimed_at=now) + + async def reset_stale_claims(self, timeout_seconds: float) -> int: + """Reset events that have been claimed for too long. + + Args: + timeout_seconds: Consider claims older than this as stale. + + Returns: + Number of events reset. + """ + cutoff = datetime.now(UTC) - timedelta(seconds=timeout_seconds) + + stmt = ( + update(events_table) + .where( + events_table.c.delivery_status == "claimed", + events_table.c.claimed_at < cutoff, + ) + .values( + delivery_status="pending", + claimed_at=None, + updated_at=datetime.now(UTC), + ) + ) + + result = await self._session.execute(stmt) + count = result.rowcount + if count > 0: + logger.info(f"Reset {count} stale claims (older than {timeout_seconds}s)") + return count + + async def mark_failed_with_retry( + self, + event_id: EventId, + error: str, + max_retries: int, + ) -> None: + """Mark an event as failed with retry logic. + + If retry_count < max_retries, increments retry_count and resets + status to 'pending' for retry. + If retry_count >= max_retries, sets status to 'failed' permanently. + """ + now = datetime.now(UTC) + + # First, get the current retry_count + select_stmt = select(events_table.c.retry_count).where(events_table.c.id == str(event_id)) + result = await self._session.execute(select_stmt) + row = result.first() + + if row is None: + logger.warning(f"Event {event_id} not found for mark_failed_with_retry") + return + + current_retry_count = row[0] or 0 + new_retry_count = current_retry_count + 1 + + if new_retry_count >= max_retries: + # Exceeded max retries - mark as permanently failed + update_stmt = ( + update(events_table) + .where(events_table.c.id == str(event_id)) + .values( + delivery_status="failed", + delivery_error=error, + retry_count=new_retry_count, + updated_at=now, + delivered_at=now, + ) + ) + else: + # Reset to pending for retry + update_stmt = ( + update(events_table) + .where(events_table.c.id == str(event_id)) + .values( + delivery_status="pending", + delivery_error=error, + retry_count=new_retry_count, + claimed_at=None, + updated_at=now, + ) + ) + + await self._session.execute(update_stmt) diff --git a/server/osa/infrastructure/persistence/tables.py b/server/osa/infrastructure/persistence/tables.py index c0d2e40..6b66044 100644 --- a/server/osa/infrastructure/persistence/tables.py +++ b/server/osa/infrastructure/persistence/tables.py @@ -4,10 +4,12 @@ Column, DateTime, Index, + Integer, MetaData, String, Table, Text, + text, ) from sqlalchemy.types import JSON @@ -77,9 +79,16 @@ Column("event_type", String(128), nullable=False), Column("payload", JSON, nullable=False), Column("created_at", DateTime(timezone=True), nullable=False), - Column("delivery_status", String(32), nullable=False), # pending, delivered, failed + Column( + "delivery_status", String(32), nullable=False + ), # pending, claimed, delivered, failed, skipped Column("delivered_at", DateTime(timezone=True), nullable=True), Column("delivery_error", Text, nullable=True), + # Pull-based worker columns + Column("routing_key", String(255), nullable=True), + Column("retry_count", Integer, nullable=False, default=0), + Column("claimed_at", DateTime(timezone=True), nullable=True), + Column("updated_at", DateTime(timezone=True), nullable=False), ) Index( @@ -88,3 +97,28 @@ events_table.c.created_at.desc(), ) Index("idx_events_delivery_status", events_table.c.delivery_status) + +# Partial index for efficient claiming query +Index( + "idx_events_claim", + events_table.c.delivery_status, + events_table.c.event_type, + events_table.c.routing_key, + events_table.c.created_at, + postgresql_where=text("delivery_status IN ('pending', 'claimed')"), +) + +# Partial index for stale claim detection +Index( + "idx_events_stale_claims", + events_table.c.claimed_at, + postgresql_where=text("delivery_status = 'claimed'"), +) + +# Partial index for failed event queries +Index( + "idx_events_failed", + events_table.c.event_type, + events_table.c.created_at, + postgresql_where=text("delivery_status = 'failed'"), +) diff --git a/server/tests/contract/test_worker_lifecycle.py b/server/tests/contract/test_worker_lifecycle.py new file mode 100644 index 0000000..600001d --- /dev/null +++ b/server/tests/contract/test_worker_lifecycle.py @@ -0,0 +1,218 @@ +"""Contract tests for WorkerPool lifecycle in FastAPI lifespan. + +Tests for Phase 7: Migration. +""" + +import asyncio +from datetime import UTC, datetime +from typing import Any +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest + +from osa.domain.index.event.index_record import IndexRecord +from osa.domain.index.handler.keyword_index_handler import KeywordIndexHandler +from osa.domain.index.handler.vector_index_handler import VectorIndexHandler +from osa.domain.index.model.registry import IndexRegistry +from osa.domain.shared.event import ClaimResult, EventId +from osa.domain.shared.model.srn import Domain, LocalId, RecordSRN, RecordVersion +from osa.infrastructure.event.worker import WorkerPool + + +class FakeBackend: + """Fake storage backend for testing.""" + + def __init__(self, name: str): + self._name = name + self.ingested: list[tuple[str, dict]] = [] + + @property + def name(self) -> str: + return self._name + + async def ingest_batch(self, records: list[tuple[str, dict]]) -> None: + self.ingested.extend(records) + + +def make_mock_container(handler_type: type, handler_instance: Any): + """Create a mock DI container that provides scoped dependencies. + + Creates a container mock that returns scoped Outbox and AsyncSession + when called as an async context manager with scope parameter. + """ + from osa.domain.shared.outbox import Outbox + + outbox = AsyncMock() + outbox.claim.return_value = ClaimResult(events=[], claimed_at=datetime.now(UTC)) + outbox.reset_stale_claims.return_value = 0 + + session = AsyncMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + + async def get_dependency(cls: type) -> Any: + """Return the appropriate dependency based on the requested class.""" + if cls == Outbox: + return outbox + if cls == handler_type: + return handler_instance + return session + + # Create scope that returns dependencies + scope = AsyncMock() + scope.get = AsyncMock(side_effect=get_dependency) + + # Create async context manager + context = MagicMock() + context.__aenter__ = AsyncMock(return_value=scope) + context.__aexit__ = AsyncMock(return_value=None) + + # Container callable returns the context manager + container = MagicMock() + container.return_value = context + + return container, outbox, session + + +class TestWorkerPoolLifecycle: + """Tests for WorkerPool lifecycle management.""" + + @pytest.mark.asyncio + async def test_worker_pool_starts_and_stops_cleanly(self): + """WorkerPool should start and stop without errors.""" + vector_backend = FakeBackend("vector") + keyword_backend = FakeBackend("keyword") + registry = IndexRegistry({"vector": vector_backend, "keyword": keyword_backend}) + + # Create handler instance + vector_handler = VectorIndexHandler(indexes=registry) + + container, outbox, session = make_mock_container(VectorIndexHandler, vector_handler) + + pool = WorkerPool(container=container, stale_claim_interval=60.0) + pool.register(VectorIndexHandler) + pool.register(KeywordIndexHandler) + + # Start + await pool.start() + await asyncio.sleep(0.05) + + # Verify workers are running + assert len(pool.workers) == 2 + for worker in pool.workers: + assert worker._task is not None + assert not worker._task.done() + + # Stop + await pool.stop() + + # Verify workers are stopped + for worker in pool.workers: + assert worker._shutdown is True + + @pytest.mark.asyncio + async def test_worker_pool_as_context_manager(self): + """WorkerPool should work as async context manager.""" + vector_backend = FakeBackend("vector") + registry = IndexRegistry({"vector": vector_backend}) + vector_handler = VectorIndexHandler(indexes=registry) + + container, outbox, session = make_mock_container(VectorIndexHandler, vector_handler) + + pool = WorkerPool(container=container, stale_claim_interval=60.0) + pool.register(VectorIndexHandler) + + async with pool: + # Workers should be running + assert pool.workers[0]._task is not None + await asyncio.sleep(0.02) + + # After exit, workers should be stopped + assert pool.workers[0]._shutdown is True + + +class TestIndexHandlers: + """Tests for concrete index handlers.""" + + @pytest.mark.asyncio + async def test_vector_handler_processes_batch(self): + """VectorIndexHandler should process IndexRecord events in batches.""" + backend = FakeBackend("vector") + registry = IndexRegistry({"vector": backend}) + handler = VectorIndexHandler(indexes=registry) + + # Create test events + events = [ + IndexRecord( + id=EventId(uuid4()), + backend_name="vector", + record_srn=RecordSRN( + domain=Domain("test.example.com"), + id=LocalId(f"rec-{i}"), + version=RecordVersion(1), + ), + metadata={"title": f"Record {i}"}, + ) + for i in range(5) + ] + + # Process + await handler.handle_batch(events) + + # Verify backend received all records + assert len(backend.ingested) == 5 + + @pytest.mark.asyncio + async def test_keyword_handler_processes_individually(self): + """KeywordIndexHandler should process IndexRecord events one at a time.""" + backend = FakeBackend("keyword") + registry = IndexRegistry({"keyword": backend}) + handler = KeywordIndexHandler(indexes=registry) + + # batch_size should be 1 + assert KeywordIndexHandler.__batch_size__ == 1 + + # Create test event + event = IndexRecord( + id=EventId(uuid4()), + backend_name="keyword", + record_srn=RecordSRN( + domain=Domain("test.example.com"), + id=LocalId("rec-1"), + version=RecordVersion(1), + ), + metadata={"title": "Record 1"}, + ) + + # Process + await handler.handle(event) + + # Verify + assert len(backend.ingested) == 1 + + @pytest.mark.asyncio + async def test_handler_raises_on_backend_failure(self): + """Handlers should raise when backend fails (Worker handles retry).""" + # Backend that fails + failing_backend = AsyncMock() + failing_backend.name = "vector" + failing_backend.ingest_batch = AsyncMock(side_effect=Exception("Backend error")) + + registry = IndexRegistry({"vector": failing_backend}) + handler = VectorIndexHandler(indexes=registry) + + event = IndexRecord( + id=EventId(uuid4()), + backend_name="vector", + record_srn=RecordSRN( + domain=Domain("test.example.com"), + id=LocalId("rec-1"), + version=RecordVersion(1), + ), + metadata={"title": "Record 1"}, + ) + + # Process - should raise (Worker handles retry) + with pytest.raises(Exception, match="Backend error"): + await handler.handle_batch([event]) diff --git a/server/tests/integration/event/__init__.py b/server/tests/integration/event/__init__.py new file mode 100644 index 0000000..20b1c3f --- /dev/null +++ b/server/tests/integration/event/__init__.py @@ -0,0 +1 @@ +"""Integration tests for event processing.""" diff --git a/server/tests/integration/event/test_claim.py b/server/tests/integration/event/test_claim.py new file mode 100644 index 0000000..9707021 --- /dev/null +++ b/server/tests/integration/event/test_claim.py @@ -0,0 +1,339 @@ +"""Integration tests for event claiming with FOR UPDATE SKIP LOCKED. + +Tests for User Story 1: Reliable Event Processing. +""" + +from datetime import UTC, datetime +from typing import Any +from uuid import uuid4 + +import pytest + +from osa.domain.shared.event import ClaimResult, Event, EventId + + +class DummyEvent(Event): + """Test event for claim tests.""" + + id: EventId + data: str + + +class FakeEventRepository: + """Fake event repository simulating FOR UPDATE SKIP LOCKED behavior. + + Tracks claimed events and ensures concurrent claims skip locked rows. + """ + + def __init__(self): + self.events: dict[str, dict[str, Any]] = {} + self._locked_ids: set[str] = set() + + async def save( + self, event: Event, status: str = "pending", routing_key: str | None = None + ) -> None: + """Save an event.""" + self.events[str(event.id)] = { + "event": event, + "status": status, + "routing_key": routing_key, + "retry_count": 0, + "claimed_at": None, + "updated_at": datetime.now(UTC), + } + + async def claim( + self, + event_types: list[str], + limit: int, + routing_key: str | None = None, + ) -> ClaimResult: + """Claim events, simulating FOR UPDATE SKIP LOCKED.""" + claimed = [] + now = datetime.now(UTC) + + for event_id, data in self.events.items(): + if len(claimed) >= limit: + break + + # Skip already locked events (simulates SKIP LOCKED) + if event_id in self._locked_ids: + continue + + # Check status + if data["status"] != "pending": + continue + + # Check event type + if type(data["event"]).__name__ not in event_types: + continue + + # Check routing key + if routing_key is not None and data["routing_key"] != routing_key: + continue + + # Lock and claim + self._locked_ids.add(event_id) + data["status"] = "claimed" + data["claimed_at"] = now + claimed.append(data["event"]) + + return ClaimResult(events=claimed, claimed_at=now) + + async def update_status( + self, + event_id: EventId, + status: str, + error: str | None = None, + ) -> None: + """Update event status and release lock.""" + event_id_str = str(event_id) + if event_id_str in self.events: + self.events[event_id_str]["status"] = status + self.events[event_id_str]["updated_at"] = datetime.now(UTC) + if error: + self.events[event_id_str]["delivery_error"] = error + # Release lock when delivered/failed + if status in ("delivered", "failed", "skipped"): + self._locked_ids.discard(event_id_str) + + async def mark_failed_with_retry( + self, + event_id: EventId, + error: str, + max_retries: int, + ) -> None: + """Mark failed with retry logic.""" + event_id_str = str(event_id) + if event_id_str not in self.events: + return + + data = self.events[event_id_str] + data["retry_count"] += 1 + data["updated_at"] = datetime.now(UTC) + + if data["retry_count"] >= max_retries: + data["status"] = "failed" + data["delivery_error"] = error + else: + # Reset to pending for retry + data["status"] = "pending" + data["claimed_at"] = None + + self._locked_ids.discard(event_id_str) + + def release_lock(self, event_id: str) -> None: + """Release lock (simulates transaction rollback).""" + self._locked_ids.discard(event_id) + + +class TestClaimWithSkipLocked: + """Tests for FOR UPDATE SKIP LOCKED behavior.""" + + @pytest.fixture + def repo(self) -> FakeEventRepository: + """Create a fake event repository.""" + return FakeEventRepository() + + @pytest.mark.asyncio + async def test_claim_returns_pending_events(self, repo: FakeEventRepository): + """Claim should return pending events matching event_types.""" + # Arrange + event1 = DummyEvent(id=EventId(uuid4()), data="event1") + event2 = DummyEvent(id=EventId(uuid4()), data="event2") + await repo.save(event1) + await repo.save(event2) + + # Act + result = await repo.claim(event_types=["DummyEvent"], limit=10) + + # Assert + assert len(result) == 2 + assert result.events[0].data in ("event1", "event2") + assert result.events[1].data in ("event1", "event2") + + @pytest.mark.asyncio + async def test_claim_respects_limit(self, repo: FakeEventRepository): + """Claim should respect the limit parameter.""" + # Arrange + for i in range(10): + await repo.save(DummyEvent(id=EventId(uuid4()), data=f"event{i}")) + + # Act + result = await repo.claim(event_types=["DummyEvent"], limit=3) + + # Assert + assert len(result) == 3 + + @pytest.mark.asyncio + async def test_claim_skips_already_claimed_events(self, repo: FakeEventRepository): + """Concurrent claims should skip already locked events.""" + # Arrange + events = [DummyEvent(id=EventId(uuid4()), data=f"event{i}") for i in range(5)] + for event in events: + await repo.save(event) + + # Act - First worker claims some events + result1 = await repo.claim(event_types=["DummyEvent"], limit=3) + # Second worker tries to claim - should skip locked ones + result2 = await repo.claim(event_types=["DummyEvent"], limit=3) + + # Assert + assert len(result1) == 3 + assert len(result2) == 2 # Only remaining unclaimed events + + # Verify no overlap + ids1 = {e.id for e in result1.events} + ids2 = {e.id for e in result2.events} + assert ids1.isdisjoint(ids2) + + @pytest.mark.asyncio + async def test_claim_filters_by_routing_key(self, repo: FakeEventRepository): + """Claim should filter by routing_key when specified.""" + # Arrange + event1 = DummyEvent(id=EventId(uuid4()), data="vector-event") + event2 = DummyEvent(id=EventId(uuid4()), data="keyword-event") + event3 = DummyEvent(id=EventId(uuid4()), data="unrouted-event") + + await repo.save(event1, routing_key="vector") + await repo.save(event2, routing_key="keyword") + await repo.save(event3, routing_key=None) + + # Act + vector_result = await repo.claim(event_types=["DummyEvent"], limit=10, routing_key="vector") + + # Assert + assert len(vector_result) == 1 + assert vector_result.events[0].data == "vector-event" + + @pytest.mark.asyncio + async def test_claim_sets_status_to_claimed(self, repo: FakeEventRepository): + """Claim should set event status to 'claimed'.""" + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + await repo.save(event) + + # Act + await repo.claim(event_types=["DummyEvent"], limit=1) + + # Assert + assert repo.events[str(event.id)]["status"] == "claimed" + + @pytest.mark.asyncio + async def test_claim_sets_claimed_at_timestamp(self, repo: FakeEventRepository): + """Claim should set claimed_at timestamp.""" + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + await repo.save(event) + + # Act + before = datetime.now(UTC) + result = await repo.claim(event_types=["DummyEvent"], limit=1) + after = datetime.now(UTC) + + # Assert + claimed_at = repo.events[str(event.id)]["claimed_at"] + assert claimed_at is not None + assert before <= claimed_at <= after + assert result.claimed_at == claimed_at + + +class TestPartialFailureRecovery: + """Tests for partial failure recovery.""" + + @pytest.fixture + def repo(self) -> FakeEventRepository: + """Create a fake event repository.""" + return FakeEventRepository() + + @pytest.mark.asyncio + async def test_mark_delivered_releases_lock(self, repo: FakeEventRepository): + """mark_delivered should release the lock and set status to delivered.""" + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + await repo.save(event) + await repo.claim(event_types=["DummyEvent"], limit=1) + + # Act + await repo.update_status(event.id, status="delivered") + + # Assert + assert repo.events[str(event.id)]["status"] == "delivered" + assert str(event.id) not in repo._locked_ids + + @pytest.mark.asyncio + async def test_mark_failed_with_retry_resets_to_pending(self, repo: FakeEventRepository): + """mark_failed_with_retry should reset to pending if retries remain.""" + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + await repo.save(event) + await repo.claim(event_types=["DummyEvent"], limit=1) + + # Act - First failure (retry_count becomes 1, max is 3) + await repo.mark_failed_with_retry(event.id, "Error 1", max_retries=3) + + # Assert - Should be pending for retry + assert repo.events[str(event.id)]["status"] == "pending" + assert repo.events[str(event.id)]["retry_count"] == 1 + assert str(event.id) not in repo._locked_ids + + @pytest.mark.asyncio + async def test_mark_failed_after_max_retries_sets_failed(self, repo: FakeEventRepository): + """mark_failed_with_retry should set status=failed after max_retries.""" + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + await repo.save(event) + + # Simulate 3 failures (max_retries=3) + for i in range(3): + await repo.claim(event_types=["DummyEvent"], limit=1) + await repo.mark_failed_with_retry(event.id, f"Error {i + 1}", max_retries=3) + + # Assert - After 3 retries, should be failed + assert repo.events[str(event.id)]["status"] == "failed" + assert repo.events[str(event.id)]["retry_count"] == 3 + + @pytest.mark.asyncio + async def test_partial_batch_some_succeed_some_fail(self, repo: FakeEventRepository): + """In a batch, some events can succeed while others fail.""" + # Arrange + events = [DummyEvent(id=EventId(uuid4()), data=f"event{i}") for i in range(5)] + for event in events: + await repo.save(event) + + # Claim all events + await repo.claim(event_types=["DummyEvent"], limit=5) + + # Act - Mark first 3 as delivered, last 2 as failed + for event in events[:3]: + await repo.update_status(event.id, status="delivered") + for event in events[3:]: + await repo.mark_failed_with_retry(event.id, "Processing error", max_retries=3) + + # Assert + delivered = [e for e in events if repo.events[str(e.id)]["status"] == "delivered"] + pending = [e for e in events if repo.events[str(e.id)]["status"] == "pending"] + + assert len(delivered) == 3 + assert len(pending) == 2 # Failed ones reset to pending for retry + + @pytest.mark.asyncio + async def test_released_events_can_be_reclaimed(self, repo: FakeEventRepository): + """Events released (via rollback or retry) can be claimed by other workers.""" + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + await repo.save(event) + + # First worker claims + await repo.claim(event_types=["DummyEvent"], limit=1) + assert repo.events[str(event.id)]["status"] == "claimed" + + # First worker fails and releases + await repo.mark_failed_with_retry(event.id, "Error", max_retries=3) + + # Act - Second worker can now claim + result = await repo.claim(event_types=["DummyEvent"], limit=1) + + # Assert + assert len(result) == 1 + assert result.events[0].id == event.id diff --git a/server/tests/integration/event/test_concurrent_workers.py b/server/tests/integration/event/test_concurrent_workers.py new file mode 100644 index 0000000..37c36b3 --- /dev/null +++ b/server/tests/integration/event/test_concurrent_workers.py @@ -0,0 +1,247 @@ +"""Integration tests for concurrent workers claiming different events. + +Tests for User Story 2: Concurrent Event Processing. +""" + +import asyncio +from datetime import UTC, datetime +from uuid import uuid4 + +import pytest + +from osa.domain.shared.event import ClaimResult, Event, EventId + + +class DummyEvent(Event): + """Test event for concurrent worker tests.""" + + id: EventId + data: str + + +class FakeConcurrentRepository: + """Fake repository simulating concurrent access with SKIP LOCKED. + + Uses asyncio.Lock to simulate row-level locking behavior. + """ + + def __init__(self): + self.events: dict[str, dict] = {} + self._locks: dict[str, asyncio.Lock] = {} + self._global_lock = asyncio.Lock() + + async def save(self, event: Event, status: str = "pending", routing_key: str | None = None): + """Save an event.""" + self.events[str(event.id)] = { + "event": event, + "status": status, + "routing_key": routing_key, + } + self._locks[str(event.id)] = asyncio.Lock() + + async def claim( + self, + event_types: list[str], + limit: int, + routing_key: str | None = None, + ) -> ClaimResult: + """Claim events with SKIP LOCKED simulation.""" + claimed = [] + now = datetime.now(UTC) + + async with self._global_lock: + for event_id, data in self.events.items(): + if len(claimed) >= limit: + break + + # Skip if locked (simulates SKIP LOCKED) + if self._locks[event_id].locked(): + continue + + if data["status"] != "pending": + continue + + if type(data["event"]).__name__ not in event_types: + continue + + if routing_key is not None and data["routing_key"] != routing_key: + continue + + # Try to acquire lock + if self._locks[event_id].locked(): + continue + + await self._locks[event_id].acquire() + data["status"] = "claimed" + claimed.append(data["event"]) + + return ClaimResult(events=claimed, claimed_at=now) + + async def release(self, event_id: str, new_status: str = "delivered"): + """Release lock and update status.""" + if event_id in self.events: + self.events[event_id]["status"] = new_status + if event_id in self._locks and self._locks[event_id].locked(): + self._locks[event_id].release() + + async def reset_stale_claims(self, timeout_seconds: float) -> int: + """Reset stale claims (stub).""" + return 0 + + +class TestConcurrentWorkerClaiming: + """Tests for concurrent workers claiming different events.""" + + @pytest.mark.asyncio + async def test_concurrent_workers_claim_different_events(self): + """Multiple workers running concurrently should claim different events.""" + # Arrange + repo = FakeConcurrentRepository() + + # Create 10 events + events = [] + for i in range(10): + event = DummyEvent(id=EventId(uuid4()), data=f"event-{i}") + events.append(event) + await repo.save(event) + + # Simulate two workers claiming concurrently + async def worker_claim(worker_id: int, limit: int) -> list[Event]: + result = await repo.claim(event_types=["DummyEvent"], limit=limit) + # Simulate some processing time + await asyncio.sleep(0.01) + # Release all claimed events + for event in result.events: + await repo.release(str(event.id)) + return result.events + + # Act - Run two workers concurrently + results = await asyncio.gather( + worker_claim(1, 5), + worker_claim(2, 5), + ) + + worker1_events = results[0] + worker2_events = results[1] + + # Assert - Each worker should have claimed some events + assert len(worker1_events) + len(worker2_events) == 10 + + # No overlap - SKIP LOCKED ensures each event claimed by only one worker + worker1_ids = {e.id for e in worker1_events} + worker2_ids = {e.id for e in worker2_events} + assert worker1_ids.isdisjoint(worker2_ids) + + @pytest.mark.asyncio + async def test_skip_locked_prevents_double_claiming(self): + """SKIP LOCKED should prevent the same event from being claimed twice.""" + # Arrange + repo = FakeConcurrentRepository() + event = DummyEvent(id=EventId(uuid4()), data="single-event") + await repo.save(event) + + claim_count = 0 + claimed_by = [] + + async def try_claim(worker_id: int) -> bool: + nonlocal claim_count + result = await repo.claim(event_types=["DummyEvent"], limit=1) + if result.events: + claim_count += 1 + claimed_by.append(worker_id) + # Hold the lock briefly + await asyncio.sleep(0.02) + await repo.release(str(result.events[0].id)) + return True + return False + + # Act - Multiple workers try to claim the same event simultaneously + results = await asyncio.gather( + try_claim(1), + try_claim(2), + try_claim(3), + ) + + # Assert - Only one worker should have claimed the event + successful_claims = sum(results) + assert successful_claims == 1 + assert claim_count == 1 + + @pytest.mark.asyncio + async def test_routing_key_isolation(self): + """Workers with different routing keys should not interfere.""" + # Arrange + repo = FakeConcurrentRepository() + + # Create events for different routing keys + for i in range(5): + event = DummyEvent(id=EventId(uuid4()), data=f"vector-{i}") + await repo.save(event, routing_key="vector") + + for i in range(5): + event = DummyEvent(id=EventId(uuid4()), data=f"keyword-{i}") + await repo.save(event, routing_key="keyword") + + # Act - Two workers with different routing keys claim concurrently + vector_result = await repo.claim(event_types=["DummyEvent"], limit=10, routing_key="vector") + keyword_result = await repo.claim( + event_types=["DummyEvent"], limit=10, routing_key="keyword" + ) + + # Assert - Each worker only gets their routed events + assert len(vector_result.events) == 5 + assert len(keyword_result.events) == 5 + + assert all("vector" in e.data for e in vector_result.events) + assert all("keyword" in e.data for e in keyword_result.events) + + @pytest.mark.asyncio + async def test_high_concurrency_no_duplicates(self): + """Under high concurrency, no events should be processed by multiple workers.""" + # Arrange + repo = FakeConcurrentRepository() + + # Create 100 events + for i in range(100): + event = DummyEvent(id=EventId(uuid4()), data=f"event-{i}") + await repo.save(event) + + all_claimed_ids: list[set] = [] + + async def worker_claim_all(worker_id: int) -> set: + """Worker keeps claiming until no more events.""" + claimed_ids = set() + while True: + result = await repo.claim(event_types=["DummyEvent"], limit=10) + if not result.events: + break + for event in result.events: + claimed_ids.add(event.id) + await repo.release(str(event.id)) + await asyncio.sleep(0.001) # Small delay to simulate processing + return claimed_ids + + # Act - Run 5 workers concurrently + results = await asyncio.gather( + worker_claim_all(1), + worker_claim_all(2), + worker_claim_all(3), + worker_claim_all(4), + worker_claim_all(5), + ) + + # Collect all claimed IDs + all_ids = set() + for claimed_ids in results: + all_ids.update(claimed_ids) + all_claimed_ids.append(claimed_ids) + + # Assert - All 100 events should be claimed exactly once + assert len(all_ids) == 100 + + # Check no duplicates across workers + for i, ids1 in enumerate(all_claimed_ids): + for j, ids2 in enumerate(all_claimed_ids): + if i != j: + overlap = ids1 & ids2 + assert len(overlap) == 0, f"Workers {i} and {j} both claimed: {overlap}" diff --git a/server/tests/integration/test_crash_recovery.py b/server/tests/integration/test_crash_recovery.py index f80e511..18ab740 100644 --- a/server/tests/integration/test_crash_recovery.py +++ b/server/tests/integration/test_crash_recovery.py @@ -9,7 +9,7 @@ import pytest from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.listener.index_batch_listener import IndexRecordBatch +from osa.domain.index.handler.vector_index_handler import VectorIndexHandler from osa.domain.index.model.registry import IndexRegistry from osa.domain.shared.event import EventId from osa.domain.shared.model.srn import Domain, LocalId, RecordSRN, RecordVersion @@ -86,11 +86,11 @@ async def test_events_remain_pending_on_crash(self): # Simulate a backend that fails mid-batch backend = TrackingBackend("vector", fail_at=1) # Fail on first call registry = IndexRegistry({"vector": backend}) - listener = IndexRecordBatch(indexes=registry) + handler = VectorIndexHandler(indexes=registry) # Act - Try to process, expecting failure with pytest.raises(Exception, match="Simulated crash"): - await listener.handle_batch(events) + await handler.handle_batch(events) # Assert - Events should remain pending (outbox unchanged) assert len(outbox.pending) == 10 # All still pending @@ -107,10 +107,10 @@ async def test_recovery_processes_all_pending_events(self): # First attempt: backend fails failing_backend = TrackingBackend("vector", fail_at=1) failing_registry = IndexRegistry({"vector": failing_backend}) - failing_listener = IndexRecordBatch(indexes=failing_registry) + failing_handler = VectorIndexHandler(indexes=failing_registry) with pytest.raises(Exception): - await failing_listener.handle_batch(events) + await failing_handler.handle_batch(events) # Events still pending assert len(outbox.pending) == 10 @@ -118,10 +118,10 @@ async def test_recovery_processes_all_pending_events(self): # Second attempt (recovery): backend works working_backend = TrackingBackend("vector") working_registry = IndexRegistry({"vector": working_backend}) - working_listener = IndexRecordBatch(indexes=working_registry) + working_handler = VectorIndexHandler(indexes=working_registry) # Act - Retry processing - await working_listener.handle_batch(outbox.pending) + await working_handler.handle_batch(outbox.pending) # Assert - All events processed assert len(working_backend.indexed_records) == 10 @@ -147,15 +147,15 @@ async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: backend = PartialFailBackend() registry = IndexRegistry({"vector": backend}) - listener = IndexRecordBatch(indexes=registry) + handler = VectorIndexHandler(indexes=registry) # Act with pytest.raises(Exception, match="Partial failure"): - await listener.handle_batch(events) + await handler.handle_batch(events) # Assert - Some records were processed but batch should be atomic # In production, the outbox marks all events as failed together - # The listener correctly propagates the error for retry + # The handler correctly propagates the error for retry assert len(backend.processed) == 2 # Backend saw 2 before crash @pytest.mark.asyncio @@ -166,11 +166,11 @@ async def test_idempotent_reprocessing(self): backend = TrackingBackend("vector") registry = IndexRegistry({"vector": backend}) - listener = IndexRecordBatch(indexes=registry) + handler = VectorIndexHandler(indexes=registry) # Act - Process twice - await listener.handle_batch(events) - await listener.handle_batch(events) # Reprocess same events + await handler.handle_batch(events) + await handler.handle_batch(events) # Reprocess same events # Assert - Records should be in backend (upsert semantics handle duplicates) # The backend receives all records from both batches @@ -198,10 +198,10 @@ async def ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: backend = StatelessBackend() registry = IndexRegistry({"vector": backend}) - listener = IndexRecordBatch(indexes=registry) + handler = VectorIndexHandler(indexes=registry) # Act - await listener.handle_batch(events) + await handler.handle_batch(events) # Assert - All records in single batch call, immediately persisted assert len(backend.batch_calls) == 1 diff --git a/server/tests/integration/test_event_batch_processing.py b/server/tests/integration/test_event_batch_processing.py index bdb63f1..304d6f9 100644 --- a/server/tests/integration/test_event_batch_processing.py +++ b/server/tests/integration/test_event_batch_processing.py @@ -10,8 +10,8 @@ import pytest from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.listener.fanout_listener import FanOutToIndexBackends -from osa.domain.index.listener.index_batch_listener import IndexRecordBatch +from osa.domain.index.handler.fanout_to_index_backends import FanOutToIndexBackends +from osa.domain.index.handler.vector_index_handler import VectorIndexHandler from osa.domain.index.model.registry import IndexRegistry from osa.domain.record.event.record_published import RecordPublished from osa.domain.shared.event import EventId @@ -42,7 +42,7 @@ class FakeOutbox: def __init__(self): self.events: list[Any] = [] - async def append(self, event: Any) -> None: + async def append(self, event: Any, routing_key: str | None = None) -> None: self.events.append(event) @@ -92,21 +92,14 @@ async def test_fanout_creates_index_records_per_backend(self): assert backend_names == {"vector", "keyword"} @pytest.mark.asyncio - async def test_batch_listener_groups_by_backend(self): - """IndexRecordBatch should group events by backend and batch ingest.""" + async def test_handler_processes_batch(self): + """VectorIndexHandler should process events in batches.""" # Arrange vector_backend = FakeBackend("vector") - keyword_backend = FakeBackend("keyword") - registry = IndexRegistry( - { - "vector": vector_backend, - "keyword": keyword_backend, - } - ) - - batch_listener = IndexRecordBatch(indexes=registry) + registry = IndexRegistry({"vector": vector_backend}) + handler = VectorIndexHandler(indexes=registry) - # Create mixed events for both backends + # Create events for vector backend events = [ IndexRecord( id=EventId(uuid4()), @@ -119,33 +112,16 @@ async def test_batch_listener_groups_by_backend(self): metadata={"id": i}, ) for i in range(5) - ] + [ - IndexRecord( - id=EventId(uuid4()), - backend_name="keyword", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ), - metadata={"id": i}, - ) - for i in range(3) ] # Act - await batch_listener.handle_batch(events) + await handler.handle_batch(events) # Assert - vector backend received 5 records in single batch call assert len(vector_backend.batch_calls) == 1 assert len(vector_backend.batch_calls[0]) == 5 assert vector_backend.total_records == 5 - # Assert - keyword backend received 3 records in single batch call - assert len(keyword_backend.batch_calls) == 1 - assert len(keyword_backend.batch_calls[0]) == 3 - assert keyword_backend.total_records == 3 - @pytest.mark.asyncio async def test_end_to_end_fanout_to_batch(self): """End-to-end: RecordPublished -> FanOut -> BatchProcess.""" @@ -155,7 +131,7 @@ async def test_end_to_end_fanout_to_batch(self): outbox = FakeOutbox() fanout = FanOutToIndexBackends(indexes=registry, outbox=outbox) - batch_listener = IndexRecordBatch(indexes=registry) + handler = VectorIndexHandler(indexes=registry) # Create multiple RecordPublished events num_records = 10 @@ -171,7 +147,7 @@ async def test_end_to_end_fanout_to_batch(self): assert all(e.backend_name == "vector" for e in outbox.events) # Act - Step 2: Batch process all IndexRecord events - await batch_listener.handle_batch(outbox.events) + await handler.handle_batch(outbox.events) # Assert - All records indexed in single batch call assert len(vector_backend.batch_calls) == 1 @@ -183,7 +159,7 @@ async def test_batch_efficiency_large_batch(self): # Arrange backend = FakeBackend("vector") registry = IndexRegistry({"vector": backend}) - batch_listener = IndexRecordBatch(indexes=registry) + handler = VectorIndexHandler(indexes=registry) # Create 1000+ events num_events = 1000 @@ -202,7 +178,7 @@ async def test_batch_efficiency_large_batch(self): ] # Act - await batch_listener.handle_batch(events) + await handler.handle_batch(events) # Assert - All records in single batch call (not 1000 individual calls) assert len(backend.batch_calls) == 1, "Should use single batch call, not individual calls" @@ -210,41 +186,20 @@ async def test_batch_efficiency_large_batch(self): assert len(backend.batch_calls[0]) == num_events @pytest.mark.asyncio - async def test_failure_isolation_between_backends(self): - """Verify failures are isolated per-backend at the event level.""" - # Arrange - working_backend = FakeBackend("working") - failing_backend = MagicMock() - failing_backend.name = "failing" - failing_backend.ingest_batch = AsyncMock(side_effect=Exception("Backend failure")) - - registry = IndexRegistry( - { - "working": working_backend, - "failing": failing_backend, - } - ) - batch_listener = IndexRecordBatch(indexes=registry) - - # Create events for both backends - working_events = [ - IndexRecord( - id=EventId(uuid4()), - backend_name="working", - record_srn=RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ), - metadata={"id": i}, - ) - for i in range(3) - ] + async def test_failure_propagates_from_handler(self): + """Verify failures propagate from handler (Worker handles retry).""" + # Arrange - VectorIndexHandler looks up "vector" backend specifically + vector_backend = MagicMock() + vector_backend.name = "vector" + vector_backend.ingest_batch = AsyncMock(side_effect=Exception("Backend failure")) + + registry = IndexRegistry({"vector": vector_backend}) + handler = VectorIndexHandler(indexes=registry) - failing_events = [ + events = [ IndexRecord( id=EventId(uuid4()), - backend_name="failing", + backend_name="vector", record_srn=RecordSRN( domain=Domain("test.example.com"), id=LocalId(str(uuid4())), @@ -255,14 +210,6 @@ async def test_failure_isolation_between_backends(self): for i in range(2) ] - # Note: In the new design, events for different backends are in separate - # outbox entries and processed independently by the worker. This test - # verifies that at the listener level, each backend batch is independent. - - # Process working backend events - await batch_listener.handle_batch(working_events) - assert working_backend.total_records == 3 - - # Process failing backend events - should raise + # Process events - should raise (Worker handles retry) with pytest.raises(Exception, match="Backend failure"): - await batch_listener.handle_batch(failing_events) + await handler.handle_batch(events) diff --git a/server/tests/unit/domain/index/test_fanout_listener.py b/server/tests/unit/domain/index/test_fanout_listener.py index 3308e56..5033680 100644 --- a/server/tests/unit/domain/index/test_fanout_listener.py +++ b/server/tests/unit/domain/index/test_fanout_listener.py @@ -1,4 +1,4 @@ -"""Unit tests for FanOutToIndexBackends listener.""" +"""Unit tests for FanOutToIndexBackends handler.""" from typing import Any from unittest.mock import AsyncMock @@ -7,7 +7,7 @@ import pytest from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.listener.fanout_listener import FanOutToIndexBackends +from osa.domain.index.handler.fanout_to_index_backends import FanOutToIndexBackends from osa.domain.index.model.registry import IndexRegistry from osa.domain.record.event.record_published import RecordPublished from osa.domain.shared.event import EventId @@ -32,7 +32,7 @@ def __init__(self): self.events: list[Any] = [] self.append = AsyncMock(side_effect=self._append) - async def _append(self, event: Any) -> None: + async def _append(self, event: Any, routing_key: str | None = None) -> None: self.events.append(event) @@ -66,7 +66,7 @@ def sample_metadata() -> dict: class TestFanOutToIndexBackends: - """Tests for FanOutToIndexBackends listener.""" + """Tests for FanOutToIndexBackends handler.""" @pytest.mark.asyncio async def test_creates_index_record_per_backend( @@ -75,14 +75,14 @@ async def test_creates_index_record_per_backend( sample_deposition_srn: DepositionSRN, sample_metadata: dict, ): - """Listener should create one IndexRecord event per registered backend.""" + """Handler should create one IndexRecord event per registered backend.""" # Arrange backend1 = FakeBackend("vector") backend2 = FakeBackend("keyword") registry = IndexRegistry({"vector": backend1, "keyword": backend2}) outbox = FakeOutbox() - listener = FanOutToIndexBackends(indexes=registry, outbox=outbox) + handler = FanOutToIndexBackends(indexes=registry, outbox=outbox) event = RecordPublished( id=EventId(uuid4()), @@ -92,7 +92,7 @@ async def test_creates_index_record_per_backend( ) # Act - await listener.handle(event) + await handler.handle(event) # Assert assert len(outbox.events) == 2 @@ -121,7 +121,7 @@ async def test_creates_unique_event_ids( ) outbox = FakeOutbox() - listener = FanOutToIndexBackends(indexes=registry, outbox=outbox) + handler = FanOutToIndexBackends(indexes=registry, outbox=outbox) event = RecordPublished( id=EventId(uuid4()), @@ -131,7 +131,7 @@ async def test_creates_unique_event_ids( ) # Act - await listener.handle(event) + await handler.handle(event) # Assert event_ids = [e.id for e in outbox.events] @@ -149,7 +149,7 @@ async def test_empty_registry_creates_no_events( registry = IndexRegistry({}) outbox = FakeOutbox() - listener = FanOutToIndexBackends(indexes=registry, outbox=outbox) + handler = FanOutToIndexBackends(indexes=registry, outbox=outbox) event = RecordPublished( id=EventId(uuid4()), @@ -159,7 +159,7 @@ async def test_empty_registry_creates_no_events( ) # Act - await listener.handle(event) + await handler.handle(event) # Assert assert len(outbox.events) == 0 diff --git a/server/tests/unit/domain/index/test_index_batch_listener.py b/server/tests/unit/domain/index/test_index_batch_listener.py deleted file mode 100644 index ed7a55a..0000000 --- a/server/tests/unit/domain/index/test_index_batch_listener.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Unit tests for IndexRecordBatch listener.""" - -from typing import Any -from unittest.mock import AsyncMock -from uuid import uuid4 - -import pytest - -from osa.domain.index.event.index_record import IndexRecord -from osa.domain.index.listener.index_batch_listener import IndexRecordBatch -from osa.domain.index.model.registry import IndexRegistry -from osa.domain.shared.error import SkippedEventsError -from osa.domain.shared.event import EventId -from osa.domain.shared.model.srn import Domain, LocalId, RecordSRN, RecordVersion - - -class FakeBackend: - """Fake storage backend for testing.""" - - def __init__(self, name: str): - self._name = name - self.batches: list[list[tuple[str, dict]]] = [] - self.ingest_batch = AsyncMock(side_effect=self._ingest_batch) - - @property - def name(self) -> str: - return self._name - - async def _ingest_batch(self, records: list[tuple[str, dict[str, Any]]]) -> None: - self.batches.append(list(records)) - - -class FailingBackend: - """Backend that always fails for testing error handling.""" - - def __init__(self, name: str): - self._name = name - self.ingest_batch = AsyncMock(side_effect=Exception("Backend failure")) - - @property - def name(self) -> str: - return self._name - - -def make_index_record( - backend_name: str, - srn: RecordSRN | None = None, - metadata: dict | None = None, -) -> IndexRecord: - """Create an IndexRecord for testing.""" - if srn is None: - srn = RecordSRN( - domain=Domain("test.example.com"), - id=LocalId(str(uuid4())), - version=RecordVersion(1), - ) - return IndexRecord( - id=EventId(uuid4()), - backend_name=backend_name, - record_srn=srn, - metadata=metadata or {"title": "Test"}, - ) - - -class TestIndexRecordBatch: - """Tests for IndexRecordBatch listener.""" - - @pytest.mark.asyncio - async def test_groups_events_by_backend(self): - """Listener should group events by backend and call ingest_batch per backend.""" - # Arrange - vector_backend = FakeBackend("vector") - keyword_backend = FakeBackend("keyword") - registry = IndexRegistry({"vector": vector_backend, "keyword": keyword_backend}) - - listener = IndexRecordBatch(indexes=registry) - - events = [ - make_index_record("vector", metadata={"id": 1}), - make_index_record("keyword", metadata={"id": 2}), - make_index_record("vector", metadata={"id": 3}), - ] - - # Act - await listener.handle_batch(events) - - # Assert - vector backend received 2 records in one batch - assert len(vector_backend.batches) == 1 - assert len(vector_backend.batches[0]) == 2 - - # Assert - keyword backend received 1 record in one batch - assert len(keyword_backend.batches) == 1 - assert len(keyword_backend.batches[0]) == 1 - - @pytest.mark.asyncio - async def test_passes_correct_srn_and_metadata(self): - """Listener should pass correct SRN and metadata to backend.""" - # Arrange - backend = FakeBackend("vector") - registry = IndexRegistry({"vector": backend}) - - listener = IndexRecordBatch(indexes=registry) - - srn = RecordSRN( - domain=Domain("test.example.com"), - id=LocalId("test-record-id"), - version=RecordVersion(1), - ) - metadata = {"title": "Test Record", "organism": "human"} - - events = [make_index_record("vector", srn=srn, metadata=metadata)] - - # Act - await listener.handle_batch(events) - - # Assert - assert len(backend.batches) == 1 - assert len(backend.batches[0]) == 1 - record_srn, record_meta = backend.batches[0][0] - assert record_srn == str(srn) - assert record_meta == metadata - - @pytest.mark.asyncio - async def test_handles_empty_batch(self): - """Listener should handle empty batch without error.""" - # Arrange - backend = FakeBackend("vector") - registry = IndexRegistry({"vector": backend}) - - listener = IndexRecordBatch(indexes=registry) - - # Act - await listener.handle_batch([]) - - # Assert - no calls made - assert len(backend.batches) == 0 - - @pytest.mark.asyncio - async def test_raises_skipped_for_unknown_backend(self): - """Unknown backend should raise SkippedEventsError.""" - # Arrange - vector_backend = FakeBackend("vector") - registry = IndexRegistry({"vector": vector_backend}) - - listener = IndexRecordBatch(indexes=registry) - - unknown_event = make_index_record("unknown", metadata={"id": 1}) - events = [unknown_event] - - # Act & Assert - should raise SkippedEventsError - with pytest.raises(SkippedEventsError) as exc_info: - await listener.handle_batch(events) - - assert unknown_event.id in exc_info.value.event_ids - assert "unknown" in exc_info.value.reason - - @pytest.mark.asyncio - async def test_raises_on_backend_failure(self): - """Listener should propagate backend failures for retry.""" - # Arrange - failing_backend = FailingBackend("vector") - registry = IndexRegistry({"vector": failing_backend}) - - listener = IndexRecordBatch(indexes=registry) - - events = [make_index_record("vector")] - - # Act & Assert - with pytest.raises(Exception, match="Backend failure"): - await listener.handle_batch(events) - - -class TestIndexRecordBatchFailureVisibility: - """Tests for failure visibility in IndexRecordBatch (US4).""" - - @pytest.mark.asyncio - async def test_failure_includes_backend_name_in_error(self): - """Failures should include backend name for visibility.""" - # Arrange - failing_backend = FailingBackend("vector") - registry = IndexRegistry({"vector": failing_backend}) - - listener = IndexRecordBatch(indexes=registry) - - events = [make_index_record("vector")] - - # Act & Assert - # The backend failure should propagate with context - with pytest.raises(Exception): - await listener.handle_batch(events) - - # Backend's ingest_batch was called - failing_backend.ingest_batch.assert_called_once() diff --git a/server/tests/unit/domain/shared/test_claim_result.py b/server/tests/unit/domain/shared/test_claim_result.py new file mode 100644 index 0000000..89ce770 --- /dev/null +++ b/server/tests/unit/domain/shared/test_claim_result.py @@ -0,0 +1,92 @@ +"""Unit tests for ClaimResult value object. + +Tests result of a claim operation. +""" + +from datetime import UTC, datetime +from uuid import uuid4 + +import pytest + +from osa.domain.shared.event import ClaimResult, Event, EventId + + +class DummyEvent(Event): + """Test event for claim result tests.""" + + id: EventId + data: str + + +class TestClaimResult: + """Tests for ClaimResult value object.""" + + def test_create_with_events(self): + """ClaimResult should hold claimed events and timestamp.""" + event1 = DummyEvent(id=EventId(uuid4()), data="event1") + event2 = DummyEvent(id=EventId(uuid4()), data="event2") + now = datetime.now(UTC) + + result = ClaimResult(events=[event1, event2], claimed_at=now) + + assert len(result.events) == 2 + assert result.events[0] is event1 + assert result.events[1] is event2 + assert result.claimed_at == now + + def test_create_empty(self): + """ClaimResult can be created with empty events list.""" + now = datetime.now(UTC) + result = ClaimResult(events=[], claimed_at=now) + + assert result.events == [] + assert result.claimed_at == now + + def test_immutable(self): + """ClaimResult should be immutable (frozen dataclass).""" + now = datetime.now(UTC) + result = ClaimResult(events=[], claimed_at=now) + + with pytest.raises(AttributeError): + result.events = [] # type: ignore[misc] + + with pytest.raises(AttributeError): + result.claimed_at = datetime.now(UTC) # type: ignore[misc] + + def test_events_required(self): + """ClaimResult events is required.""" + with pytest.raises(TypeError): + ClaimResult(claimed_at=datetime.now(UTC)) # type: ignore[call-arg] + + def test_claimed_at_required(self): + """ClaimResult claimed_at is required.""" + with pytest.raises(TypeError): + ClaimResult(events=[]) # type: ignore[call-arg] + + def test_bool_true_when_has_events(self): + """ClaimResult should be truthy when events are present.""" + event = DummyEvent(id=EventId(uuid4()), data="test") + result = ClaimResult(events=[event], claimed_at=datetime.now(UTC)) + assert bool(result) is True + + def test_bool_false_when_empty(self): + """ClaimResult should be falsy when events are empty.""" + result = ClaimResult(events=[], claimed_at=datetime.now(UTC)) + assert bool(result) is False + + def test_len(self): + """ClaimResult should support len() returning number of events.""" + event1 = DummyEvent(id=EventId(uuid4()), data="event1") + event2 = DummyEvent(id=EventId(uuid4()), data="event2") + + result = ClaimResult(events=[event1, event2], claimed_at=datetime.now(UTC)) + assert len(result) == 2 + + def test_iter(self): + """ClaimResult should be iterable over events.""" + event1 = DummyEvent(id=EventId(uuid4()), data="event1") + event2 = DummyEvent(id=EventId(uuid4()), data="event2") + + result = ClaimResult(events=[event1, event2], claimed_at=datetime.now(UTC)) + events = list(result) + assert events == [event1, event2] diff --git a/server/tests/unit/domain/shared/test_event.py b/server/tests/unit/domain/shared/test_event.py index a753296..906bf2b 100644 --- a/server/tests/unit/domain/shared/test_event.py +++ b/server/tests/unit/domain/shared/test_event.py @@ -1,9 +1,9 @@ """Unit tests for domain event infrastructure. -Regression tests for event listener metaclass behavior. +Tests for event handler metaclass behavior. """ -from osa.domain.shared.event import BatchEventListener, Event, EventId, EventListener +from osa.domain.shared.event import Event, EventHandler, EventId class DummyEvent(Event): @@ -13,55 +13,57 @@ class DummyEvent(Event): data: str -class TestEventListenerMetaclass: - """Tests for EventListener metaclass __event_type__ extraction.""" +class TestEventHandlerMetaclass: + """Tests for EventHandler metaclass __event_type__ extraction.""" - def test_event_listener_has_event_type_set(self): - """EventListener subclasses should have __event_type__ extracted from generic param.""" + def test_event_handler_has_event_type_set(self): + """EventHandler subclasses should have __event_type__ extracted from generic param.""" - class MyListener(EventListener[DummyEvent]): + class MyHandler(EventHandler[DummyEvent]): async def handle(self, event: DummyEvent) -> None: pass - assert hasattr(MyListener, "__event_type__") - assert MyListener.__event_type__ is DummyEvent + assert hasattr(MyHandler, "__event_type__") + assert MyHandler.__event_type__ is DummyEvent - def test_batch_event_listener_has_event_type_set(self): - """BatchEventListener subclasses should have __event_type__ extracted from generic param. + def test_event_handler_is_dataclass(self): + """EventHandler subclasses should be automatically converted to dataclasses.""" - Regression test: Previously _extract_event_type only checked for 'EventListener' - in the origin name, causing BatchEventListener subclasses to not get __event_type__. - """ + class HandlerWithDeps(EventHandler[DummyEvent]): + some_dep: str - class MyBatchListener(BatchEventListener[DummyEvent]): - async def handle_batch(self, events: list[DummyEvent]) -> None: + async def handle(self, event: DummyEvent) -> None: pass - assert hasattr(MyBatchListener, "__event_type__") - assert MyBatchListener.__event_type__ is DummyEvent + # Dataclass should allow instantiation with keyword args + handler = HandlerWithDeps(some_dep="test") + assert handler.some_dep == "test" - def test_event_listener_is_dataclass(self): - """EventListener subclasses should be automatically converted to dataclasses.""" - - class ListenerWithDeps(EventListener[DummyEvent]): - some_dep: str + def test_event_handler_default_classvars(self): + """EventHandler should have sensible default classvars.""" + class MyHandler(EventHandler[DummyEvent]): async def handle(self, event: DummyEvent) -> None: pass - # Dataclass should allow instantiation with keyword args - listener = ListenerWithDeps(some_dep="test") - assert listener.some_dep == "test" + assert MyHandler.__routing_key__ is None + assert MyHandler.__batch_size__ == 1 + assert MyHandler.__batch_timeout__ == 5.0 + assert MyHandler.__poll_interval__ == 0.5 + assert MyHandler.__max_retries__ == 3 + assert MyHandler.__claim_timeout__ == 300.0 - def test_batch_event_listener_is_dataclass(self): - """BatchEventListener subclasses should be automatically converted to dataclasses.""" + def test_event_handler_custom_classvars(self): + """EventHandler subclasses can override classvars.""" - class BatchListenerWithDeps(BatchEventListener[DummyEvent]): - some_dep: str + class BatchHandler(EventHandler[DummyEvent]): + __routing_key__ = "my-queue" + __batch_size__ = 100 + __batch_timeout__ = 10.0 async def handle_batch(self, events: list[DummyEvent]) -> None: pass - # Dataclass should allow instantiation with keyword args - listener = BatchListenerWithDeps(some_dep="test") - assert listener.some_dep == "test" + assert BatchHandler.__routing_key__ == "my-queue" + assert BatchHandler.__batch_size__ == 100 + assert BatchHandler.__batch_timeout__ == 10.0 diff --git a/server/tests/unit/domain/shared/test_outbox_claim.py b/server/tests/unit/domain/shared/test_outbox_claim.py new file mode 100644 index 0000000..0aa1e5c --- /dev/null +++ b/server/tests/unit/domain/shared/test_outbox_claim.py @@ -0,0 +1,165 @@ +"""Unit tests for Outbox claim, mark_delivered, and mark_failed operations. + +Tests for User Story 1: Reliable Event Processing. +""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock +from uuid import uuid4 + +import pytest + +from osa.domain.shared.event import ClaimResult, Event, EventId +from osa.domain.shared.outbox import Outbox + + +class DummyEvent(Event): + """Test event for outbox tests.""" + + id: EventId + data: str + + +class TestOutboxClaim: + """Tests for Outbox.claim() returning claimed events.""" + + @pytest.fixture + def mock_repo(self) -> AsyncMock: + """Create a mock EventRepository.""" + return AsyncMock() + + @pytest.fixture + def outbox(self, mock_repo: AsyncMock) -> Outbox: + """Create an Outbox with mocked repository.""" + outbox = Outbox.__new__(Outbox) + outbox._repo = mock_repo + return outbox + + async def test_claim_returns_claimed_events(self, outbox: Outbox, mock_repo: AsyncMock): + """Outbox.claim() should return ClaimResult from repository.""" + event1 = DummyEvent(id=EventId(uuid4()), data="event1") + event2 = DummyEvent(id=EventId(uuid4()), data="event2") + now = datetime.now(UTC) + + mock_repo.claim.return_value = ClaimResult(events=[event1, event2], claimed_at=now) + + result = await outbox.claim( + event_types=[DummyEvent], + limit=10, + routing_key=None, + ) + + assert isinstance(result, ClaimResult) + assert len(result) == 2 + assert result.events[0] is event1 + assert result.events[1] is event2 + mock_repo.claim.assert_called_once_with( + event_types=["DummyEvent"], + limit=10, + routing_key=None, + ) + + async def test_claim_with_routing_key(self, outbox: Outbox, mock_repo: AsyncMock): + """Outbox.claim() should pass routing_key to repository.""" + event = DummyEvent(id=EventId(uuid4()), data="routed") + now = datetime.now(UTC) + + mock_repo.claim.return_value = ClaimResult(events=[event], claimed_at=now) + + result = await outbox.claim( + event_types=[DummyEvent], + limit=5, + routing_key="vector", + ) + + assert len(result) == 1 + mock_repo.claim.assert_called_once_with( + event_types=["DummyEvent"], + limit=5, + routing_key="vector", + ) + + async def test_claim_returns_empty_when_no_events(self, outbox: Outbox, mock_repo: AsyncMock): + """Outbox.claim() should return empty ClaimResult when no events available.""" + now = datetime.now(UTC) + mock_repo.claim.return_value = ClaimResult(events=[], claimed_at=now) + + result = await outbox.claim( + event_types=[DummyEvent], + limit=10, + routing_key=None, + ) + + assert len(result) == 0 + assert bool(result) is False + + +class TestOutboxMarkDelivered: + """Tests for Outbox.mark_delivered() updating single event.""" + + @pytest.fixture + def mock_repo(self) -> AsyncMock: + """Create a mock EventRepository.""" + return AsyncMock() + + @pytest.fixture + def outbox(self, mock_repo: AsyncMock) -> Outbox: + """Create an Outbox with mocked repository.""" + outbox = Outbox.__new__(Outbox) + outbox._repo = mock_repo + return outbox + + async def test_mark_delivered_updates_status(self, outbox: Outbox, mock_repo: AsyncMock): + """Outbox.mark_delivered() should update event status to delivered.""" + event_id = EventId(uuid4()) + + await outbox.mark_delivered(event_id) + + mock_repo.update_status.assert_called_once_with(event_id, status="delivered") + + +class TestOutboxMarkFailed: + """Tests for Outbox.mark_failed() with retry logic.""" + + @pytest.fixture + def mock_repo(self) -> AsyncMock: + """Create a mock EventRepository.""" + return AsyncMock() + + @pytest.fixture + def outbox(self, mock_repo: AsyncMock) -> Outbox: + """Create an Outbox with mocked repository.""" + outbox = Outbox.__new__(Outbox) + outbox._repo = mock_repo + return outbox + + async def test_mark_failed_increments_retry_count(self, outbox: Outbox, mock_repo: AsyncMock): + """Outbox.mark_failed() should increment retry_count.""" + event_id = EventId(uuid4()) + + await outbox.mark_failed(event_id, "Connection error") + + mock_repo.update_status.assert_called_once() + call_args = mock_repo.update_status.call_args + assert call_args[0][0] == event_id + assert call_args[1]["status"] == "failed" + assert call_args[1]["error"] == "Connection error" + + async def test_mark_failed_with_max_retries_sets_failed_status( + self, outbox: Outbox, mock_repo: AsyncMock + ): + """Outbox.mark_failed() should set status=failed after max_retries exceeded. + + Note: The retry counting is handled by the repository implementation. + The Outbox service just forwards the mark_failed call. The repository + checks retry_count and either resets to pending (for retry) or marks failed. + """ + event_id = EventId(uuid4()) + error = "Persistent failure" + + # The mark_failed_with_retry method handles retry logic + await outbox.mark_failed_with_retry(event_id, error, max_retries=3) + + mock_repo.mark_failed_with_retry.assert_called_once_with( + event_id, error=error, max_retries=3 + ) diff --git a/server/tests/unit/domain/shared/test_outbox_routing.py b/server/tests/unit/domain/shared/test_outbox_routing.py new file mode 100644 index 0000000..f047467 --- /dev/null +++ b/server/tests/unit/domain/shared/test_outbox_routing.py @@ -0,0 +1,92 @@ +"""Unit tests for Outbox routing key filtering. + +Tests for User Story 4: Event Routing. +""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock +from uuid import uuid4 + +import pytest + +from osa.domain.shared.event import ClaimResult, Event, EventId +from osa.domain.shared.outbox import Outbox + + +class DummyEvent(Event): + """Test event for routing tests.""" + + id: EventId + data: str + + +class TestOutboxRoutingKey: + """Tests for Outbox claim() with routing_key filtering.""" + + @pytest.fixture + def mock_repo(self) -> AsyncMock: + """Create a mock EventRepository.""" + return AsyncMock() + + @pytest.fixture + def outbox(self, mock_repo: AsyncMock) -> Outbox: + """Create an Outbox with mocked repository.""" + outbox = Outbox.__new__(Outbox) + outbox._repo = mock_repo + return outbox + + async def test_claim_with_routing_key_filters_events( + self, outbox: Outbox, mock_repo: AsyncMock + ): + """Outbox.claim() with routing_key should filter to matching events.""" + event = DummyEvent(id=EventId(uuid4()), data="routed") + mock_repo.claim.return_value = ClaimResult(events=[event], claimed_at=datetime.now(UTC)) + + result = await outbox.claim( + event_types=[DummyEvent], + limit=10, + routing_key="vector", + ) + + mock_repo.claim.assert_called_once_with( + event_types=["DummyEvent"], + limit=10, + routing_key="vector", + ) + assert len(result) == 1 + + async def test_claim_with_routing_key_none_matches_unrouted( + self, outbox: Outbox, mock_repo: AsyncMock + ): + """Outbox.claim() with routing_key=None should match unrouted events.""" + event = DummyEvent(id=EventId(uuid4()), data="unrouted") + mock_repo.claim.return_value = ClaimResult(events=[event], claimed_at=datetime.now(UTC)) + + result = await outbox.claim( + event_types=[DummyEvent], + limit=10, + routing_key=None, + ) + + mock_repo.claim.assert_called_once_with( + event_types=["DummyEvent"], + limit=10, + routing_key=None, + ) + assert len(result) == 1 + + async def test_append_with_routing_key(self, outbox: Outbox, mock_repo: AsyncMock): + """Outbox.append() should pass routing_key to repository.""" + event = DummyEvent(id=EventId(uuid4()), data="routed-event") + + await outbox.append(event, routing_key="keyword") + + mock_repo.save.assert_called_once_with(event, status="pending", routing_key="keyword") + + async def test_append_without_routing_key(self, outbox: Outbox, mock_repo: AsyncMock): + """Outbox.append() without routing_key should pass None.""" + event = DummyEvent(id=EventId(uuid4()), data="unrouted-event") + + await outbox.append(event) + + mock_repo.save.assert_called_once_with(event, status="pending", routing_key=None) diff --git a/server/tests/unit/domain/shared/test_worker_config.py b/server/tests/unit/domain/shared/test_worker_config.py new file mode 100644 index 0000000..156ff3b --- /dev/null +++ b/server/tests/unit/domain/shared/test_worker_config.py @@ -0,0 +1,105 @@ +"""Unit tests for WorkerConfig value object. + +Tests validation rules for worker configuration. +""" + +import pytest +from pydantic import ValidationError + +from osa.domain.shared.event import Event, EventId, WorkerConfig + + +class DummyEvent(Event): + """Test event for worker config tests.""" + + id: EventId + data: str + + +class TestWorkerConfigValidation: + """Tests for WorkerConfig validation rules.""" + + def test_valid_config(self): + """WorkerConfig with valid values should be created.""" + config = WorkerConfig( + name="test-worker", + event_types=(DummyEvent,), + routing_key="test-key", + batch_size=10, + batch_timeout=5.0, + poll_interval=0.5, + max_retries=3, + claim_timeout=300.0, + ) + assert config.name == "test-worker" + assert config.event_types == (DummyEvent,) + assert config.routing_key == "test-key" + assert config.batch_size == 10 + assert config.batch_timeout == 5.0 + assert config.poll_interval == 0.5 + assert config.max_retries == 3 + assert config.claim_timeout == 300.0 + + def test_defaults(self): + """WorkerConfig should have sensible defaults.""" + config = WorkerConfig( + name="test-worker", + event_types=(DummyEvent,), + ) + assert config.routing_key is None + assert config.batch_size == 1 + assert config.batch_timeout == 5.0 + assert config.poll_interval == 0.5 + assert config.max_retries == 3 + assert config.claim_timeout == 300.0 + + def test_name_required(self): + """WorkerConfig name is required.""" + with pytest.raises(ValidationError): + WorkerConfig(event_types=(DummyEvent,)) # type: ignore[call-arg] + + def test_event_types_required(self): + """WorkerConfig event_types is required.""" + with pytest.raises(ValidationError): + WorkerConfig(name="test-worker") # type: ignore[call-arg] + + def test_event_types_not_empty(self): + """WorkerConfig event_types must not be empty.""" + with pytest.raises(ValidationError, match="event_types must not be empty"): + WorkerConfig(name="test-worker", event_types=()) + + def test_batch_size_must_be_positive(self): + """WorkerConfig batch_size must be >= 1.""" + with pytest.raises(ValidationError, match="greater than or equal to 1"): + WorkerConfig(name="test-worker", event_types=(DummyEvent,), batch_size=0) + + def test_batch_timeout_must_be_positive(self): + """WorkerConfig batch_timeout must be > 0.""" + with pytest.raises(ValidationError, match="greater than 0"): + WorkerConfig(name="test-worker", event_types=(DummyEvent,), batch_timeout=0) + + def test_poll_interval_must_be_positive(self): + """WorkerConfig poll_interval must be > 0.""" + with pytest.raises(ValidationError, match="greater than 0"): + WorkerConfig(name="test-worker", event_types=(DummyEvent,), poll_interval=0) + + def test_max_retries_must_be_non_negative(self): + """WorkerConfig max_retries must be >= 0.""" + with pytest.raises(ValidationError, match="greater than or equal to 0"): + WorkerConfig(name="test-worker", event_types=(DummyEvent,), max_retries=-1) + + def test_claim_timeout_must_be_greater_than_batch_timeout(self): + """WorkerConfig claim_timeout must be > batch_timeout.""" + with pytest.raises(ValidationError, match="claim_timeout must be > batch_timeout"): + WorkerConfig( + name="test-worker", + event_types=(DummyEvent,), + batch_timeout=10.0, + claim_timeout=5.0, + ) + + def test_immutable(self): + """WorkerConfig should be immutable (frozen Pydantic model).""" + config = WorkerConfig(name="test-worker", event_types=(DummyEvent,)) + with pytest.raises(ValidationError, match="frozen"): + config.name = "new-name" # type: ignore[misc] diff --git a/server/tests/unit/domain/shared/test_worker_state.py b/server/tests/unit/domain/shared/test_worker_state.py new file mode 100644 index 0000000..1ba00ed --- /dev/null +++ b/server/tests/unit/domain/shared/test_worker_state.py @@ -0,0 +1,113 @@ +"""Unit tests for WorkerState and WorkerStatus. + +Tests runtime state tracking for workers. +""" + +from datetime import UTC, datetime + +import pytest + +from osa.domain.shared.event import ( + Event, + EventId, + WorkerConfig, + WorkerState, + WorkerStatus, +) + + +class DummyEvent(Event): + """Test event for worker state tests.""" + + id: EventId + data: str + + +class TestWorkerStatus: + """Tests for WorkerStatus enum.""" + + def test_status_values(self): + """WorkerStatus should have expected values.""" + assert WorkerStatus.IDLE.value == "idle" + assert WorkerStatus.CLAIMING.value == "claiming" + assert WorkerStatus.PROCESSING.value == "processing" + assert WorkerStatus.STOPPING.value == "stopping" + + def test_all_statuses(self): + """WorkerStatus should have exactly 4 values.""" + assert len(WorkerStatus) == 4 + + +class TestWorkerState: + """Tests for WorkerState runtime entity.""" + + @pytest.fixture + def config(self) -> WorkerConfig: + """Fixture providing a valid WorkerConfig.""" + return WorkerConfig( + name="test-worker", + event_types=(DummyEvent,), + ) + + def test_initial_state(self, config: WorkerConfig): + """WorkerState should initialize with idle status and zero counts.""" + state = WorkerState(config=config) + assert state.config is config + assert state.status == WorkerStatus.IDLE + assert state.current_batch == [] + assert state.last_claim_at is None + assert state.processed_count == 0 + assert state.failed_count == 0 + assert state.error is None + + def test_mutable_status(self, config: WorkerConfig): + """WorkerState status should be mutable.""" + state = WorkerState(config=config) + state.status = WorkerStatus.CLAIMING + assert state.status == WorkerStatus.CLAIMING + + def test_mutable_current_batch(self, config: WorkerConfig): + """WorkerState current_batch should be mutable.""" + state = WorkerState(config=config) + # In real usage, events would be appended here + state.current_batch = ["event1", "event2"] # type: ignore[list-item] + assert len(state.current_batch) == 2 + + def test_mutable_counters(self, config: WorkerConfig): + """WorkerState counters should be mutable.""" + state = WorkerState(config=config) + state.processed_count = 10 + state.failed_count = 2 + assert state.processed_count == 10 + assert state.failed_count == 2 + + def test_mutable_last_claim_at(self, config: WorkerConfig): + """WorkerState last_claim_at should be mutable.""" + state = WorkerState(config=config) + now = datetime.now(UTC) + state.last_claim_at = now + assert state.last_claim_at == now + + def test_mutable_error(self, config: WorkerConfig): + """WorkerState error should be mutable.""" + state = WorkerState(config=config) + error = ValueError("test error") + state.error = error + assert state.error is error + + def test_state_with_custom_initial_values(self, config: WorkerConfig): + """WorkerState should accept custom initial values.""" + now = datetime.now(UTC) + state = WorkerState( + config=config, + status=WorkerStatus.PROCESSING, + current_batch=[], + last_claim_at=now, + processed_count=100, + failed_count=5, + error=None, + ) + assert state.status == WorkerStatus.PROCESSING + assert state.last_claim_at == now + assert state.processed_count == 100 + assert state.failed_count == 5 diff --git a/server/tests/unit/infrastructure/event/__init__.py b/server/tests/unit/infrastructure/event/__init__.py new file mode 100644 index 0000000..bb8b32b --- /dev/null +++ b/server/tests/unit/infrastructure/event/__init__.py @@ -0,0 +1 @@ +"""Unit tests for event infrastructure.""" diff --git a/server/tests/unit/infrastructure/event/test_worker.py b/server/tests/unit/infrastructure/event/test_worker.py new file mode 100644 index 0000000..394ce4e --- /dev/null +++ b/server/tests/unit/infrastructure/event/test_worker.py @@ -0,0 +1,349 @@ +"""Unit tests for Worker poll loop lifecycle. + +Tests for pull-based event processing with EventHandler pattern. +""" + +import asyncio +from datetime import UTC, datetime +from typing import ClassVar +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.shared.event import ( + ClaimResult, + Event, + EventHandler, + EventId, + WorkerStatus, +) +from osa.domain.shared.outbox import Outbox + + +class DummyEvent(Event): + """Test event for worker tests.""" + + id: EventId + data: str + + +class DummyHandler(EventHandler[DummyEvent]): + """Test handler that tracks handle calls.""" + + __batch_size__: ClassVar[int] = 10 + __poll_interval__: ClassVar[float] = 0.1 + + processed_events: list[DummyEvent] + + async def handle(self, event: DummyEvent) -> None: + self.processed_events.append(event) + + +class FailingHandler(EventHandler[DummyEvent]): + """Handler that always raises an error.""" + + async def handle(self, event: DummyEvent) -> None: + raise RuntimeError("Processing failed") + + +def make_mock_container( + outbox: AsyncMock, + session: AsyncMock | None = None, + handler: EventHandler | None = None, +): + """Create a mock DI container that provides scoped dependencies. + + Creates a container mock that returns Outbox, AsyncSession, and handler + when called as an async context manager with scope parameter. + """ + if session is None: + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + session.rollback = AsyncMock() + + async def get_dependency(cls): + """Return the appropriate dependency based on the requested class.""" + if cls == Outbox: + return outbox + if cls == AsyncSession: + return session + if handler is not None and (cls is type(handler) or issubclass(cls, EventHandler)): + return handler + return session + + # Create scope that returns dependencies + scope = AsyncMock() + scope.get = AsyncMock(side_effect=get_dependency) + + # Create async context manager + context = MagicMock() + context.__aenter__ = AsyncMock(return_value=scope) + context.__aexit__ = AsyncMock(return_value=None) + + # Container callable returns the context manager + container = MagicMock() + container.return_value = context + + return container + + +class TestWorkerPollLoop: + """Tests for Worker poll loop lifecycle.""" + + @pytest.mark.asyncio + async def test_worker_claims_and_processes_events(self): + """Worker should claim events and call handler.handle().""" + from osa.infrastructure.event.worker import Worker + + # Arrange + event1 = DummyEvent(id=EventId(uuid4()), data="event1") + claim_result = ClaimResult(events=[event1], claimed_at=datetime.now(UTC)) + + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = claim_result + outbox.mark_delivered = AsyncMock() + + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + + handler = DummyHandler(processed_events=[]) + container = make_mock_container(outbox, session, handler) + + worker = Worker(DummyHandler) + worker.set_container(container) + + # Act - Run one poll cycle + await worker._poll_once() + + # Assert + assert len(handler.processed_events) == 1 + assert handler.processed_events[0] == event1 + outbox.claim.assert_called_once() + outbox.mark_delivered.assert_called_once_with(event1.id) + + @pytest.mark.asyncio + async def test_worker_returns_false_when_no_events(self): + """Worker._poll_once should return False when no events are available.""" + from osa.infrastructure.event.worker import Worker + + # Arrange + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = ClaimResult(events=[], claimed_at=datetime.now(UTC)) + + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + + handler = DummyHandler(processed_events=[]) + container = make_mock_container(outbox, session, handler) + + worker = Worker(DummyHandler) + worker.set_container(container) + + # Act + had_events = await worker._poll_once() + + # Assert - Should return False when no events (sleep happens in _run()) + assert had_events is False + assert worker.state.status == WorkerStatus.IDLE + + @pytest.mark.asyncio + async def test_worker_updates_state_during_processing(self): + """Worker should update state as it processes events.""" + from osa.infrastructure.event.worker import Worker + + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + claim_result = ClaimResult(events=[event], claimed_at=datetime.now(UTC)) + + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = claim_result + outbox.mark_delivered = AsyncMock() + + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + + state_during_process: WorkerStatus | None = None + + class StateTrackingHandler(EventHandler[DummyEvent]): + async def handle(self, event: DummyEvent) -> None: + nonlocal state_during_process + state_during_process = worker.state.status + + handler = StateTrackingHandler() + container = make_mock_container(outbox, session, handler) + + worker = Worker(StateTrackingHandler) + worker.set_container(container) + + # Act + await worker._poll_once() + + # Assert - State should have been PROCESSING during handle() + assert state_during_process == WorkerStatus.PROCESSING + + +class TestWorkerStartStop: + """Tests for Worker.start() and Worker.stop().""" + + @pytest.mark.asyncio + async def test_start_creates_asyncio_task(self): + """Worker.start() should create an asyncio task.""" + from osa.infrastructure.event.worker import Worker + + # Arrange + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = ClaimResult(events=[], claimed_at=datetime.now(UTC)) + + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + + handler = DummyHandler(processed_events=[]) + container = make_mock_container(outbox, session, handler) + + worker = Worker(DummyHandler) + worker.set_container(container) + + # Act + task = worker.start() + + # Assert + assert isinstance(task, asyncio.Task) + assert not task.done() + + # Cleanup + worker.stop() + await asyncio.sleep(0.15) # Give time for graceful shutdown + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_start_requires_container(self): + """Worker.start() should raise if container not set.""" + from osa.infrastructure.event.worker import Worker + + worker = Worker(DummyHandler) + + # Act & Assert - Should raise without container + with pytest.raises(RuntimeError, match="Container not set"): + worker.start() + + @pytest.mark.asyncio + async def test_stop_signals_graceful_shutdown(self): + """Worker.stop() should signal graceful shutdown.""" + from osa.infrastructure.event.worker import Worker + + # Arrange + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = ClaimResult(events=[], claimed_at=datetime.now(UTC)) + + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + + handler = DummyHandler(processed_events=[]) + container = make_mock_container(outbox, session, handler) + + worker = Worker(DummyHandler) + worker.set_container(container) + + # Act + task = worker.start() + await asyncio.sleep(0.05) # Let it run a bit + worker.stop() + await asyncio.sleep(0.15) # Wait for shutdown + + # Assert - Task should complete (not be cancelled) + assert worker.state.status == WorkerStatus.STOPPING or task.done() + + # Cleanup + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_worker_finishes_current_batch_before_stopping(self): + """Worker should finish processing current event before stopping.""" + from osa.infrastructure.event.worker import Worker + + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = ClaimResult(events=[event], claimed_at=datetime.now(UTC)) + outbox.mark_delivered = AsyncMock() + + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + + event_processed = asyncio.Event() + + class SlowHandler(EventHandler[DummyEvent]): + async def handle(self, event: DummyEvent) -> None: + await asyncio.sleep(0.1) # Simulate processing time + event_processed.set() + + handler = SlowHandler() + container = make_mock_container(outbox, session, handler) + + worker = Worker(SlowHandler) + worker.set_container(container) + + # Act + task = worker.start() + await asyncio.sleep(0.02) # Let it start processing + worker.stop() + + # Wait for event to complete + try: + await asyncio.wait_for(event_processed.wait(), timeout=0.5) + except asyncio.TimeoutError: + pass + + # Assert - Event should have been processed despite stop signal + assert event_processed.is_set() + + # Cleanup + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + @pytest.mark.asyncio + async def test_worker_handles_handler_error(self): + """Worker should handle errors in handler and mark event failed.""" + from osa.infrastructure.event.worker import Worker + + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + claim_result = ClaimResult(events=[event], claimed_at=datetime.now(UTC)) + + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = claim_result + outbox.mark_failed_with_retry = AsyncMock() + + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + + handler = FailingHandler() + container = make_mock_container(outbox, session, handler) + + worker = Worker(FailingHandler) + worker.set_container(container) + + # Act - Run one poll cycle + await worker._poll_once() + + # Assert - Event should be marked as failed + outbox.mark_failed_with_retry.assert_called_once() + assert worker.state.failed_count == 1 + assert worker.state.error is not None diff --git a/server/tests/unit/infrastructure/event/test_worker_batching.py b/server/tests/unit/infrastructure/event/test_worker_batching.py new file mode 100644 index 0000000..e26fb32 --- /dev/null +++ b/server/tests/unit/infrastructure/event/test_worker_batching.py @@ -0,0 +1,172 @@ +"""Unit tests for Worker batch processing. + +Tests for batch_size configuration and batch accumulation behavior. +""" + +from datetime import UTC, datetime +from typing import ClassVar +from unittest.mock import AsyncMock, MagicMock +from uuid import uuid4 + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.shared.event import ( + ClaimResult, + Event, + EventHandler, + EventId, +) +from osa.domain.shared.outbox import Outbox + + +class DummyEvent(Event): + """Test event for worker tests.""" + + id: EventId + data: str + + +def make_mock_container( + outbox: AsyncMock, + session: AsyncMock | None = None, + handler: EventHandler | None = None, +): + """Create a mock DI container.""" + if session is None: + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + session.rollback = AsyncMock() + + async def get_dependency(cls): + if cls == Outbox: + return outbox + if cls == AsyncSession: + return session + if handler is not None and (cls is type(handler) or issubclass(cls, EventHandler)): + return handler + return session + + scope = AsyncMock() + scope.get = AsyncMock(side_effect=get_dependency) + + context = MagicMock() + context.__aenter__ = AsyncMock(return_value=scope) + context.__aexit__ = AsyncMock(return_value=None) + + container = MagicMock() + container.return_value = context + + return container + + +class TestWorkerBatchSizeOne: + """Tests for batch_size=1 (immediate processing).""" + + @pytest.mark.asyncio + async def test_batch_size_one_processes_immediately(self): + """batch_size=1 should call handle() for single event.""" + from osa.infrastructure.event.worker import Worker + + # Arrange + event = DummyEvent(id=EventId(uuid4()), data="test") + claim_result = ClaimResult(events=[event], claimed_at=datetime.now(UTC)) + + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = claim_result + outbox.mark_delivered = AsyncMock() + + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + + class ImmediateHandler(EventHandler[DummyEvent]): + __batch_size__: ClassVar[int] = 1 + + processed_events: list[DummyEvent] + + async def handle(self, event: DummyEvent) -> None: + self.processed_events.append(event) + + handler = ImmediateHandler(processed_events=[]) + container = make_mock_container(outbox, session, handler) + + worker = Worker(ImmediateHandler) + worker.set_container(container) + + # Act + await worker._poll_once() + + # Assert - handle() called (not handle_batch()) + assert len(handler.processed_events) == 1 + assert handler.processed_events[0] == event + + +class TestWorkerBatchAccumulation: + """Tests for batch accumulation with batch_size > 1.""" + + @pytest.mark.asyncio + async def test_batch_calls_handle_batch(self): + """batch_size > 1 should call handle_batch() with all claimed events.""" + from osa.infrastructure.event.worker import Worker + + # Arrange + events = [DummyEvent(id=EventId(uuid4()), data=f"event{i}") for i in range(5)] + claim_result = ClaimResult(events=events, claimed_at=datetime.now(UTC)) + + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = claim_result + outbox.mark_delivered = AsyncMock() + + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + + class BatchHandler(EventHandler[DummyEvent]): + __batch_size__: ClassVar[int] = 100 + + processed_batches: list[list[DummyEvent]] + + async def handle_batch(self, events: list[DummyEvent]) -> None: + self.processed_batches.append(list(events)) + + handler = BatchHandler(processed_batches=[]) + container = make_mock_container(outbox, session, handler) + + worker = Worker(BatchHandler) + worker.set_container(container) + + # Act + await worker._poll_once() + + # Assert - handle_batch() called with all events + assert len(handler.processed_batches) == 1 + assert handler.processed_batches[0] == events + + +class TestWorkerBatchingIntegration: + """Integration tests for different batch configurations.""" + + @pytest.mark.asyncio + async def test_different_handlers_different_batch_sizes(self): + """Different handlers can have different batch sizes.""" + from osa.infrastructure.event.worker import Worker + + # Arrange handlers with different batch sizes + class SmallBatchHandler(EventHandler[DummyEvent]): + __batch_size__: ClassVar[int] = 1 + + async def handle(self, event: DummyEvent) -> None: + pass + + class LargeBatchHandler(EventHandler[DummyEvent]): + __batch_size__: ClassVar[int] = 100 + + async def handle_batch(self, events: list[DummyEvent]) -> None: + pass + + # Create workers + small_worker = Worker(SmallBatchHandler) + large_worker = Worker(LargeBatchHandler) + + # Assert config is read correctly from classvars + assert small_worker.config.batch_size == 1 + assert large_worker.config.batch_size == 100 diff --git a/server/tests/unit/infrastructure/event/test_worker_pool.py b/server/tests/unit/infrastructure/event/test_worker_pool.py new file mode 100644 index 0000000..c7d3f19 --- /dev/null +++ b/server/tests/unit/infrastructure/event/test_worker_pool.py @@ -0,0 +1,242 @@ +"""Unit tests for WorkerPool management. + +Tests for WorkerPool lifecycle and handler registration. +""" + +import asyncio +from datetime import UTC, datetime +from typing import ClassVar +from unittest.mock import AsyncMock, MagicMock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from osa.domain.shared.event import ( + ClaimResult, + Event, + EventHandler, + EventId, +) +from osa.domain.shared.outbox import Outbox + + +class DummyEvent(Event): + """Test event for worker tests.""" + + id: EventId + data: str + + +class DummyHandler(EventHandler[DummyEvent]): + """Test handler for pool tests.""" + + __poll_interval__: ClassVar[float] = 0.01 + + async def handle(self, event: DummyEvent) -> None: + pass + + +class AnotherHandler(EventHandler[DummyEvent]): + """Another test handler for pool tests.""" + + __poll_interval__: ClassVar[float] = 0.01 + + async def handle(self, event: DummyEvent) -> None: + pass + + +def make_mock_container( + outbox: AsyncMock | None = None, + session: AsyncMock | None = None, + handler: EventHandler | None = None, +): + """Create a mock DI container.""" + if outbox is None: + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = ClaimResult(events=[], claimed_at=datetime.now(UTC)) + outbox.append = AsyncMock() + outbox.reset_stale_claims = AsyncMock(return_value=0) + + if session is None: + session = AsyncMock(spec=AsyncSession) + session.commit = AsyncMock() + session.rollback = AsyncMock() + + async def get_dependency(cls): + if cls == Outbox: + return outbox + if cls == AsyncSession: + return session + if handler is not None and issubclass(cls, EventHandler): + return handler + # Return a default handler if requested + if issubclass(cls, EventHandler): + return cls() + return session + + scope = AsyncMock() + scope.get = AsyncMock(side_effect=get_dependency) + + context = MagicMock() + context.__aenter__ = AsyncMock(return_value=scope) + context.__aexit__ = AsyncMock(return_value=None) + + container = MagicMock() + container.return_value = context + + return container + + +class TestWorkerPoolManagement: + """Tests for WorkerPool management.""" + + def test_pool_register_creates_worker(self): + """WorkerPool.register() should create a Worker from handler type.""" + from osa.infrastructure.event.worker import WorkerPool + + # Arrange + pool = WorkerPool() + + # Act + worker = pool.register(DummyHandler) + + # Assert + assert len(pool.workers) == 1 + assert worker.handler_type is DummyHandler + assert worker.name == "DummyHandler" + + def test_pool_manages_multiple_handlers(self): + """WorkerPool should manage multiple registered handlers.""" + from osa.infrastructure.event.worker import WorkerPool + + # Arrange + pool = WorkerPool() + + # Act + pool.register(DummyHandler) + pool.register(AnotherHandler) + + # Assert + assert len(pool.workers) == 2 + names = {w.name for w in pool.workers} + assert names == {"DummyHandler", "AnotherHandler"} + + @pytest.mark.asyncio + async def test_pool_start_starts_all_workers(self): + """WorkerPool.start() should start all registered workers.""" + from osa.infrastructure.event.worker import WorkerPool + + # Arrange + container = make_mock_container() + pool = WorkerPool(container=container, stale_claim_interval=0) + + pool.register(DummyHandler) + pool.register(AnotherHandler) + + # Act + await pool.start() + await asyncio.sleep(0.02) # Let workers start + + # Assert - Workers should be running + assert all(w._task is not None for w in pool.workers) + assert all(not w._task.done() for w in pool.workers) + + # Cleanup + await pool.stop() + + @pytest.mark.asyncio + async def test_pool_stop_stops_all_workers(self): + """WorkerPool.stop() should stop all workers gracefully.""" + from osa.infrastructure.event.worker import WorkerPool + + # Arrange + container = make_mock_container() + pool = WorkerPool(container=container, stale_claim_interval=0) + + pool.register(DummyHandler) + + await pool.start() + await asyncio.sleep(0.02) + + # Act + await pool.stop() + + # Assert - Workers should be stopped + for worker in pool.workers: + assert worker._shutdown is True + + @pytest.mark.asyncio + async def test_pool_context_manager(self): + """WorkerPool should work as async context manager.""" + from osa.infrastructure.event.worker import WorkerPool + + # Arrange + container = make_mock_container() + pool = WorkerPool(container=container, stale_claim_interval=0) + pool.register(DummyHandler) + + # Act + async with pool: + # Assert - Pool should be running + assert all(w._task is not None for w in pool.workers) + + # Assert - Pool should be stopped after context exit + for worker in pool.workers: + assert worker._shutdown is True + + @pytest.mark.asyncio + async def test_pool_requires_container(self): + """WorkerPool.start() should raise if container not set.""" + from osa.infrastructure.event.worker import WorkerPool + + pool = WorkerPool() + pool.register(DummyHandler) + + # Act & Assert + with pytest.raises(RuntimeError, match="Container not set"): + await pool.start() + + def test_pool_set_container_propagates_to_workers(self): + """WorkerPool.set_container() should propagate to all workers.""" + from osa.infrastructure.event.worker import WorkerPool + + # Arrange + pool = WorkerPool() + pool.register(DummyHandler) + pool.register(AnotherHandler) + + container = make_mock_container() + + # Act + pool.set_container(container) + + # Assert + for worker in pool.workers: + assert worker._container is container + + +class TestWorkerPoolStaleClaims: + """Tests for stale claim cleanup in WorkerPool.""" + + @pytest.mark.asyncio + async def test_pool_runs_stale_claim_cleanup(self): + """WorkerPool should periodically reset stale claims.""" + from osa.infrastructure.event.worker import WorkerPool + + # Arrange + outbox = AsyncMock(spec=Outbox) + outbox.claim.return_value = ClaimResult(events=[], claimed_at=datetime.now(UTC)) + outbox.append = AsyncMock() + outbox.reset_stale_claims = AsyncMock(return_value=2) + + container = make_mock_container(outbox) + pool = WorkerPool(container=container, stale_claim_interval=0.05) + pool.register(DummyHandler) + + # Act + await pool.start() + await asyncio.sleep(0.1) # Wait for cleanup to run + + # Assert - Stale claim cleanup should have been called + # Note: The actual call might depend on timing + await pool.stop()