From 7214d3ad12e340ff334171f090553f7599e0bb88 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Fri, 20 Mar 2026 20:14:21 +0000 Subject: [PATCH 1/3] feat: add extension points for DI provider and event handler overrides create_app() now accepts `providers` and `extra_handlers` kwargs, allowing external packages to swap infrastructure adapters (e.g. container runners, storage backends) and register additional event handlers without duplicating any app wiring. - create_container() accepts *extra_providers and extra_handlers - EventProvider accepts extra_handlers, merging them with core handlers for subscription routing, WorkerPool registration, and DI resolution - Handler DI bindings moved from class-body locals() trick to dynamic self.provide() in __init__ to support runtime extension --- server/osa/application/api/rest/app.py | 34 +++++++++++-- server/osa/application/di.py | 32 +++++++++--- server/osa/config.py | 7 +++ server/osa/infrastructure/event/di.py | 70 ++++++++++++++++---------- 4 files changed, 105 insertions(+), 38 deletions(-) diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index 2dc477c..c746551 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -1,7 +1,9 @@ import logging from contextlib import asynccontextmanager +from typing import Any import logfire +from dishka import Provider as DishkaProvider from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from sqlalchemy.ext.asyncio import AsyncEngine @@ -26,6 +28,7 @@ from osa.config import Config, configure_logging from osa.domain.shared.authorization.startup import validate_all_handlers from osa.domain.shared.error import OSAError +from osa.domain.shared.event import EventHandler from osa.infrastructure.event.worker import WorkerPool from osa.infrastructure.persistence.seed import ensure_system_user from osa.util.di.fastapi import setup_dishka @@ -50,8 +53,30 @@ async def lifespan(app: FastAPI): await container.close() -def create_app() -> FastAPI: - """Create FastAPI application.""" +def create_app( + *, + providers: list[DishkaProvider] | None = None, + extra_handlers: list[type[EventHandler[Any]]] | None = None, +) -> FastAPI: + """Create FastAPI application. + + This is the main entry point for running OSA. External hosts (e.g. + Amacrin) use the keyword arguments to customise the runtime without + duplicating any app wiring:: + + app = create_app( + providers=[K8sProvider()], + extra_handlers=[MeterUsage, SendNotification], + ) + + Args: + providers: Extra Dishka providers that override the built-in + bindings. For example, pass a ``K8sProvider`` to replace the + default OCI container runner with a Kubernetes-based one. + extra_handlers: Extra event handler types registered alongside + the core handlers for subscription routing, WorkerPool + registration, and DI resolution. + """ # Pydantic Settings populates from env vars at runtime config = Config() # type: ignore[call-arg] @@ -74,7 +99,10 @@ def create_app() -> FastAPI: logfire.instrument_fastapi(app_instance) # Setup dependency injection - container = create_container() + container = create_container( + *(providers or []), + extra_handlers=extra_handlers, + ) setup_dishka(container, app_instance) # Register v1 routes with /api/v1 prefix diff --git a/server/osa/application/di.py b/server/osa/application/di.py index 6384767..014740d 100644 --- a/server/osa/application/di.py +++ b/server/osa/application/di.py @@ -1,28 +1,43 @@ +from typing import Any + from dishka import AsyncContainer, make_async_container +from dishka import Provider as DishkaProvider -from osa.util.paths import OSAPaths from osa.config import Config from osa.domain.auth.util.di import AuthProvider -from osa.domain.discovery.util.di import DiscoveryProvider from osa.domain.deposition.util.di import DepositionProvider +from osa.domain.discovery.util.di import DiscoveryProvider from osa.domain.feature.util.di import FeatureProvider from osa.domain.semantics.util.di.provider import SemanticsProvider +from osa.domain.shared.event import EventHandler from osa.domain.validation.util.di import ValidationProvider from osa.infrastructure.auth import AuthInfraProvider from osa.infrastructure.event.di import EventProvider from osa.infrastructure.http.di import HttpProvider from osa.infrastructure.index.di import IndexProvider -from osa.infrastructure.source.di import SourceProvider from osa.infrastructure.oci import OciProvider from osa.infrastructure.persistence import PersistenceProvider +from osa.infrastructure.source.di import SourceProvider from osa.util.di.scope import Scope +from osa.util.paths import OSAPaths -def create_container() -> AsyncContainer: - # Pydantic Settings populates from env vars at runtime - config = Config() # type: ignore[call-arg] +def create_container( + *extra_providers: DishkaProvider, + extra_handlers: list[type[EventHandler[Any]]] | None = None, +) -> AsyncContainer: + """Create the DI container with all default providers. - # OSAPaths reads OSA_DATA_DIR from environment automatically + Args: + extra_providers: Additional Dishka providers appended after defaults. + Later providers override earlier ones for the same type, so these + can replace any built-in binding (e.g. swap OciProvider for a + Kubernetes runner). + extra_handlers: Additional event handler types to register alongside + the core handlers. They will be included in the subscription + registry, WorkerPool, and DI resolution automatically. + """ + config = Config() # type: ignore[call-arg] paths = OSAPaths() return make_async_container( @@ -30,7 +45,7 @@ def create_container() -> AsyncContainer: OciProvider(), IndexProvider(), SourceProvider(), - EventProvider(), + EventProvider(extra_handlers=extra_handlers), HttpProvider(), DepositionProvider(), FeatureProvider(), @@ -39,6 +54,7 @@ def create_container() -> AsyncContainer: AuthProvider(), AuthInfraProvider(), DiscoveryProvider(), + *extra_providers, context={Config: config, OSAPaths: paths}, scopes=Scope, # type: ignore[arg-type] # Custom scope class ) diff --git a/server/osa/config.py b/server/osa/config.py index f662119..c5158f3 100644 --- a/server/osa/config.py +++ b/server/osa/config.py @@ -208,6 +208,13 @@ def base_url(self) -> str: scheme = "http" if self.domain == "localhost" else "https" return f"{scheme}://{self.domain}" + @model_validator(mode="after") + def derive_frontend_url(self) -> Self: + """Derive frontend URL from domain if still the default localhost value.""" + if self.frontend.url == "http://localhost:3000": + self.frontend = Frontend(url=self.base_url) + return self + @model_validator(mode="after") def derive_callback_url(self) -> Self: """Derive OAuth callback URL from domain if not explicitly set. diff --git a/server/osa/infrastructure/event/di.py b/server/osa/infrastructure/event/di.py index b9fe0ed..d7ed4cd 100644 --- a/server/osa/infrastructure/event/di.py +++ b/server/osa/infrastructure/event/di.py @@ -27,30 +27,28 @@ # Type alias for handler list HandlerTypes = NewType("HandlerTypes", list[type[EventHandler[Any]]]) -# All event handlers for WorkerPool registration -HANDLERS: HandlerTypes = HandlerTypes( - [ - # Feature handlers (must run before source triggers) - CreateFeatureTables, - InsertRecordFeatures, - # Source handlers - TriggerInitialSourceRun, - PullFromSource, - # Validation handlers - ValidateDeposition, - # Deposition handlers - CreateDepositionFromSource, - ReturnToDraft, - # Curation handlers - AutoApproveCuration, - # Record handlers - ConvertDepositionToRecord, - ] -) +# Core event handlers shipped with OSA +_CORE_HANDLERS: list[type[EventHandler[Any]]] = [ + # Feature handlers (must run before source triggers) + CreateFeatureTables, + InsertRecordFeatures, + # Source handlers + TriggerInitialSourceRun, + PullFromSource, + # Validation handlers + ValidateDeposition, + # Deposition handlers + CreateDepositionFromSource, + ReturnToDraft, + # Curation handlers + AutoApproveCuration, + # Record handlers + ConvertDepositionToRecord, +] def build_subscription_registry(handlers: HandlerTypes) -> SubscriptionRegistry: - """Build a SubscriptionRegistry from the HANDLERS list. + """Build a SubscriptionRegistry from handler list. Maps each handler's __event_type__.__name__ → handler.__name__. """ @@ -68,8 +66,30 @@ class EventProvider(Provider): Handlers, Schedules, and Outbox are UOW-scoped (fresh per unit of work). WorkerPool and SubscriptionRegistry are APP-scoped singletons. + + To register additional event handlers (e.g. from an external package), + pass them to the constructor:: + + EventProvider(extra_handlers=[MeterUsage, SendNotification]) + + Extra handlers are merged with the core handlers for subscription + routing, WorkerPool registration, and DI resolution. """ + def __init__( + self, + *, + extra_handlers: list[type[EventHandler[Any]]] | None = None, + ) -> None: + super().__init__() + self._all_handlers = HandlerTypes([*_CORE_HANDLERS, *(extra_handlers or [])]) + + # Register DI bindings for every handler (core + extra). + # Each handler becomes a UOW-scoped dependency that Dishka can + # instantiate with its declared fields injected. + for handler_type in self._all_handlers: + self.provide(handler_type, scope=Scope.UOW) + # UOW-scoped Outbox (wraps EventRepository + SubscriptionRegistry) @provide(scope=Scope.UOW) def get_outbox(self, repo: EventRepository, registry: SubscriptionRegistry) -> Outbox: @@ -80,17 +100,13 @@ def get_outbox(self, repo: EventRepository, registry: SubscriptionRegistry) -> O def get_event_log(self, repo: EventRepository) -> EventLog: return EventLog(repo) - # UOW-scoped providers for handlers - for _handler_type in HANDLERS: - locals()[_handler_type.__name__] = provide(_handler_type, scope=Scope.UOW) - # UOW-scoped provider for SourceSchedule source_schedule = provide(SourceSchedule, scope=Scope.UOW) @provide(scope=Scope.APP) def get_handler_types(self) -> HandlerTypes: - """Return the handler types for WorkerPool registration.""" - return HANDLERS + """Return all handler types (core + extra) for WorkerPool registration.""" + return self._all_handlers @provide(scope=Scope.APP) def get_subscription_registry(self, handler_types: HandlerTypes) -> SubscriptionRegistry: From 0bf88cd16bbc90a39a84d1ebe2df1e16afdf1414 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 21 Mar 2026 10:40:51 +0000 Subject: [PATCH 2/3] test: add tests for DI extension points, use uvicorn factory mode - Remove module-level app = create_app() to eliminate import side effects - Switch uvicorn to --factory mode in Dockerfile and Justfile - Add test_app_factory: tests create_app and create_container with provider overrides, extra event handlers, and both combined - Add test_event_provider: tests EventProvider default behaviour, extra handler merging, subscription registry, and DI resolution --- server/Dockerfile | 2 +- server/Justfile | 4 +- server/osa/application/api/rest/app.py | 7 - .../unit/application/test_app_factory.py | 253 ++++++++++++++++++ .../event/test_event_provider.py | 205 ++++++++++++++ 5 files changed, 461 insertions(+), 10 deletions(-) create mode 100644 server/tests/unit/application/test_app_factory.py create mode 100644 server/tests/unit/infrastructure/event/test_event_provider.py diff --git a/server/Dockerfile b/server/Dockerfile index bdbae0b..8da3778 100644 --- a/server/Dockerfile +++ b/server/Dockerfile @@ -74,4 +74,4 @@ HEALTHCHECK --interval=5s --timeout=10s --start-period=5s --retries=3 \ CMD curl --fail http://localhost:8000/api/v1/health || exit 1 ENTRYPOINT ["/app/entrypoint.sh"] -CMD ["uvicorn", "osa.application.api.rest.app:app", "--host", "0.0.0.0", "--port", "8000"] +CMD ["uvicorn", "--factory", "osa.application.api.rest.app:create_app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/server/Justfile b/server/Justfile index 705066d..5310a3f 100644 --- a/server/Justfile +++ b/server/Justfile @@ -8,11 +8,11 @@ default: # Run development server with hot-reload dev: - uv run uvicorn osa.application.api.rest.app:app --reload --host 0.0.0.0 --port 8000 + uv run uvicorn --factory osa.application.api.rest.app:create_app --reload --host 0.0.0.0 --port 8000 # Run with debug logging dev-debug: - LOG_LEVEL=DEBUG uv run uvicorn osa.application.api.rest.app:app --reload --host 0.0.0.0 --port 8000 + LOG_LEVEL=DEBUG uv run uvicorn --factory osa.application.api.rest.app:create_app --reload --host 0.0.0.0 --port 8000 # === Testing === diff --git a/server/osa/application/api/rest/app.py b/server/osa/application/api/rest/app.py index c746551..4cbd273 100644 --- a/server/osa/application/api/rest/app.py +++ b/server/osa/application/api/rest/app.py @@ -143,10 +143,3 @@ async def unhandled_exception_handler(request: Request, exc: Exception): ) return app_instance - - -# Create app instance for uvicorn -# Note: Logfire must be configured before this module is imported -# In production: start_app.py handles this -# In tests: configure in conftest.py -app = create_app() diff --git a/server/tests/unit/application/test_app_factory.py b/server/tests/unit/application/test_app_factory.py new file mode 100644 index 0000000..6b276b4 --- /dev/null +++ b/server/tests/unit/application/test_app_factory.py @@ -0,0 +1,253 @@ +"""Unit tests for create_app and create_container extension points. + +Tests that create_app() and create_container() correctly wire extra +providers and event handlers, enabling infrastructure to be swapped +at startup. +""" + +import asyncio +import os +from pathlib import Path +from typing import ClassVar +from unittest.mock import patch + +import pytest +from dishka import provide +from fastapi.testclient import TestClient + +from osa.application.api.rest.app import create_app +from osa.application.di import create_container +from osa.domain.shared.event import Event, EventHandler, EventId +from osa.domain.shared.model.hook import HookDefinition +from osa.domain.shared.model.source import SourceDefinition +from osa.domain.shared.model.subscription_registry import SubscriptionRegistry +from osa.domain.source.port.source_runner import SourceInputs, SourceOutput, SourceRunner +from osa.domain.validation.model.hook_result import HookResult, HookStatus +from osa.domain.validation.port.hook_runner import HookInputs, HookRunner +from osa.infrastructure.event.di import HandlerTypes, _CORE_HANDLERS +from osa.infrastructure.oci.runner import OciHookRunner +from osa.infrastructure.oci.source_runner import OciSourceRunner +from osa.util.di.base import Provider +from osa.util.di.scope import Scope + +# Minimal env for Config() +os.environ.setdefault( + "OSA_AUTH__JWT__SECRET", + "test-secret-that-is-at-least-32-characters-long", +) + + +# --------------------------------------------------------------------------- +# Stub runners +# --------------------------------------------------------------------------- + + +class StubHookRunner: + """Stub HookRunner for testing provider overrides.""" + + async def run(self, hook: HookDefinition, inputs: HookInputs, work_dir: Path) -> HookResult: + return HookResult(hook_name=hook.name, status=HookStatus.PASSED, duration_seconds=0.0) + + +class StubSourceRunner: + """Stub SourceRunner for testing provider overrides.""" + + async def run( + self, + source: SourceDefinition, + inputs: SourceInputs, + files_dir: Path, + work_dir: Path, + ) -> SourceOutput: + return SourceOutput(records=[], session=None, files_dir=files_dir) + + +class StubRunnerProvider(Provider): + """Provides stub runners, overriding the default OCI ones.""" + + @provide(scope=Scope.UOW, override=True) + def get_hook_runner(self) -> HookRunner: + return StubHookRunner() + + @provide(scope=Scope.UOW, override=True) + def get_source_runner(self) -> SourceRunner: + return StubSourceRunner() + + +# --------------------------------------------------------------------------- +# Stub event handler +# --------------------------------------------------------------------------- + + +class CustomEvent(Event): + """Test event for extra handler registration.""" + + id: EventId + data: str + + +class CustomHandler(EventHandler[CustomEvent]): + """Test handler registered via extra_handlers.""" + + __poll_interval__: ClassVar[float] = 0.01 + + async def handle(self, event: CustomEvent) -> None: + pass + + +# --------------------------------------------------------------------------- +# Tests: create_container — provider overrides +# --------------------------------------------------------------------------- + + +class TestProviderOverrides: + """Test that extra providers override default bindings.""" + + def test_runner_override(self): + """Extra provider replaces default OCI runners.""" + container = create_container(StubRunnerProvider()) + + async def resolve(): + async with container(scope=Scope.UOW) as uow: + hook = await uow.get(HookRunner) + source = await uow.get(SourceRunner) + return hook, source + + hook_runner, source_runner = asyncio.run(resolve()) + assert isinstance(hook_runner, StubHookRunner) + assert isinstance(source_runner, StubSourceRunner) + + def test_default_runners_without_override(self): + """Without extra providers, default OCI runners are used.""" + container = create_container() + + async def resolve(): + async with container(scope=Scope.UOW) as uow: + hook = await uow.get(HookRunner) + source = await uow.get(SourceRunner) + return hook, source + + hook_runner, source_runner = asyncio.run(resolve()) + assert isinstance(hook_runner, OciHookRunner) + assert isinstance(source_runner, OciSourceRunner) + + +# --------------------------------------------------------------------------- +# Tests: create_container — extra event handlers +# --------------------------------------------------------------------------- + + +class TestExtraHandlers: + """Test that extra_handlers are wired into the event system.""" + + def test_handler_resolvable_from_di(self): + """Extra handlers can be instantiated by the DI container.""" + container = create_container(extra_handlers=[CustomHandler]) + + async def resolve(): + async with container(scope=Scope.UOW) as uow: + return await uow.get(CustomHandler) + + handler = asyncio.run(resolve()) + assert isinstance(handler, CustomHandler) + + def test_handler_in_handler_types(self): + """Extra handlers appear in HandlerTypes for WorkerPool registration.""" + container = create_container(extra_handlers=[CustomHandler]) + handler_types = asyncio.run(container.get(HandlerTypes)) + + assert CustomHandler in handler_types + for core in _CORE_HANDLERS: + assert core in handler_types + + def test_handler_in_subscription_registry(self): + """Extra handlers are routed in the subscription registry.""" + container = create_container(extra_handlers=[CustomHandler]) + registry = asyncio.run(container.get(SubscriptionRegistry)) + + assert "CustomEvent" in registry + assert "CustomHandler" in registry["CustomEvent"] + + def test_no_extra_handlers_unchanged(self): + """Without extra_handlers, only core handlers are present.""" + container = create_container() + handler_types = asyncio.run(container.get(HandlerTypes)) + + assert list(handler_types) == list(_CORE_HANDLERS) + + +# --------------------------------------------------------------------------- +# Tests: create_container — both extension points +# --------------------------------------------------------------------------- + + +class TestCombined: + """Test providers and extra_handlers used simultaneously.""" + + def test_both_providers_and_extra_handlers(self): + """Provider overrides and extra handlers work together.""" + container = create_container( + StubRunnerProvider(), + extra_handlers=[CustomHandler], + ) + + async def resolve(): + handler_types = await container.get(HandlerTypes) + async with container(scope=Scope.UOW) as uow: + hook = await uow.get(HookRunner) + custom = await uow.get(CustomHandler) + return hook, custom, handler_types + + hook_runner, custom, handler_types = asyncio.run(resolve()) + + assert isinstance(hook_runner, StubHookRunner) + assert isinstance(custom, CustomHandler) + assert CustomHandler in handler_types + + +# --------------------------------------------------------------------------- +# Tests: create_app +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _skip_handler_auth_validation(): + """Patch out validate_all_handlers for create_app tests. + + Other test modules define CommandHandler/QueryHandler subclasses without + __auth__ declarations. Since __subclasses__() is process-wide, the startup + validation picks them up and fails. This is orthogonal to what we test here. + """ + with patch("osa.application.api.rest.app.validate_all_handlers"): + yield + + +class TestCreateApp: + """Test create_app passes extension points through to the container.""" + + def test_provider_override_via_create_app(self): + """Providers passed to create_app override default bindings.""" + app = create_app(providers=[StubRunnerProvider()]) + container = app.state.dishka_container + + async def resolve(): + async with container(scope=Scope.UOW) as uow: + return await uow.get(HookRunner) + + runner = asyncio.run(resolve()) + assert isinstance(runner, StubHookRunner) + + def test_extra_handlers_via_create_app(self): + """Extra handlers passed to create_app are wired in DI.""" + app = create_app(extra_handlers=[CustomHandler]) + container = app.state.dishka_container + + handler_types = asyncio.run(container.get(HandlerTypes)) + assert CustomHandler in handler_types + + def test_default_create_app_serves_health(self): + """Default create_app produces a working app.""" + app = create_app() + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/api/v1/health") + assert response.status_code == 200 diff --git a/server/tests/unit/infrastructure/event/test_event_provider.py b/server/tests/unit/infrastructure/event/test_event_provider.py new file mode 100644 index 0000000..fb235c3 --- /dev/null +++ b/server/tests/unit/infrastructure/event/test_event_provider.py @@ -0,0 +1,205 @@ +"""Unit tests for EventProvider extension points. + +Tests that EventProvider correctly merges extra handlers with core handlers +for DI resolution, subscription routing, and WorkerPool registration. +""" + +from typing import ClassVar + +import pytest +import pytest_asyncio +from dishka import make_async_container + +from osa.domain.shared.event import Event, EventHandler, EventId +from osa.infrastructure.event.di import ( + EventProvider, + HandlerTypes, + _CORE_HANDLERS, + build_subscription_registry, +) +from osa.util.di.scope import Scope + + +# --------------------------------------------------------------------------- +# Test fixtures: dummy events and handlers +# --------------------------------------------------------------------------- + + +class AlphaEvent(Event): + """Test event A.""" + + id: EventId + data: str + + +class BetaEvent(Event): + """Test event B.""" + + id: EventId + data: str + + +class AlphaHandler(EventHandler[AlphaEvent]): + """Handler for AlphaEvent.""" + + __poll_interval__: ClassVar[float] = 0.01 + + async def handle(self, event: AlphaEvent) -> None: + pass + + +class BetaHandler(EventHandler[BetaEvent]): + """Handler for BetaEvent.""" + + __poll_interval__: ClassVar[float] = 0.01 + + async def handle(self, event: BetaEvent) -> None: + pass + + +# --------------------------------------------------------------------------- +# EventProvider: default behaviour +# --------------------------------------------------------------------------- + + +class TestEventProviderDefaults: + def test_no_args_has_core_handlers(self): + """EventProvider() with no args has exactly the core handlers.""" + provider = EventProvider() + assert list(provider._all_handlers) == list(_CORE_HANDLERS) + + def test_none_extra_handlers_same_as_no_args(self): + """Passing extra_handlers=None is equivalent to no args.""" + provider = EventProvider(extra_handlers=None) + assert list(provider._all_handlers) == list(_CORE_HANDLERS) + + def test_empty_extra_handlers_same_as_no_args(self): + """Passing extra_handlers=[] is equivalent to no args.""" + provider = EventProvider(extra_handlers=[]) + assert list(provider._all_handlers) == list(_CORE_HANDLERS) + + +# --------------------------------------------------------------------------- +# EventProvider: extra handlers +# --------------------------------------------------------------------------- + + +class TestEventProviderExtraHandlers: + def test_extra_handlers_appended_to_core(self): + """Extra handlers appear after core handlers in the list.""" + provider = EventProvider(extra_handlers=[AlphaHandler]) + assert provider._all_handlers[-1] is AlphaHandler + for core in _CORE_HANDLERS: + assert core in provider._all_handlers + + def test_multiple_extra_handlers(self): + """Multiple extra handlers are all included.""" + provider = EventProvider(extra_handlers=[AlphaHandler, BetaHandler]) + assert AlphaHandler in provider._all_handlers + assert BetaHandler in provider._all_handlers + + def test_core_handlers_not_duplicated(self): + """Core handlers appear exactly once even with extras.""" + provider = EventProvider(extra_handlers=[AlphaHandler]) + core_count = sum(1 for h in provider._all_handlers if h in _CORE_HANDLERS) + assert core_count == len(_CORE_HANDLERS) + + +# --------------------------------------------------------------------------- +# Subscription registry +# --------------------------------------------------------------------------- + + +class TestSubscriptionRegistry: + def test_extra_handler_appears_in_registry(self): + """Extra handlers are routed in the subscription registry.""" + handlers = HandlerTypes([*_CORE_HANDLERS, AlphaHandler]) + registry = build_subscription_registry(handlers) + + assert "AlphaEvent" in registry + assert "AlphaHandler" in registry["AlphaEvent"] + + def test_multiple_handlers_for_same_event(self): + """Multiple handlers for the same event type are all registered.""" + + class AnotherAlphaHandler(EventHandler[AlphaEvent]): + __poll_interval__: ClassVar[float] = 0.01 + + async def handle(self, event: AlphaEvent) -> None: + pass + + handlers = HandlerTypes([AlphaHandler, AnotherAlphaHandler]) + registry = build_subscription_registry(handlers) + + assert registry["AlphaEvent"] == {"AlphaHandler", "AnotherAlphaHandler"} + + def test_core_events_unchanged_with_extras(self): + """Adding extra handlers doesn't affect core event routing.""" + core_registry = build_subscription_registry(HandlerTypes(_CORE_HANDLERS)) + extended_registry = build_subscription_registry( + HandlerTypes([*_CORE_HANDLERS, AlphaHandler]) + ) + + for event_type, consumers in core_registry.items(): + assert extended_registry[event_type] == consumers + + +# --------------------------------------------------------------------------- +# DI integration: full container resolution +# --------------------------------------------------------------------------- + + +class TestEventProviderDI: + @pytest_asyncio.fixture + async def container(self): + """Minimal container with EventProvider only (no persistence deps). + + Skips providers that need DB — we only test that extra handlers + are resolvable and appear in HandlerTypes. + """ + # EventProvider's @provide methods for Outbox/EventLog need + # EventRepository, which we don't have. But HandlerTypes and + # handler resolution are self-contained. Use a stripped-down + # container that only includes the handler bindings. + # + # We create a fresh provider that only exposes handler types + # and the handlers themselves (no Outbox/EventLog). + from dishka import Provider, provide + + extra = [AlphaHandler, BetaHandler] + + # Provide just HandlerTypes via a thin wrapper to avoid + # EventRepository dependency from Outbox binding + class HandlerOnlyProvider(Provider): + @provide(scope=Scope.APP) + def handler_types(self) -> HandlerTypes: + return HandlerTypes([*_CORE_HANDLERS, *extra]) + + # Register handler DI bindings + handler_provider = Provider() + for h in extra: + handler_provider.provide(h, scope=Scope.UOW) + + c = make_async_container( + HandlerOnlyProvider(), + handler_provider, + scopes=Scope, # type: ignore[arg-type] + ) + yield c + await c.close() + + @pytest.mark.asyncio + async def test_extra_handlers_resolvable(self, container): + """Extra handlers can be resolved from DI.""" + async with container(scope=Scope.UOW) as uow: + alpha = await uow.get(AlphaHandler) + beta = await uow.get(BetaHandler) + assert isinstance(alpha, AlphaHandler) + assert isinstance(beta, BetaHandler) + + @pytest.mark.asyncio + async def test_handler_types_includes_extras(self, container): + """HandlerTypes in DI includes both core and extra handlers.""" + handler_types = await container.get(HandlerTypes) + assert AlphaHandler in handler_types + assert BetaHandler in handler_types From d38eee91233650877919061feb8107e9e9b66be0 Mon Sep 17 00:00:00 2001 From: Rory Byrne Date: Sat, 21 Mar 2026 11:05:16 +0000 Subject: [PATCH 3/3] fix: guard against duplicate event handler registration Raise ValueError if the same handler type appears in both core and extra handlers, or is passed twice in extra_handlers. Prevents silent misbehaviour or DuplicateFactoryError depending on Dishka version. --- server/osa/infrastructure/event/di.py | 7 +++++++ .../unit/infrastructure/event/test_event_provider.py | 11 +++++++++++ 2 files changed, 18 insertions(+) diff --git a/server/osa/infrastructure/event/di.py b/server/osa/infrastructure/event/di.py index d7ed4cd..286b51e 100644 --- a/server/osa/infrastructure/event/di.py +++ b/server/osa/infrastructure/event/di.py @@ -87,7 +87,14 @@ def __init__( # Register DI bindings for every handler (core + extra). # Each handler becomes a UOW-scoped dependency that Dishka can # instantiate with its declared fields injected. + seen: set[type] = set() for handler_type in self._all_handlers: + if handler_type in seen: + raise ValueError( + f"Duplicate event handler registration: {handler_type.__name__!r}. " + "Remove it from extra_handlers — it is already a core handler." + ) + seen.add(handler_type) self.provide(handler_type, scope=Scope.UOW) # UOW-scoped Outbox (wraps EventRepository + SubscriptionRegistry) diff --git a/server/tests/unit/infrastructure/event/test_event_provider.py b/server/tests/unit/infrastructure/event/test_event_provider.py index fb235c3..bf9c4e9 100644 --- a/server/tests/unit/infrastructure/event/test_event_provider.py +++ b/server/tests/unit/infrastructure/event/test_event_provider.py @@ -104,6 +104,17 @@ def test_core_handlers_not_duplicated(self): core_count = sum(1 for h in provider._all_handlers if h in _CORE_HANDLERS) assert core_count == len(_CORE_HANDLERS) + def test_duplicate_extra_handler_raises(self): + """Passing the same handler twice raises ValueError.""" + with pytest.raises(ValueError, match="Duplicate event handler"): + EventProvider(extra_handlers=[AlphaHandler, AlphaHandler]) + + def test_core_handler_in_extra_raises(self): + """Passing a core handler as extra raises ValueError.""" + core = _CORE_HANDLERS[0] + with pytest.raises(ValueError, match="Duplicate event handler"): + EventProvider(extra_handlers=[core]) + # --------------------------------------------------------------------------- # Subscription registry