Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions CONTEXT.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,33 @@ kb.start_handling_loop()
assert client.last_handle_response == [...]
```

### `dependency_overrides` — Overriding Dependencies in Tests

`KnowledgeBase.dependency_overrides` is a `dict[Callable, Callable]` that lets you replace dependency factories at test time, mirroring FastAPI's `app.dependency_overrides`.

```python
def get_db() -> RealDatabase:
return RealDatabase(url="postgresql://...")

# In production — handler receives RealDatabase
@kb.answer_ki(name="my-ki", graph_pattern="...")
def handler(
binding_set, info,
db: Annotated[RealDatabase, Depends(get_db)],
): ...

# In tests — swap the factory
kb.dependency_overrides[get_db] = lambda: FakeDatabase()

# Clear when done
kb.dependency_overrides.clear()
```

**Behaviour:**
- Overrides are **transitive**: overriding a leaf factory (e.g. `get_config`) propagates to all factories that depend on it.
- Override factories **inherit the `cache` setting** from the original `Depends()` declaration.
- Overrides apply to all KI handlers on the KB (not per-KI).

Run tests with:

```bash
Expand Down Expand Up @@ -418,3 +445,4 @@ These are excluded from linting (`ruff`) and are kept for historical reference o
- **Handler introspection**: `KnowledgeInteractionContext.__post_init__` inspects handler signatures to auto-detect binding models, enabling transparent (de)serialization without manual type dispatch. Dispatch logic (validate → call → serialize for ANSWER/REACT; prepare_outgoing + parse_result for ASK/POST) lives in `KnowledgeInteractionContext`, not in `KnowledgeBase`.
- **`KnowledgeBaseBuilder` wraps `KnowledgeBase`**: Settings-based KI registration belongs to `KnowledgeBaseBuilder`, not to `KnowledgeBase`. `KnowledgeBase.from_settings()` returns a builder; `builder.build()` returns the finished `KnowledgeBase`. `KnowledgeBase` itself has no knowledge of settings. ASK/POST KIs are auto-registered at `build()` time; ANSWER/REACT KIs require a handler attached via `builder.handler(name, func)` before `build()` is called.
- **Dependency injection via `Depends`**: `KnowledgeInteractionContext.dispatch()` calls `resolve_dependencies(handler)` before invoking the handler, passing resolved values as kwargs. The resolver (`src/dependency_injection.py`) uses `get_type_hints(include_extras=True)` to find `Annotated[T, Depends(factory)]` params, recursively resolves factory deps (transitive), and caches results per invocation when `cache=True`. `@wraps` on the decorator wrapper preserves `__annotations__`, so the resolver sees the original handler's hints.
- **`dependency_overrides`**: `KnowledgeBase.dependency_overrides` is a `dict[Callable, Callable]` (à la FastAPI) that substitutes dependency factories at resolution time. Overrides are checked transitively at every level and inherit the original `Depends(cache=...)` setting. The dict is passed explicitly from `KnowledgeBase.call()` → `dispatch()` → `resolve_dependencies()` to keep `KnowledgeInteractionContext` decoupled from `KnowledgeBase`.
19 changes: 14 additions & 5 deletions src/knowledge_mapper/dependency_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def _get_dep_params(func: Callable[..., Any]) -> dict[str, Depends]:
def resolve_dependencies(
func: Callable[..., Any],
cache: dict[Callable[..., Any], Any] | None = None,
overrides: dict[Callable[..., Any], Callable[..., Any]] | None = None,
) -> dict[str, Any]:
"""Resolve all ``Annotated[T, Depends(...)]`` parameters of *func*.

Expand All @@ -70,6 +71,11 @@ def resolve_dependencies(
same dict for all calls within a single KI invocation so that
``cache=True`` factories are called at most once. Pass ``None``
to start fresh (a new empty dict will be created).
overrides: An optional mapping of original factory → replacement
factory. When a ``Depends`` factory appears as a key in this
dict, the corresponding override callable is invoked instead.
Overrides are checked transitively at every level of the
dependency tree.

Returns:
A dict mapping parameter name → resolved value for every
Expand All @@ -82,13 +88,16 @@ def resolve_dependencies(
resolved: dict[str, Any] = {}
for param_name, dep in dep_params.items():
factory = dep.factory
if dep.cache and factory in cache:
resolved[param_name] = cache[factory]
actual_factory = (
overrides[factory] if overrides and factory in overrides else factory
)
if dep.cache and actual_factory in cache:
resolved[param_name] = cache[actual_factory]
else:
# Recursively resolve factory's own dependencies first
factory_kwargs = resolve_dependencies(factory, cache)
value = factory(**factory_kwargs)
factory_kwargs = resolve_dependencies(actual_factory, cache, overrides)
value = actual_factory(**factory_kwargs)
if dep.cache:
cache[factory] = value
cache[actual_factory] = value
resolved[param_name] = value
return resolved
6 changes: 5 additions & 1 deletion src/knowledge_mapper/kb/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, id: str, name: str, description: str, ke_url: str):
name=name,
description=description,
)
self.dependency_overrides: dict[Callable[..., Any], Callable[..., Any]] = {}

@classmethod
def from_settings(cls, settings: KnowledgeBaseSettings) -> KnowledgeBaseBuilder:
Expand Down Expand Up @@ -391,7 +392,10 @@ def call(self, binding_set: BindingSet, ki_name: str) -> BindingSet:
Raises:
KeyError: If ``ki_name`` is not found in the local KI registry.
"""
return self.ki_registry[ki_name].dispatch(binding_set)
return self.ki_registry[ki_name].dispatch(
binding_set,
dependency_overrides=self.dependency_overrides or None,
)

def post(
self, binding_set: Sequence[BindingModel] | BindingSet, ki_name: str
Expand Down
10 changes: 8 additions & 2 deletions src/knowledge_mapper/knowledge_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,21 @@ def __post_init__(self):
self.handler
)

def dispatch(self, binding_set: BindingSet) -> BindingSet:
def dispatch(
self,
binding_set: BindingSet,
dependency_overrides: (
dict[Callable[..., Any], Callable[..., Any]] | None
) = None,
) -> BindingSet:
"""Validate incoming bindings, call the handler (with DI), and serialize
the result back to a raw BindingSet.

Used by the handling loop for incoming ANSWER/REACT KI calls.
"""
assert self.handler is not None

dep_kwargs = resolve_dependencies(self.handler)
dep_kwargs = resolve_dependencies(self.handler, overrides=dependency_overrides)

if self.validation_model:
validated = [self.validation_model.model_validate(b) for b in binding_set]
Expand Down
57 changes: 53 additions & 4 deletions src/knowledge_mapper/testing/fake_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""In-memory FakeClient that satisfies ClientProtocol for use in tests."""

from collections import deque
from datetime import UTC, datetime

from knowledge_mapper.ke.client import ClientProtocol, PollResult
from knowledge_mapper.ke.client import ClientProtocol, HandleRequest, PollResult
from knowledge_mapper.ke.models import (
AskResult,
BindingSet,
Expand All @@ -26,6 +27,8 @@ def __init__(self, fake_url) -> None:
# Maps ki_name -> BindingSet to return from execute_post_interaction
self._mock_interaction_results: dict[str, BindingSet] = {}
self._handle_responses: list[tuple[str, str, int, BindingSet]] = []
self._incoming_calls: deque[tuple[PollResult, HandleRequest | None]] = deque()
self._next_handle_request_id: int = 1

def ke_is_available(self) -> bool:
return True
Expand Down Expand Up @@ -65,9 +68,9 @@ def register_ki(
self._knowledge_interactions.setdefault(kb_id, []).append(registered)
return registered

def poll_ki_call(self, kb_id: str) -> tuple[PollResult, None]:
# This fake client never returns any KI calls to handle, but always asks to
# repoll.
def poll_ki_call(self, kb_id: str) -> tuple[PollResult, HandleRequest | None]:
if self._incoming_calls:
return self._incoming_calls.popleft()
return (PollResult.REPOLL, None)

def post_handle_response(
Expand All @@ -87,6 +90,52 @@ def mock_result_binding_set(self, ki_name: str, binding_set: BindingSet) -> None
is called for the KI with the given name."""
self._mock_interaction_results[ki_name] = binding_set

def enqueue_handle_request(
self,
ki_name: str,
binding_set: BindingSet,
requesting_kb_id: str = "http://example.org/requesting-kb",
) -> None:
"""Queue an incoming KI call so ``poll_ki_call`` returns HANDLE for it.

Args:
ki_name: Name of a KI that has already been registered via
``register_ki``.
binding_set: The incoming binding set to pass to the handler.
requesting_kb_id: The ID of the requesting knowledge base
(defaults to a test sentinel).

Raises:
KeyError: If no registered KI with *ki_name* exists.
"""
ki = next(
(
ki
for kis in self._knowledge_interactions.values()
for ki in kis
if ki.name == ki_name
),
None,
)
if ki is None or ki.id is None:
raise KeyError(
f"No registered KI named '{ki_name}' found in TestClient. "
"Register the KI before enqueueing a handle request."
)

handle_request = HandleRequest(
knowledge_interaction_id=ki.id,
handle_request_id=self._next_handle_request_id,
binding_set=binding_set,
requesting_knowledge_base_id=requesting_kb_id,
)
self._next_handle_request_id += 1
self._incoming_calls.append((PollResult.HANDLE, handle_request))

def enqueue_exit(self) -> None:
"""Queue an EXIT signal so ``poll_ki_call`` terminates the handling loop."""
self._incoming_calls.append((PollResult.EXIT, None))

def ask(
self,
kb_id: str,
Expand Down
100 changes: 100 additions & 0 deletions tests/test_dependency_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,103 @@ def handler(

result = kb.call([], "transitive-ki")
assert result == [{"url": "sqlite://:memory:"}]


# ---------------------------------------------------------------------------
# dependency_overrides: FastAPI-style override mechanism
# ---------------------------------------------------------------------------


def test_dependency_override_replaces_factory(kb: KnowledgeBase):
"""A factory listed in dependency_overrides is replaced at resolution time."""

class RealDb:
name = "real"

class FakeDb:
name = "fake"

def get_db() -> RealDb:
return RealDb()

@kb.answer_ki(name="override-ki", graph_pattern="?s ?p ?o .")
def handler(
binding_set: BindingSet,
info,
db: Annotated[RealDb, Depends(get_db)],
) -> BindingSet:
return [{"db": db.name}]

# Without override — uses real factory
assert kb.call([], "override-ki") == [{"db": "real"}]

# With override — uses fake factory
kb.dependency_overrides[get_db] = lambda: FakeDb()
assert kb.call([], "override-ki") == [{"db": "fake"}]

# Clear override — back to real
kb.dependency_overrides.clear()
assert kb.call([], "override-ki") == [{"db": "real"}]


def test_dependency_override_transitive(kb: KnowledgeBase):
"""Overriding a transitive (nested) factory propagates through the chain."""

class Config:
url = "prod://db"

class TestConfig:
url = "test://db"

class Db:
def __init__(self, config):
self.url = config.url

def get_config() -> Config:
return Config()

def get_db(config: Annotated[Config, Depends(get_config)]) -> Db:
return Db(config)

@kb.answer_ki(name="transitive-override-ki", graph_pattern="?s ?p ?o .")
def handler(
binding_set: BindingSet,
info,
db: Annotated[Db, Depends(get_db)],
) -> BindingSet:
return [{"url": db.url}]

# Override the leaf dependency — get_db still runs but receives TestConfig
kb.dependency_overrides[get_config] = lambda: TestConfig()
assert kb.call([], "transitive-override-ki") == [{"url": "test://db"}]


def test_dependency_override_respects_cache(kb: KnowledgeBase):
"""Override factory inherits the cache=True setting from the Depends declaration."""
call_count = 0

def get_value():
return "real"

def fake_get_value():
nonlocal call_count
call_count += 1
return "fake"

def get_service(val: Annotated[str, Depends(get_value)]):
return val

@kb.answer_ki(name="cache-override-ki", graph_pattern="?s ?p ?o .")
def handler(
binding_set: BindingSet,
info,
val: Annotated[str, Depends(get_value)],
svc: Annotated[str, Depends(get_service)],
) -> BindingSet:
assert val is svc # same cached instance
return [{"val": val}]

kb.dependency_overrides[get_value] = fake_get_value
kb.call([], "cache-override-ki")
# fake_get_value should be called only once due to cache=True
assert call_count == 1
8 changes: 2 additions & 6 deletions tests/test_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,8 @@ def handler(binding_set: list[SensorBinding], info) -> list[SensorBinding]:
def test_dispatch_react_typed():
"""dispatch() works for REACT KIs with typed handlers."""

def handler(
binding_set: list[MeasurementBinding], info
) -> list[ResultBinding]:
return [
ResultBinding(measurement=b.measurement) for b in binding_set
]
def handler(binding_set: list[MeasurementBinding], info) -> list[ResultBinding]:
return [ResultBinding(measurement=b.measurement) for b in binding_set]

ctx = KnowledgeInteractionContext(
info=PostReactInteractionInfo(
Expand Down
Loading
Loading