From 2c7977aa61fa804b5952c00358fb4559e3b720e2 Mon Sep 17 00:00:00 2001 From: david Date: Wed, 27 May 2026 14:35:55 +0200 Subject: [PATCH 1/2] Add enqueueing of handleresults in TestClient --- src/knowledge_mapper/testing/fake_client.py | 57 +++++++++++- tests/test_handling_loop.py | 99 +++++++++++++++++++++ 2 files changed, 152 insertions(+), 4 deletions(-) create mode 100644 tests/test_handling_loop.py diff --git a/src/knowledge_mapper/testing/fake_client.py b/src/knowledge_mapper/testing/fake_client.py index d15f66c..6861481 100644 --- a/src/knowledge_mapper/testing/fake_client.py +++ b/src/knowledge_mapper/testing/fake_client.py @@ -1,8 +1,9 @@ """In-memory FakeClient that satisfies ClientProtocol for use in tests.""" +from collections import deque from datetime import UTC, datetime -from knowledge_mapper.ke.client import ClientProtocol, PollResult +from knowledge_mapper.ke.client import ClientProtocol, HandleRequest, PollResult from knowledge_mapper.ke.models import ( AskResult, BindingSet, @@ -26,6 +27,8 @@ def __init__(self, fake_url) -> None: # Maps ki_name -> BindingSet to return from execute_post_interaction self._mock_interaction_results: dict[str, BindingSet] = {} self._handle_responses: list[tuple[str, str, int, BindingSet]] = [] + self._incoming_calls: deque[tuple[PollResult, HandleRequest | None]] = deque() + self._next_handle_request_id: int = 1 def ke_is_available(self) -> bool: return True @@ -65,9 +68,9 @@ def register_ki( self._knowledge_interactions.setdefault(kb_id, []).append(registered) return registered - def poll_ki_call(self, kb_id: str) -> tuple[PollResult, None]: - # This fake client never returns any KI calls to handle, but always asks to - # repoll. + def poll_ki_call(self, kb_id: str) -> tuple[PollResult, HandleRequest | None]: + if self._incoming_calls: + return self._incoming_calls.popleft() return (PollResult.REPOLL, None) def post_handle_response( @@ -87,6 +90,52 @@ def mock_result_binding_set(self, ki_name: str, binding_set: BindingSet) -> None is called for the KI with the given name.""" self._mock_interaction_results[ki_name] = binding_set + def enqueue_handle_request( + self, + ki_name: str, + binding_set: BindingSet, + requesting_kb_id: str = "http://example.org/requesting-kb", + ) -> None: + """Queue an incoming KI call so ``poll_ki_call`` returns HANDLE for it. + + Args: + ki_name: Name of a KI that has already been registered via + ``register_ki``. + binding_set: The incoming binding set to pass to the handler. + requesting_kb_id: The ID of the requesting knowledge base + (defaults to a test sentinel). + + Raises: + KeyError: If no registered KI with *ki_name* exists. + """ + ki = next( + ( + ki + for kis in self._knowledge_interactions.values() + for ki in kis + if ki.name == ki_name + ), + None, + ) + if ki is None or ki.id is None: + raise KeyError( + f"No registered KI named '{ki_name}' found in TestClient. " + "Register the KI before enqueueing a handle request." + ) + + handle_request = HandleRequest( + knowledge_interaction_id=ki.id, + handle_request_id=self._next_handle_request_id, + binding_set=binding_set, + requesting_knowledge_base_id=requesting_kb_id, + ) + self._next_handle_request_id += 1 + self._incoming_calls.append((PollResult.HANDLE, handle_request)) + + def enqueue_exit(self) -> None: + """Queue an EXIT signal so ``poll_ki_call`` terminates the handling loop.""" + self._incoming_calls.append((PollResult.EXIT, None)) + def ask( self, kb_id: str, diff --git a/tests/test_handling_loop.py b/tests/test_handling_loop.py new file mode 100644 index 0000000..de57adb --- /dev/null +++ b/tests/test_handling_loop.py @@ -0,0 +1,99 @@ +"""Tests for the handling loop using TestClient's enqueue methods.""" + +import pytest + +from knowledge_mapper import KnowledgeBase +from knowledge_mapper.ke.models import ( + BindingSet, + KnowledgeInteractionInfo, +) +from knowledge_mapper.testing import TestClient + + +@pytest.fixture +def client() -> TestClient: + return TestClient(fake_url="http://fake-ke") + + +@pytest.fixture +def kb(client: TestClient) -> KnowledgeBase: + kb = KnowledgeBase( + id="http://example.org/test#kb", + name="test-kb", + description="A KB for testing the handling loop.", + ke_url="http://fake-ke", + ) + kb.client = client + + captured: list[BindingSet] = [] + + @kb.answer_ki( + name="echo-ki", + graph_pattern="?s ?p ?o .", + ) + def echo_handler( + binding_set: BindingSet, info: KnowledgeInteractionInfo + ) -> BindingSet: + captured.append(binding_set) + return binding_set + + kb.register() + kb._test_captured = captured # type: ignore[attr-defined] + return kb + + +def test_handle_dispatches_to_handler(kb: KnowledgeBase, client: TestClient): + """Enqueueing a HANDLE request dispatches to the handler and posts a response.""" + input_bs: BindingSet = [{"s": "ex:A", "p": "ex:rel", "o": "ex:B"}] + client.enqueue_handle_request("echo-ki", input_bs) + + kb.start_handling_loop(loops=1) + + assert kb._test_captured == [input_bs] # type: ignore[attr-defined] + assert client.last_handle_response == input_bs + + +def test_exit_stops_loop(kb: KnowledgeBase, client: TestClient): + """An EXIT signal terminates the loop without requiring a loops limit.""" + client.enqueue_exit() + kb.start_handling_loop() # would hang without the EXIT signal + + +def test_handle_then_exit(kb: KnowledgeBase, client: TestClient): + """A HANDLE followed by EXIT processes the request and then stops.""" + input_bs: BindingSet = [{"s": "ex:X"}] + client.enqueue_handle_request("echo-ki", input_bs) + client.enqueue_exit() + + kb.start_handling_loop() + + assert kb._test_captured == [input_bs] # type: ignore[attr-defined] + assert client.last_handle_response == input_bs + + +def test_multiple_handle_requests(kb: KnowledgeBase, client: TestClient): + """Multiple HANDLE requests are processed in order.""" + bs1: BindingSet = [{"s": "ex:1"}] + bs2: BindingSet = [{"s": "ex:2"}] + client.enqueue_handle_request("echo-ki", bs1) + client.enqueue_handle_request("echo-ki", bs2) + client.enqueue_exit() + + kb.start_handling_loop() + + assert kb._test_captured == [bs1, bs2] # type: ignore[attr-defined] + assert len(client._handle_responses) == 2 + assert client._handle_responses[0][3] == bs1 + assert client._handle_responses[1][3] == bs2 + + +def test_repoll_fallback(kb: KnowledgeBase, client: TestClient): + """With nothing enqueued, a single loop iteration REPOLLs without error.""" + kb.start_handling_loop(loops=1) + assert kb._test_captured == [] # type: ignore[attr-defined] + + +def test_enqueue_unknown_ki_raises(client: TestClient): + """Enqueueing a handle request for an unregistered KI raises KeyError.""" + with pytest.raises(KeyError, match="No registered KI named 'nonexistent'"): + client.enqueue_handle_request("nonexistent", []) From 15a42535e87ce369ef6b63159c2a3339f05a374a Mon Sep 17 00:00:00 2001 From: david Date: Wed, 27 May 2026 15:10:59 +0200 Subject: [PATCH 2/2] Add overriding of dependencies in handlers --- CONTEXT.md | 28 +++++ src/knowledge_mapper/dependency_injection.py | 19 +++- src/knowledge_mapper/kb/knowledge_base.py | 6 +- src/knowledge_mapper/knowledge_interaction.py | 10 +- tests/test_dependency_injection.py | 100 ++++++++++++++++++ tests/test_dispatch.py | 8 +- 6 files changed, 157 insertions(+), 14 deletions(-) diff --git a/CONTEXT.md b/CONTEXT.md index e8183bb..7ac7ce5 100644 --- a/CONTEXT.md +++ b/CONTEXT.md @@ -349,6 +349,33 @@ kb.start_handling_loop() assert client.last_handle_response == [...] ``` +### `dependency_overrides` — Overriding Dependencies in Tests + +`KnowledgeBase.dependency_overrides` is a `dict[Callable, Callable]` that lets you replace dependency factories at test time, mirroring FastAPI's `app.dependency_overrides`. + +```python +def get_db() -> RealDatabase: + return RealDatabase(url="postgresql://...") + +# In production — handler receives RealDatabase +@kb.answer_ki(name="my-ki", graph_pattern="...") +def handler( + binding_set, info, + db: Annotated[RealDatabase, Depends(get_db)], +): ... + +# In tests — swap the factory +kb.dependency_overrides[get_db] = lambda: FakeDatabase() + +# Clear when done +kb.dependency_overrides.clear() +``` + +**Behaviour:** +- Overrides are **transitive**: overriding a leaf factory (e.g. `get_config`) propagates to all factories that depend on it. +- Override factories **inherit the `cache` setting** from the original `Depends()` declaration. +- Overrides apply to all KI handlers on the KB (not per-KI). + Run tests with: ```bash @@ -418,3 +445,4 @@ These are excluded from linting (`ruff`) and are kept for historical reference o - **Handler introspection**: `KnowledgeInteractionContext.__post_init__` inspects handler signatures to auto-detect binding models, enabling transparent (de)serialization without manual type dispatch. Dispatch logic (validate → call → serialize for ANSWER/REACT; prepare_outgoing + parse_result for ASK/POST) lives in `KnowledgeInteractionContext`, not in `KnowledgeBase`. - **`KnowledgeBaseBuilder` wraps `KnowledgeBase`**: Settings-based KI registration belongs to `KnowledgeBaseBuilder`, not to `KnowledgeBase`. `KnowledgeBase.from_settings()` returns a builder; `builder.build()` returns the finished `KnowledgeBase`. `KnowledgeBase` itself has no knowledge of settings. ASK/POST KIs are auto-registered at `build()` time; ANSWER/REACT KIs require a handler attached via `builder.handler(name, func)` before `build()` is called. - **Dependency injection via `Depends`**: `KnowledgeInteractionContext.dispatch()` calls `resolve_dependencies(handler)` before invoking the handler, passing resolved values as kwargs. The resolver (`src/dependency_injection.py`) uses `get_type_hints(include_extras=True)` to find `Annotated[T, Depends(factory)]` params, recursively resolves factory deps (transitive), and caches results per invocation when `cache=True`. `@wraps` on the decorator wrapper preserves `__annotations__`, so the resolver sees the original handler's hints. +- **`dependency_overrides`**: `KnowledgeBase.dependency_overrides` is a `dict[Callable, Callable]` (à la FastAPI) that substitutes dependency factories at resolution time. Overrides are checked transitively at every level and inherit the original `Depends(cache=...)` setting. The dict is passed explicitly from `KnowledgeBase.call()` → `dispatch()` → `resolve_dependencies()` to keep `KnowledgeInteractionContext` decoupled from `KnowledgeBase`. diff --git a/src/knowledge_mapper/dependency_injection.py b/src/knowledge_mapper/dependency_injection.py index 3ad63f4..48212aa 100644 --- a/src/knowledge_mapper/dependency_injection.py +++ b/src/knowledge_mapper/dependency_injection.py @@ -61,6 +61,7 @@ def _get_dep_params(func: Callable[..., Any]) -> dict[str, Depends]: def resolve_dependencies( func: Callable[..., Any], cache: dict[Callable[..., Any], Any] | None = None, + overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None, ) -> dict[str, Any]: """Resolve all ``Annotated[T, Depends(...)]`` parameters of *func*. @@ -70,6 +71,11 @@ def resolve_dependencies( same dict for all calls within a single KI invocation so that ``cache=True`` factories are called at most once. Pass ``None`` to start fresh (a new empty dict will be created). + overrides: An optional mapping of original factory → replacement + factory. When a ``Depends`` factory appears as a key in this + dict, the corresponding override callable is invoked instead. + Overrides are checked transitively at every level of the + dependency tree. Returns: A dict mapping parameter name → resolved value for every @@ -82,13 +88,16 @@ def resolve_dependencies( resolved: dict[str, Any] = {} for param_name, dep in dep_params.items(): factory = dep.factory - if dep.cache and factory in cache: - resolved[param_name] = cache[factory] + actual_factory = ( + overrides[factory] if overrides and factory in overrides else factory + ) + if dep.cache and actual_factory in cache: + resolved[param_name] = cache[actual_factory] else: # Recursively resolve factory's own dependencies first - factory_kwargs = resolve_dependencies(factory, cache) - value = factory(**factory_kwargs) + factory_kwargs = resolve_dependencies(actual_factory, cache, overrides) + value = actual_factory(**factory_kwargs) if dep.cache: - cache[factory] = value + cache[actual_factory] = value resolved[param_name] = value return resolved diff --git a/src/knowledge_mapper/kb/knowledge_base.py b/src/knowledge_mapper/kb/knowledge_base.py index 22efc73..46b56ce 100644 --- a/src/knowledge_mapper/kb/knowledge_base.py +++ b/src/knowledge_mapper/kb/knowledge_base.py @@ -52,6 +52,7 @@ def __init__(self, id: str, name: str, description: str, ke_url: str): name=name, description=description, ) + self.dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] = {} @classmethod def from_settings(cls, settings: KnowledgeBaseSettings) -> KnowledgeBaseBuilder: @@ -391,7 +392,10 @@ def call(self, binding_set: BindingSet, ki_name: str) -> BindingSet: Raises: KeyError: If ``ki_name`` is not found in the local KI registry. """ - return self.ki_registry[ki_name].dispatch(binding_set) + return self.ki_registry[ki_name].dispatch( + binding_set, + dependency_overrides=self.dependency_overrides or None, + ) def post( self, binding_set: Sequence[BindingModel] | BindingSet, ki_name: str diff --git a/src/knowledge_mapper/knowledge_interaction.py b/src/knowledge_mapper/knowledge_interaction.py index 268f343..ecd43fa 100644 --- a/src/knowledge_mapper/knowledge_interaction.py +++ b/src/knowledge_mapper/knowledge_interaction.py @@ -36,7 +36,13 @@ def __post_init__(self): self.handler ) - def dispatch(self, binding_set: BindingSet) -> BindingSet: + def dispatch( + self, + binding_set: BindingSet, + dependency_overrides: ( + dict[Callable[..., Any], Callable[..., Any]] | None + ) = None, + ) -> BindingSet: """Validate incoming bindings, call the handler (with DI), and serialize the result back to a raw BindingSet. @@ -44,7 +50,7 @@ def dispatch(self, binding_set: BindingSet) -> BindingSet: """ assert self.handler is not None - dep_kwargs = resolve_dependencies(self.handler) + dep_kwargs = resolve_dependencies(self.handler, overrides=dependency_overrides) if self.validation_model: validated = [self.validation_model.model_validate(b) for b in binding_set] diff --git a/tests/test_dependency_injection.py b/tests/test_dependency_injection.py index 42570e0..6838059 100644 --- a/tests/test_dependency_injection.py +++ b/tests/test_dependency_injection.py @@ -135,3 +135,103 @@ def handler( result = kb.call([], "transitive-ki") assert result == [{"url": "sqlite://:memory:"}] + + +# --------------------------------------------------------------------------- +# dependency_overrides: FastAPI-style override mechanism +# --------------------------------------------------------------------------- + + +def test_dependency_override_replaces_factory(kb: KnowledgeBase): + """A factory listed in dependency_overrides is replaced at resolution time.""" + + class RealDb: + name = "real" + + class FakeDb: + name = "fake" + + def get_db() -> RealDb: + return RealDb() + + @kb.answer_ki(name="override-ki", graph_pattern="?s ?p ?o .") + def handler( + binding_set: BindingSet, + info, + db: Annotated[RealDb, Depends(get_db)], + ) -> BindingSet: + return [{"db": db.name}] + + # Without override — uses real factory + assert kb.call([], "override-ki") == [{"db": "real"}] + + # With override — uses fake factory + kb.dependency_overrides[get_db] = lambda: FakeDb() + assert kb.call([], "override-ki") == [{"db": "fake"}] + + # Clear override — back to real + kb.dependency_overrides.clear() + assert kb.call([], "override-ki") == [{"db": "real"}] + + +def test_dependency_override_transitive(kb: KnowledgeBase): + """Overriding a transitive (nested) factory propagates through the chain.""" + + class Config: + url = "prod://db" + + class TestConfig: + url = "test://db" + + class Db: + def __init__(self, config): + self.url = config.url + + def get_config() -> Config: + return Config() + + def get_db(config: Annotated[Config, Depends(get_config)]) -> Db: + return Db(config) + + @kb.answer_ki(name="transitive-override-ki", graph_pattern="?s ?p ?o .") + def handler( + binding_set: BindingSet, + info, + db: Annotated[Db, Depends(get_db)], + ) -> BindingSet: + return [{"url": db.url}] + + # Override the leaf dependency — get_db still runs but receives TestConfig + kb.dependency_overrides[get_config] = lambda: TestConfig() + assert kb.call([], "transitive-override-ki") == [{"url": "test://db"}] + + +def test_dependency_override_respects_cache(kb: KnowledgeBase): + """Override factory inherits the cache=True setting from the Depends declaration.""" + call_count = 0 + + def get_value(): + return "real" + + def fake_get_value(): + nonlocal call_count + call_count += 1 + return "fake" + + def get_service(val: Annotated[str, Depends(get_value)]): + return val + + @kb.answer_ki(name="cache-override-ki", graph_pattern="?s ?p ?o .") + def handler( + binding_set: BindingSet, + info, + val: Annotated[str, Depends(get_value)], + svc: Annotated[str, Depends(get_service)], + ) -> BindingSet: + assert val is svc # same cached instance + return [{"val": val}] + + kb.dependency_overrides[get_value] = fake_get_value + kb.call([], "cache-override-ki") + # fake_get_value should be called only once due to cache=True + assert call_count == 1 diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index 1634293..c40a8ca 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -70,12 +70,8 @@ def handler(binding_set: list[SensorBinding], info) -> list[SensorBinding]: def test_dispatch_react_typed(): """dispatch() works for REACT KIs with typed handlers.""" - def handler( - binding_set: list[MeasurementBinding], info - ) -> list[ResultBinding]: - return [ - ResultBinding(measurement=b.measurement) for b in binding_set - ] + def handler(binding_set: list[MeasurementBinding], info) -> list[ResultBinding]: + return [ResultBinding(measurement=b.measurement) for b in binding_set] ctx = KnowledgeInteractionContext( info=PostReactInteractionInfo(