From eb3e46b708c4c0c5e5d570920faaedfccb37ff69 Mon Sep 17 00:00:00 2001 From: Vedant Patel Date: Mon, 8 Jun 2026 12:02:53 -0700 Subject: [PATCH 1/3] =?UTF-8?q?feat(planner):=20graph-aware=20retrieval=20?= =?UTF-8?q?planner=20=E2=80=94=20opt-in=20capability?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a one-LLM-call up-front retrieval planner that produces a structured `RetrievalPlan` (expected answer type, priority predicates, optional hop sequence) from a compressed view of the relevant fact- graph slice. The plan biases — never replaces — the existing PPR + beam + triple-ANN + Cohere Rerank pipeline. This is the first capability in the OSS RAG starter space that pairs HippoRAG-2-style PPR with an explicit pre-retrieval plan, decided in one LLM call before any retrieval runs. Existing public KG-RAG systems (HippoRAG 2, MS GraphRAG, LightRAG, DAVIS) either skip planning entirely or do reactive agentic loops (latency cliffs). New modules: - src/engram/core/graph_view.py — CompressedGraphView, EntityNeighborhood, EdgeSummary + build_query_graph_view(). Pure logic over backend.fact_graph + backend.get_entity. Top-K by edge confidence, corpus-wide predicate histogram. - src/engram/dialogue/prompts/retrieval_plan.py — RetrievalPlan, HopStep schemas + build_retrieval_plan_prompt with 4 worked examples. Confidence calibration guidance (0.8-1.0 high, <0.3 rare-abstention). - src/engram/dialogue/retrieval_planner.py — async plan_retrieval with confidence-floor abstention (default 0.5), graceful no-op on empty view or LLM error, optional raw_plan_sink for diagnostics. - benchmarks/failure_tagger.py — Phase 0 LLM classifier that tagged 120 n=200 failures by mode; 22.5% are planner-addressable (mostly answer-type mismatches). Plumbing: - benchmarks/retrieval.py — kg_hybrid_neighbors gains `plan` kwarg. Plan drives: predicate_boost in beam_search_facts, post-fusion fact-type filter (capped at 30% removal so a wrong plan can't starve the reader), plan-aware Cohere Rerank query suffix. - src/engram/core/kg_retrieval.py — beam_search_facts gains predicate_boost + multiplier (default 1.5x). - benchmarks/runner.py — answer_one builds view + plan once per question (cached across IRCoT rounds), threads to retrieval. - benchmarks/musique.py — `--retrieval-planner` (default OFF) and `--trace-retrieval-plan PATH` flags. Pre-existing main-branch hardening, ported in this commit: - src/engram/backends/memory.py — _LMDB_MAX_KEY_BYTES (480) + _key_too_long() guards in entity / alias / fact upsert paths. Skip-with-warning when an LLM-extracted name would exceed LMDB's 511-byte key cap. Prevents cold-path BadValsizeError that was silently killing graph builds on the n=100 fixture. - src/engram/dialogue/orchestrator.py — exc_info=True on the swallowed background-task warning so future cold-path failures surface with tracebacks. Tests: 366/366 pass. 9 unit tests for graph_view, 8 for the planner dialogue, 3 integration tests for plan-biased retrieval (Plankton voiced_by chain, plan=None passthrough, filter-cap safety). n=100 ablation (kg-hybrid + IRCoT + synth OFF, same store): - No-planner baseline: EM 0.40, F1 0.5475 - Planner-on (refined prompt): EM 0.39, F1 0.5389; 8/100 plans fire confidently, run-to-run variance ±0.04 EM exceeds plausible signal. Verdict: shipping as opt-in capability, not metric-lift feature. Default OFF. The planner is correct in isolation (tests pass, fires when input is good) but the lift is bottlenecked by the upstream entity extractor producing 30-40% query-slot noise and by n=100 sample variance exceeding the +0.02 gate. Engram's selling point becomes "the only OSS RAG starter with explicit graph-aware planning + structured retrieval traces" — capability differentiation, not benchmark dominance. Co-Authored-By: Claude Opus 4.7 --- benchmarks/failure_tagger.py | 283 +++++++++++++ benchmarks/musique.py | 40 ++ benchmarks/retrieval.py | 159 ++++++- benchmarks/runner.py | 54 +++ src/engram/backends/memory.py | 38 ++ src/engram/core/graph_view.py | 259 ++++++++++++ src/engram/core/kg_retrieval.py | 20 + src/engram/dialogue/orchestrator.py | 6 +- src/engram/dialogue/prompts/__init__.py | 8 + src/engram/dialogue/prompts/retrieval_plan.py | 391 ++++++++++++++++++ src/engram/dialogue/retrieval_planner.py | 133 ++++++ .../test_kg_retrieval_with_plan.py | 244 +++++++++++ tests/unit/test_core_graph_view.py | 189 +++++++++ tests/unit/test_dialogue_retrieval_planner.py | 221 ++++++++++ 14 files changed, 2042 insertions(+), 3 deletions(-) create mode 100644 benchmarks/failure_tagger.py create mode 100644 src/engram/core/graph_view.py create mode 100644 src/engram/dialogue/prompts/retrieval_plan.py create mode 100644 src/engram/dialogue/retrieval_planner.py create mode 100644 tests/integration/test_kg_retrieval_with_plan.py create mode 100644 tests/unit/test_core_graph_view.py create mode 100644 tests/unit/test_dialogue_retrieval_planner.py diff --git a/benchmarks/failure_tagger.py b/benchmarks/failure_tagger.py new file mode 100644 index 0000000..efac83a --- /dev/null +++ b/benchmarks/failure_tagger.py @@ -0,0 +1,283 @@ +"""Phase 0 failure tagger for the graph-aware planner branch. + +One LLM call per failed prediction classifies the failure mode. Used +once, before any planner implementation, to confirm there's a real +addressable surface for the planner. Output drives the decision gate: +if wrong-hop-target + answer-type-mismatch failures together account +for less than ~15% of all failures, the planner approach is killed +before any code is written. + +Run as a script: + + .venv/bin/python -m benchmarks.failure_tagger \\ + --input predictions_ircot_synth_on.jsonl \\ + --mode baseline \\ + --output tagged_failures.jsonl + +Failures are defined by exact-match (case + punct + article normalized); +the classifier only sees the cases that failed under that scorer. +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import re +import sys +from collections.abc import Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +from pydantic import BaseModel, ConfigDict, Field + +if TYPE_CHECKING: + from engram.core.protocol import LLMProvider + + +logger = logging.getLogger(__name__) + + +FailureCategory = Literal[ + "format-mismatch", + "wrong-hop-target", + "hallucination", + "missing-info", + "answer-type-mismatch", + "other", +] + + +class FailureTag(BaseModel): + """LLM verdict for one failed prediction.""" + + model_config = ConfigDict(extra="forbid") + + reasoning: str = Field( + description="Brief justification before the category. One or two sentences." + ) + category: FailureCategory = Field( + description="The single best-fit failure category." + ) + + +_SYSTEM_PROMPT = """\ +You are a benchmark failure-mode classifier. You see a question, the +gold answer(s) the benchmark expects, and a model's prediction that +scored 0 on exact-match. Your job: pick the SINGLE best-fit failure +category from the list below. + +Categories (pick exactly one): + +- format-mismatch: + The prediction CONTAINS the correct answer but adds extra words, + explanations, modifiers, or has different formatting. A human + reading the prediction would see the right answer is in there. + Examples: + Gold "blackmail" Pred "Japanese blackmail." + Gold "Amy Poehler" Pred "Amy Poehler (previously married to Will Arnett)." + Gold "Association for Computing Machinery" Pred "ACM (Association for Computing Machinery)" + Gold "22" Pred "22 times" + +- wrong-hop-target: + Multi-hop question; the prediction is a wrong intermediate entity + or the wrong final entity from the reasoning chain. The model + landed on a nearby-but-wrong node. NOT a missing-info case. + Examples: + Q: "Who voices the character X?" Gold "Mr. Lawrence" Pred "Plankton" (character, not voice actor) + Q: "Who owns the record label of the performer of X?" + Gold "Warner Music Group" Pred "Shawn Lane owns Eye Reckon Records." + Q: "Spouse of the founder of X?" Gold "Real Spouse" Pred "The founder is Y" (gave intermediate) + +- hallucination: + Prediction is a confident wrong answer that doesn't appear in any + reasonable evidence and isn't a hop-confusion. Often a plausible + but unrelated entity, or invented detail. + Examples: + Gold "Michael Buble" Pred "Josh Groban" (similar singer, no chain to support it) + Gold "1929" Pred "1492" (invented date) + +- missing-info: + Prediction expresses inability to answer — "unknown", + "evidence does not provide", "no information found", "cannot be + determined", etc. The retrieval failed. + Examples: + Pred "The evidence does not provide information about ..." + Pred "Unknown" + +- answer-type-mismatch: + Prediction is the wrong TYPE of thing the question asked for. + Question asked WHO; prediction gave a number/value. Question asked + HOW LONG; prediction gave a date. Question asked WHERE; prediction + gave a person. + Examples: + Q: "Who has the lowest batting average?" Gold "Bill Bergen" Pred ".170" (gave the value, not the person) + Q: "How long?" Gold "4 years" Pred "2005" (gave a date, not a duration) + +- other: + Doesn't cleanly fit any above. Use sparingly. + +Rules: + +- Pick exactly one category. If two categories seem plausible, pick + the one that more strongly explains the failure. +- format-mismatch is the LOOSEST category — only use it when the + exact gold string (or a tiny variant) is clearly present in the + prediction. If the prediction has the wrong PRIMARY entity, prefer + wrong-hop-target even if some token overlap exists. +- wrong-hop-target is for multi-hop reasoning errors specifically. + Single-hop "wrong entity" with no chain reasoning is hallucination. +- Fill reasoning briefly before the category. +""" + + +def build_tag_prompt( + *, question: str, gold_answers: Sequence[str], prediction: str +) -> list[dict[str, str]]: + gold_block = " | ".join(f"{g!r}" for g in gold_answers) + user_msg = ( + f"Question: {question}\n" + f"Gold answer(s): {gold_block}\n" + f"Prediction: {prediction!r}\n\n" + "Classify this failure. Return a FailureTag." + ) + return [ + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": user_msg}, + ] + + +async def tag_failure( + *, + question: str, + gold_answers: Sequence[str], + prediction: str, + llm: LLMProvider, + model: str | None = None, +) -> FailureTag | None: + """Return a single FailureTag or None on LLM error.""" + messages = build_tag_prompt( + question=question, gold_answers=gold_answers, prediction=prediction + ) + try: + return await llm.extract( + messages, FailureTag, model=model, temperature=0.0 + ) + except Exception as exc: + logger.warning("Tag failed for question %r: %s", question[:60], exc) + return None + + +_PUNCT_RE = re.compile(r"[^a-z0-9 ]") +_ARTICLE_RE = re.compile(r"\b(a|an|the)\b") + + +def _normalize(s: str) -> str: + s = s.lower() + s = _ARTICLE_RE.sub(" ", s) + s = _PUNCT_RE.sub(" ", s) + return " ".join(s.split()) + + +def is_em_failure(prediction: str, gold_answers: Sequence[str]) -> bool: + """SQuAD-style EM check; failure when no gold matches.""" + pred_n = _normalize(prediction) + return not any(pred_n == _normalize(g) for g in gold_answers) + + +async def main_async(args: argparse.Namespace) -> int: + from engram.llm.litellm_provider import LiteLLMProvider + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + + preds = [json.loads(line) for line in args.input.open()] + if args.mode: + preds = [p for p in preds if p.get("mode") == args.mode] + + failures = [ + p for p in preds if is_em_failure(p["prediction"], p["gold_answers"]) + ] + logger.info( + "Loaded %d predictions; %d failures (mode=%s)", + len(preds), + len(failures), + args.mode or "all", + ) + + llm = LiteLLMProvider(default_model=args.model) + sem = asyncio.Semaphore(args.concurrency) + tagged: list[dict] = [] + + async def _bounded(p: dict) -> None: + async with sem: + tag = await tag_failure( + question=p["question"], + gold_answers=p["gold_answers"], + prediction=p["prediction"], + llm=llm, + model=args.model, + ) + tagged.append( + { + "question_id": p["question_id"], + "question": p["question"], + "gold_answers": p["gold_answers"], + "prediction": p["prediction"], + "category": tag.category if tag else "other", + "reasoning": tag.reasoning if tag else "(tag failed)", + } + ) + + await asyncio.gather(*[_bounded(p) for p in failures]) + + with args.output.open("w") as fh: + for record in tagged: + fh.write(json.dumps(record) + "\n") + logger.info("Wrote %d tagged records to %s", len(tagged), args.output) + + from collections import Counter + + counts = Counter(r["category"] for r in tagged) + total = len(tagged) or 1 + print("\nFailure category histogram:") + for cat, n in counts.most_common(): + print(f" {cat:25s} {n:4d} ({100*n/total:.1f}%)") + + planner_addressable = counts.get("wrong-hop-target", 0) + counts.get( + "answer-type-mismatch", 0 + ) + pct = 100 * planner_addressable / total + print( + f"\nPlanner-addressable (wrong-hop + answer-type-mismatch): " + f"{planner_addressable}/{total} = {pct:.1f}%" + ) + print(f"Gate (>=15%): {'PASS' if pct >= 15 else 'FAIL'}") + return 0 + + +def _parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--input", type=Path, required=True) + p.add_argument( + "--mode", + type=str, + default=None, + help="Filter predictions by mode ('baseline' or 'enriched'). Default: all.", + ) + p.add_argument("--output", type=Path, required=True) + p.add_argument("--model", type=str, default="openai/gpt-4o-mini") + p.add_argument("--concurrency", type=int, default=8) + return p.parse_args() + + +def main() -> int: + args = _parse_args() + return asyncio.run(main_async(args)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/benchmarks/musique.py b/benchmarks/musique.py index 8fae67a..5082a52 100644 --- a/benchmarks/musique.py +++ b/benchmarks/musique.py @@ -296,6 +296,29 @@ def _parse_args() -> argparse.Namespace: "layer — cleaner ablation." ), ) + parser.add_argument( + "--retrieval-planner", + action="store_true", + default=False, + help=( + "Graph-aware retrieval planner: one LLM call per question " + "produces an explicit hop sequence + expected answer type + " + "priority predicates that bias kg_hybrid_neighbors (beam " + "predicate boost, fact-object type filter, plan-aware " + "rerank query). Requires --kg-retrieval. Bypasses cleanly " + "when planner abstains (confidence < 0.5)." + ), + ) + parser.add_argument( + "--trace-retrieval-plan", + type=Path, + default=None, + help=( + "Diagnostic JSONL dump of (question, query_entities, view, " + "plan) per question when --retrieval-planner is on. Used " + "for manual review of plan quality during smoke runs." + ), + ) parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging.") return parser.parse_args() @@ -374,6 +397,9 @@ async def _amain(args: argparse.Namespace) -> int: summaries: dict[str, dict] = {} all_predictions: list = [] backends: list[MemoryBackend] = [] + plan_trace: list[dict] | None = ( + [] if args.trace_retrieval_plan else None + ) try: for mode in modes: @@ -477,6 +503,10 @@ async def _amain(args: argparse.Namespace) -> int: decomposer_llm=reader if use_decomposition else None, decomposer_model=args.reader_model if use_decomposition else None, reranker=cohere_reranker, + use_retrieval_planner=args.retrieval_planner, + planner_llm=reader if args.retrieval_planner else None, + planner_model=args.reader_model if args.retrieval_planner else None, + plan_trace=plan_trace, ) summary = score_predictions(predictions_for_scoring(predictions)) summaries[mode] = summary.as_dict() @@ -513,6 +543,16 @@ async def _amain(args: argparse.Namespace) -> int: + "\n" ) logger.info("Wrote %d predictions to %s", len(all_predictions), args.output) + + if args.trace_retrieval_plan and plan_trace is not None: + with args.trace_retrieval_plan.open("w") as fh: + for record in plan_trace: + fh.write(json.dumps(record) + "\n") + logger.info( + "Wrote %d retrieval-plan trace records to %s", + len(plan_trace), + args.trace_retrieval_plan, + ) finally: for backend in backends: backend.close() diff --git a/benchmarks/retrieval.py b/benchmarks/retrieval.py index e92b1ca..ca2fcd6 100644 --- a/benchmarks/retrieval.py +++ b/benchmarks/retrieval.py @@ -35,6 +35,7 @@ from engram.backends.memory import MemoryBackend from engram.core.models import Chunk, Fact from engram.core.protocol import CorpusBackend + from engram.dialogue.prompts import RetrievalPlan logger = logging.getLogger(__name__) @@ -459,6 +460,7 @@ async def kg_hybrid_neighbors( ppr_top_k: int = 100, beam_top_k: int = 100, min_fact_confidence: float = 0.7, + plan: RetrievalPlan | None = None, ) -> tuple[list[Chunk], list[Fact]]: """KG-hybrid retrieval entry point — Phase 6 fusion of all paths. @@ -524,8 +526,22 @@ async def kg_hybrid_neighbors( ppr_ranked = ( two_stage_ppr_facts(entities, backend, top_k=ppr_top_k) if entities else [] ) + # Plan-aware bias: when a retrieval plan is provided, beam search + # boosts edges whose predicate matches the planner's priority list. + # PPR / triple_match stay untouched — only beam gets the bias, because + # beam is the precision-oriented arm of the fusion (PPR is recall). + beam_predicate_boost: Sequence[str] | None = ( + plan.priority_predicates if plan is not None and plan.priority_predicates else None + ) beam_ranked = ( - beam_search_facts(entities, backend, top_k=beam_top_k) if entities else [] + beam_search_facts( + entities, + backend, + top_k=beam_top_k, + predicate_boost=beam_predicate_boost, + ) + if entities + else [] ) # Three-source RRF fusion: triple-match (semantic), PPR (probabilistic @@ -551,6 +567,11 @@ async def kg_hybrid_neighbors( fact = backend.get_fact(fid) if fact is not None: fact_objects.append(fact) + # Plan-aware filter: drop facts whose object type doesn't match the + # planner's expected_answer_type. Capped at 30% removal so a wrong + # plan can't starve the reader. + if plan is not None: + fact_objects = _apply_plan_filter(fact_objects, plan, backend) kg_chunk_ids = facts_to_chunk_ids(fact_objects) hybrid_chunk_ids = [c.id for c in hybrid_chunks] @@ -576,12 +597,146 @@ async def kg_hybrid_neighbors( capped = _cap_by_context_budget(out) if reranker is not None and capped: - reranked = await reranker.rerank(query, capped) + rerank_query = _plan_aware_rerank_query(query, plan) + reranked = await reranker.rerank(rerank_query, capped) deduped = _dedupe_after_rerank(reranked) capped = _cap_by_context_budget(deduped) return capped, fact_objects +_YEAR_RE = re.compile(r"^\s*\d{3,4}\s*$") +_NUMBER_RE = re.compile(r"^\s*[\d,.\s]+\s*(times|years|months|days|hours|minutes)?\s*$") +_MONEY_RE = re.compile( + r"(?:\$|usd\b|eur\b|gbp\b|\b\d+\s*(?:million|billion|thousand|m|bn|k)\b)", + re.IGNORECASE, +) +_DURATION_RE = re.compile( + r"^\s*[\d.]+\s*(?:years?|months?|weeks?|days?|hours?|minutes?|seconds?|" + r"decades?|centuries?|millennia)\s*(?:old)?\s*$", + re.IGNORECASE, +) +_DATE_RE = re.compile( + r"(?:january|february|march|april|may|june|july|august|september|october|" + r"november|december|jan|feb|mar|apr|may|jun|jul|aug|sep|oct|nov|dec)\b", + re.IGNORECASE, +) + +# Mapping of planner answer-type tokens to (regex_predicate | None) for +# literal-value matching. Entries with None must be resolved via +# backend.get_entity() lookup against the EntityRecord entity_type. +_LITERAL_TYPE_MATCHERS: dict[str, re.Pattern[str]] = { + "year": _YEAR_RE, + "number": _NUMBER_RE, + "money": _MONEY_RE, + "duration": _DURATION_RE, + "date": _DATE_RE, +} + +# Tokens used in the planner's answer-type vocab that correspond to +# entity-record entity_type strings (case-insensitive match). +_ENTITY_TYPE_TOKENS = { + "person", + "organization", + "location", + "product", + "project", + "event", + "concept", +} + +# Maximum fraction of fused facts the answer-type filter is allowed to +# remove. Safety floor against an over-aggressive plan that would +# starve the reader of evidence. If more than this would be dropped, +# the filter is disabled for this query entirely. +_PLAN_FILTER_REMOVAL_CAP = 0.30 + + +def _fact_object_matches_answer_type( + obj: str, answer_type: str, backend: MemoryBackend +) -> bool: + """Heuristic: does this fact's object look like the planner's expected type? + + Returns True liberally — only returns False when we have positive + evidence the type does NOT match. Unknown / ambiguous cases default + to True so the filter never drops a fact we're unsure about. + """ + if not obj or not answer_type: + return True + answer_type_lower = answer_type.strip().lower() + if answer_type_lower in ("other", "title", "language"): + # These are too loose / corpus-specific to filter on reliably. + return True + + matcher = _LITERAL_TYPE_MATCHERS.get(answer_type_lower) + if matcher is not None: + return bool(matcher.search(obj)) + + if answer_type_lower in _ENTITY_TYPE_TOKENS: + record = backend.get_entity(obj) + if record is None: + # No EntityRecord = likely a literal value, not the entity + # type the planner expects. + return False + return record.entity_type.strip().lower() == answer_type_lower + + return True + + +def _apply_plan_filter( + facts: list[Fact], + plan: RetrievalPlan, + backend: MemoryBackend, +) -> list[Fact]: + """Drop facts whose object doesn't match the plan's expected_answer_type. + + Capped at :data:`_PLAN_FILTER_REMOVAL_CAP` of the input — if more + than that would be dropped, return the input unchanged (the planner + is likely wrong about the answer type and we'd rather take the + recall hit on this query than starve the reader). + """ + answer_type = (plan.expected_answer_type or "").strip().lower() + if not answer_type or not facts: + return facts + if answer_type in ("other", "title", "language"): + return facts + + kept: list[Fact] = [] + dropped: list[Fact] = [] + for fact in facts: + if _fact_object_matches_answer_type(fact.object, answer_type, backend): + kept.append(fact) + else: + dropped.append(fact) + + if not kept: + return facts + removal_ratio = len(dropped) / max(1, len(facts)) + if removal_ratio > _PLAN_FILTER_REMOVAL_CAP: + logger.debug( + "plan filter disabled for this query: would drop %.0f%% (cap %.0f%%)", + removal_ratio * 100, + _PLAN_FILTER_REMOVAL_CAP * 100, + ) + return facts + return kept + + +def _plan_aware_rerank_query(query: str, plan: RetrievalPlan | None) -> str: + """Append plan hints to the rerank query so Cohere internalizes them.""" + if plan is None: + return query + hints: list[str] = [] + if plan.expected_answer_type: + hints.append(f"Expected answer type: {plan.expected_answer_type}") + if plan.priority_predicates: + hints.append( + "Important relations: " + ", ".join(plan.priority_predicates[:5]) + ) + if not hints: + return query + return query + " | " + " | ".join(hints) + + def _fact_to_dict(fact: Fact) -> dict: """Marshal a ``Fact`` into the dict shape ``merge_fact_strategies`` expects. diff --git a/benchmarks/runner.py b/benchmarks/runner.py index 0c98365..7e16bfe 100644 --- a/benchmarks/runner.py +++ b/benchmarks/runner.py @@ -386,6 +386,10 @@ async def answer_one( decomposer_llm: LLMProvider | None = None, decomposer_model: str | None = None, reranker: CohereReranker | None = None, + use_retrieval_planner: bool = False, + planner_llm: LLMProvider | None = None, + planner_model: str | None = None, + plan_trace: list[dict] | None = None, ) -> QuestionPrediction: """Retrieve adaptively, then verbose-CoT-answer-then-extract. @@ -451,6 +455,47 @@ async def answer_one( model=entity_extractor_model, ) + # Build the retrieval plan once per question, BEFORE any retrieval + # round. Both IRCoT rounds share the same plan — it's a property of + # the question, not of the retrieved evidence (that's IRCoT's job). + retrieval_plan = None + if ( + use_retrieval_planner + and use_kg_retrieval + and planner_llm is not None + and query_entities + ): + from engram.core.graph_view import build_query_graph_view + from engram.dialogue.retrieval_planner import plan_retrieval + + view = build_query_graph_view( + question.question, query_entities, backend # type: ignore[arg-type] + ) + # Capture the raw plan (pre-floor) for diagnostics; only the + # post-floor result is used downstream. + raw_plan_sink: list = [] + retrieval_plan = await plan_retrieval( + view=view, + provider=planner_llm, + model=planner_model, + raw_plan_sink=raw_plan_sink, + ) + if plan_trace is not None: + raw = raw_plan_sink[0] if raw_plan_sink else None + plan_trace.append( + { + "question_id": question.id, + "question": question.question, + "query_entities": list(query_entities), + "unresolved_mentions": list(view.unresolved_mentions), + "resolved_entities": [e.name for e in view.entities], + "plan": ( + retrieval_plan.model_dump() if retrieval_plan else None + ), + "raw_plan": raw.model_dump() if raw is not None else None, + } + ) + async def _retrieve(query_text: str) -> tuple[list[Chunk], list[Fact]]: """Retrieval dispatch. Returns ``(neighbors, matched_facts)``.""" if use_kg_retrieval: @@ -461,6 +506,7 @@ async def _retrieve(query_text: str) -> tuple[list[Chunk], list[Fact]]: bm25=bm25, extra_queries=extra_queries, reranker=reranker, + plan=retrieval_plan, ) if use_graph_retrieval: return await graph_aware_neighbors( @@ -592,6 +638,10 @@ async def run_many( decomposer_model: str | None = None, reranker: CohereReranker | None = None, checkpoint_every: int = 10, + use_retrieval_planner: bool = False, + planner_llm: LLMProvider | None = None, + planner_model: str | None = None, + plan_trace: list[dict] | None = None, ) -> list[QuestionPrediction]: """Score every question against the already-ingested ``backend``. @@ -627,6 +677,10 @@ async def _bounded(q: MusiqueQuestion) -> QuestionPrediction | None: decomposer_llm=decomposer_llm, decomposer_model=decomposer_model, reranker=reranker, + use_retrieval_planner=use_retrieval_planner, + planner_llm=planner_llm, + planner_model=planner_model, + plan_trace=plan_trace, ) except Exception: logger.exception("Question %s failed in %s mode", q.id, mode) diff --git a/src/engram/backends/memory.py b/src/engram/backends/memory.py index ea59415..5bd4132 100644 --- a/src/engram/backends/memory.py +++ b/src/engram/backends/memory.py @@ -71,6 +71,17 @@ _META_DIM = b"dim" _META_FACT_DIM = b"fact_dim" +# LMDB's compile-time MDB_MAXKEYSIZE is 511 bytes. Index keys and +# entity-name keys are constructed from LLM-extracted strings which +# can occasionally be runaway-long (e.g. a model emitting an entire +# sentence as an "entity name"). Going over the limit blows up the +# whole cold-path write transaction. We keep a 31-byte safety margin. +_LMDB_MAX_KEY_BYTES = 480 + + +def _key_too_long(key: bytes) -> bool: + return len(key) > _LMDB_MAX_KEY_BYTES + def _norm(text: str) -> str: return text.strip().lower() @@ -593,6 +604,14 @@ def _resolve_or_create_entity_blocking( norm = normalize_entity_name(name) if not norm: return name + if _key_too_long(norm.encode()): + logger.warning( + "Skipping entity resolution: normalized name exceeds LMDB max " + "(name[:80]=%r, len=%d)", + name[:80], + len(norm.encode()), + ) + return name with self._env.begin(write=True) as txn: ebn_db = self._dbs[b"entity_by_name"] @@ -688,6 +707,14 @@ def _register_aliases( alias_norm = normalize_entity_name(cand) if not alias_norm or alias_norm == canon_norm: continue + if _key_too_long(alias_norm.encode()): + logger.warning( + "Skipping alias write: normalized alias exceeds LMDB max " + "(alias[:80]=%r, len=%d)", + cand[:80], + len(alias_norm.encode()), + ) + continue txn.put(alias_norm.encode(), canon_norm.encode(), db=ea_db) def get_entity(self, name: str) -> EntityRecord | None: @@ -917,6 +944,17 @@ def _upsert_facts_blocking( meta_db = self._dbs[b"meta"] for i, fact in enumerate(facts): + index_keys = self._index_keys_for(fact) + if any(_key_too_long(k) for k in index_keys): + logger.warning( + "Skipping fact %s: index key exceeds LMDB max (subject=%r, " + "predicate=%r, object=%r)", + fact.id, + fact.subject[:80], + fact.predicate[:80], + fact.object[:80], + ) + continue existing = txn.get(fact.id.encode(), db=facts_db) if existing is not None: prior = Fact.model_validate_json(existing) diff --git a/src/engram/core/graph_view.py b/src/engram/core/graph_view.py new file mode 100644 index 0000000..d13d657 --- /dev/null +++ b/src/engram/core/graph_view.py @@ -0,0 +1,259 @@ +"""Compressed graph view for the query-time retrieval planner. + +Given a list of query mentions, build a small structured snapshot of the +relevant slice of the fact graph: each mention's 1-hop neighborhood +(top-K edges by confidence) plus a corpus-wide predicate histogram for +context. The LLM planner uses this snapshot to write a structured +retrieval plan (which predicates to weight, what answer type to expect) +without ever seeing raw triples or the full graph. + +This module is pure logic over the existing ``MemoryBackend.fact_graph`` +(a ``networkx.MultiDiGraph``) and ``backend.get_entity``. No new storage, +no LLM calls — it is the cheap-to-compute input that makes the planner +graph-aware. + +Design notes: + +- Entity-type lookup falls through to ``"unknown"`` when a node isn't a + tracked ``EntityRecord``. The ingest path doesn't guarantee every + graph node has a record (e.g. literal objects like dates / numbers + surface as nodes with no record). +- Edge sampling is top-K by confidence, not random. High-confidence + edges are what the planner should reason about; tail edges add noise. +- Hub entities (degree > 100) are common — they're the entities mentioned + in many facts. We still take top-K by confidence; the cap keeps the + prompt token-budget bounded. +- The corpus-wide predicate histogram is O(fact_count) on every call. At + Engram's MuSiQue scale (~2000 facts per n=100 corpus) that's fast. + If it shows up as a hot loop on bigger corpora, cache it on the backend. +""" + +from __future__ import annotations + +from collections import Counter +from collections.abc import Iterable, Sequence +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from engram.core.entities import normalize_entity_name + +if TYPE_CHECKING: + from engram.backends.memory import MemoryBackend + + +DEFAULT_MAX_EDGES_PER_ENTITY = 6 +"""Cap on edges shown per query mention. Each edge is one prompt line: +six is enough to convey what predicates surround the entity without +blowing up token cost. Tuned to keep a 5-entity view under ~600 tokens.""" + +DEFAULT_PREDICATE_TOP_N = 30 +"""Corpus-wide predicate histogram cap surfaced to the planner. Top-30 +covers the long tail of MuSiQue predicates without flooding the prompt.""" + + +@dataclass(frozen=True) +class EdgeSummary: + """One graph edge condensed for the planner prompt.""" + + predicate: str + """The relation label (e.g. ``founded_by``, ``capital_of``).""" + + other_name: str + """The opposite endpoint's display name (subject for in_edges, + object for out_edges).""" + + other_type: str + """The opposite endpoint's ``EntityRecord.entity_type``, or + ``"unknown"`` when no record exists (typical for literal values + like dates and numbers).""" + + confidence: float + """Edge confidence from the underlying ``Fact``. Used for top-K + sorting; surfaced so the planner can downweight low-confidence + evidence.""" + + +@dataclass(frozen=True) +class EntityNeighborhood: + """A single mention's structural fingerprint in the graph.""" + + name: str + """The display name of the entity in the graph. May differ from the + raw query mention if the resolver matched a variant.""" + + entity_type: str + """``EntityRecord.entity_type`` for this entity, or ``"unknown"`` + when no record exists.""" + + out_edges: list[EdgeSummary] = field(default_factory=list) + """Outgoing edges: this entity is the subject. Top-K by confidence.""" + + in_edges: list[EdgeSummary] = field(default_factory=list) + """Incoming edges: this entity is the object. Top-K by confidence.""" + + total_degree: int = 0 + """Total in + out degree before sampling. Lets the planner see when + a node is a hub even though we only show top-K edges.""" + + +@dataclass(frozen=True) +class CompressedGraphView: + """The full prompt input for the retrieval planner.""" + + query: str + """The original question. Echoed back so the prompt can be a single + self-contained string.""" + + entities: list[EntityNeighborhood] + """One :class:`EntityNeighborhood` per query mention that resolved + to a graph node. Mentions that didn't resolve are dropped silently — + the planner gets nothing to plan from for them anyway.""" + + available_predicates: list[tuple[str, int]] + """Top-N predicates in the corpus by count, descending. Pairs of + ``(predicate, fact_count)`` so the planner can weight rare vs hub + predicates differently when proposing priority_predicates.""" + + unresolved_mentions: list[str] + """Mentions the resolver couldn't link to a graph node. Surfaced + for diagnostics so the planner can choose to fall back if its + confidence is low.""" + + +def _resolve_to_graph_node(mention: str, graph) -> str | None: + """Cheap deterministic resolution: exact -> normalized. + + We deliberately do NOT run the substring fallback here — substring + matches are too ambiguous to plan against (see the parked + feat/slm-voices entity-link voice attempt for why). Mentions that + don't cleanly resolve are reported as ``unresolved_mentions`` and + the planner can decide what to do. + """ + if not mention or not mention.strip(): + return None + if mention in graph: + return mention + norm = normalize_entity_name(mention) + if not norm: + return None + for node in graph.nodes(): + if normalize_entity_name(node) == norm: + return node + return None + + +def _top_k_edges_by_confidence( + edges: Iterable[tuple[str, str, str, dict]], + *, + k: int, + direction: str, + backend: MemoryBackend, +) -> list[EdgeSummary]: + """Build EdgeSummary list from raw networkx edges, top-K by confidence. + + ``direction`` is ``"out"`` (this entity is u) or ``"in"`` (this + entity is v). Determines which endpoint becomes ``other_name``. + """ + summaries: list[tuple[float, EdgeSummary]] = [] + for u, v, _key, data in edges: + other = v if direction == "out" else u + record = backend.get_entity(other) + other_type = record.entity_type if record is not None else "unknown" + confidence = float(data.get("confidence", 0.0)) + summaries.append( + ( + confidence, + EdgeSummary( + predicate=str(data.get("predicate", "related_to")), + other_name=other, + other_type=other_type, + confidence=confidence, + ), + ) + ) + summaries.sort(key=lambda x: x[0], reverse=True) + return [es for _, es in summaries[:k]] + + +def _predicate_histogram( + backend: MemoryBackend, top_n: int +) -> list[tuple[str, int]]: + """Top-N corpus predicates by fact count.""" + counts: Counter[str] = Counter() + for fact in backend.iter_facts(): + if fact.predicate: + counts[fact.predicate] += 1 + return counts.most_common(top_n) + + +def build_query_graph_view( + query: str, + mentions: Sequence[str], + backend: MemoryBackend, + *, + max_edges_per_entity: int = DEFAULT_MAX_EDGES_PER_ENTITY, + predicate_top_n: int = DEFAULT_PREDICATE_TOP_N, +) -> CompressedGraphView: + """Assemble a planner-ready snapshot of the graph slice around ``mentions``. + + ``mentions`` is the output of ``benchmarks.entities.extract_query_entities`` + (or any other source of candidate query entities). Each mention is + resolved by exact / normalize match against the graph; unresolved + mentions are reported separately rather than guessed at. + + Returns a :class:`CompressedGraphView` ready to be rendered into the + planner prompt. Safe to call when the graph is empty — returns an + empty view (planner will abstain). + """ + graph = backend.fact_graph + entities: list[EntityNeighborhood] = [] + unresolved: list[str] = [] + + for mention in mentions: + node = _resolve_to_graph_node(mention, graph) + if node is None: + if mention and mention.strip(): + unresolved.append(mention) + continue + + record = backend.get_entity(node) + entity_type = record.entity_type if record is not None else "unknown" + + out_raw = list(graph.out_edges(node, keys=True, data=True)) + in_raw = list(graph.in_edges(node, keys=True, data=True)) + + out_summaries = _top_k_edges_by_confidence( + out_raw, k=max_edges_per_entity, direction="out", backend=backend + ) + in_summaries = _top_k_edges_by_confidence( + in_raw, k=max_edges_per_entity, direction="in", backend=backend + ) + + entities.append( + EntityNeighborhood( + name=node, + entity_type=entity_type, + out_edges=out_summaries, + in_edges=in_summaries, + total_degree=len(out_raw) + len(in_raw), + ) + ) + + predicates = _predicate_histogram(backend, predicate_top_n) + + return CompressedGraphView( + query=query, + entities=entities, + available_predicates=predicates, + unresolved_mentions=unresolved, + ) + + +__all__ = [ + "DEFAULT_MAX_EDGES_PER_ENTITY", + "DEFAULT_PREDICATE_TOP_N", + "CompressedGraphView", + "EdgeSummary", + "EntityNeighborhood", + "build_query_graph_view", +] diff --git a/src/engram/core/kg_retrieval.py b/src/engram/core/kg_retrieval.py index 02df231..72b7bcf 100644 --- a/src/engram/core/kg_retrieval.py +++ b/src/engram/core/kg_retrieval.py @@ -303,6 +303,13 @@ def two_stage_ppr_facts( return out +DEFAULT_PREDICATE_BOOST_MULTIPLIER = 1.5 +"""Score multiplier applied to beam edges whose predicate matches one +the retrieval planner flagged as a priority. 1.5x keeps the boosted +edges competitive against high-confidence non-priority edges without +swamping the score distribution. Tune in Phase 2 ablation.""" + + def beam_search_facts( seed_entities: Sequence[str], backend: MemoryBackend, @@ -310,6 +317,8 @@ def beam_search_facts( config: TraversalConfig | None = None, max_hops: int = 10, top_k: int = 250, + predicate_boost: Sequence[str] | None = None, + predicate_boost_multiplier: float = DEFAULT_PREDICATE_BOOST_MULTIPLIER, ) -> list[Fact]: """Phase 5.3 — confidence-decayed multi-hop beam search over the fact graph. @@ -369,6 +378,13 @@ def beam_search_facts( all_scored: list[tuple[float, int, str]] = [] # (hop_score, hop, fact_id) fact_ids_seen: set[str] = set() + # Normalize the planner's priority predicates for case-insensitive + # matching against edge data. None / empty set bypasses the boost. + boost_set: set[str] = ( + {p.strip().lower() for p in predicate_boost if p and p.strip()} + if predicate_boost + else set() + ) def _process_hop( edge_tuples: list[tuple[str, str, str]], @@ -386,6 +402,10 @@ def _process_hop( if apply_confidence_filter and edge_conf < cfg.min_edge_confidence: continue hop_score = edge_conf * (cfg.decay_factor**hop) + if boost_set: + predicate = str(edge_data.get("predicate", "")).strip().lower() + if predicate in boost_set: + hop_score *= predicate_boost_multiplier scored.append((hop_score, key, u, v)) scored.sort(key=lambda x: x[0], reverse=True) diff --git a/src/engram/dialogue/orchestrator.py b/src/engram/dialogue/orchestrator.py index 99c6b1a..c358543 100644 --- a/src/engram/dialogue/orchestrator.py +++ b/src/engram/dialogue/orchestrator.py @@ -263,7 +263,11 @@ async def flush(self) -> None: ) for result in results: if isinstance(result, BaseException): - logger.warning("Background graph-build task raised: %r", result) + logger.warning( + "Background graph-build task raised: %r", + result, + exc_info=(type(result), result, result.__traceback__), + ) async def aclose(self) -> None: """Wait for pending background work and release resources.""" diff --git a/src/engram/dialogue/prompts/__init__.py b/src/engram/dialogue/prompts/__init__.py index 7d5fd59..43194e1 100644 --- a/src/engram/dialogue/prompts/__init__.py +++ b/src/engram/dialogue/prompts/__init__.py @@ -60,6 +60,11 @@ build_fact_extraction_prompt, build_fact_extraction_prompt_loose, ) +from engram.dialogue.prompts.retrieval_plan import ( + HopStep, + RetrievalPlan, + build_retrieval_plan_prompt, +) from engram.dialogue.prompts.synthesis import ( ChunkSynthesis, build_synthesis_prompt, @@ -84,7 +89,9 @@ "ExtractedEntity", "ExtractedFact", "FactExtractionResult", + "HopStep", "PromptCall", + "RetrievalPlan", "build_batch_bridging_prompt", "build_batch_derivation_prompt", "build_batch_entity_extraction_prompt", @@ -95,5 +102,6 @@ "build_entity_extraction_prompt", "build_fact_extraction_prompt", "build_fact_extraction_prompt_loose", + "build_retrieval_plan_prompt", "build_synthesis_prompt", ] diff --git a/src/engram/dialogue/prompts/retrieval_plan.py b/src/engram/dialogue/prompts/retrieval_plan.py new file mode 100644 index 0000000..371b66f --- /dev/null +++ b/src/engram/dialogue/prompts/retrieval_plan.py @@ -0,0 +1,391 @@ +"""Retrieval-planning prompt: turn a compressed graph view into an explicit plan. + +Given the user's question and a structured snapshot of the relevant +graph slice (per-mention 1-hop neighborhoods + corpus-wide predicate +histogram), the LLM produces a typed :class:`RetrievalPlan`: + +- ``hop_sequence``: ordered hops the answer requires, each tagged with + the predicate to traverse and the expected target entity type. +- ``expected_answer_type``: the literal type the answer should be — + the single biggest planner-addressable lever per Phase 0 failure + tagging (19% of failures are wrong-answer-type errors that no other + pipeline stage catches). +- ``priority_predicates``: the predicates the retriever should weight + up. The downstream beam search consumes this as a predicate-boost + to surface multi-hop chains it would otherwise miss. +- ``confidence``: the LLM's own confidence in the plan. Below the + caller's floor (default 0.5), the planner abstains and the pipeline + falls through to its existing reactive retrieval — no plan-bias. + +This is the first capability Engram ships that has no analogue in +HippoRAG 2, Microsoft GraphRAG, LightRAG, or DAVIS as of early 2026. +PPR-baseline + an explicit graph-aware up-front plan is a genuinely +open lane. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Annotated + +from pydantic import BaseModel, ConfigDict, Field + +from engram.dialogue.prompts.base import ( + ENGRAM_PERSONA, + GROUNDING_RULES, + PromptCall, +) + +if TYPE_CHECKING: + from engram.core.graph_view import CompressedGraphView + + +_ANSWER_TYPE_VOCAB = ( + "person", + "organization", + "location", + "product", + "project", + "event", + "concept", + "date", + "year", + "duration", + "number", + "money", + "language", + "title", + "other", +) +"""Allowed values for ``expected_answer_type``. The first seven mirror +Engram's entity-type vocabulary; the rest cover the literal-value +shapes MuSiQue questions commonly expect (years, durations, monetary +amounts, language names, work titles). Kept small so the LLM picks one +crisp value the downstream filter can match against fact objects.""" + + +class HopStep(BaseModel): + """A single planned hop in the retrieval chain.""" + + model_config = ConfigDict(extra="forbid") + + predicate: str = Field( + description=( + "The relation label to traverse for this hop. Should match " + "one of the predicates in the compressed graph view's " + "available_predicates list when possible. Free-form when " + "the planner is generalizing beyond what's present." + ) + ) + expected_target_type: str = Field( + description=( + "What entity type or value type the hop should land on. " + f"Pick one of: {', '.join(_ANSWER_TYPE_VOCAB)}." + ) + ) + reason: str = Field( + description=( + "One short sentence explaining why this hop is needed for " + "the question." + ) + ) + + +class RetrievalPlan(BaseModel): + """The full planner verdict for one question.""" + + model_config = ConfigDict(extra="forbid") + + reasoning: str = Field( + description=( + "Overview of how you read the question against the graph " + "view. Filled before the structured fields so the model " + "reasons before committing." + ) + ) + hop_sequence: list[HopStep] = Field( + description=( + "Ordered hops the answer requires. Empty list for " + "single-hop questions that need no planning. At most 4 " + "hops — beyond that the question is malformed or out of " + "scope for the corpus." + ) + ) + expected_answer_type: str = Field( + description=( + "The type of thing the final answer is. Pick one of: " + f"{', '.join(_ANSWER_TYPE_VOCAB)}. The downstream pipeline " + "uses this to filter fact objects whose type doesn't match " + "and to bias the reader's output shape." + ) + ) + priority_predicates: list[str] = Field( + description=( + "Predicates the retriever should weight up during beam " + "search. Order matters: the first is the most important. " + "Empty list when no predicate clearly dominates." + ) + ) + confidence: Annotated[float, Field(ge=0.0, le=1.0)] = Field( + description=( + "Your confidence the plan is correct (0..1). Below 0.5 the " + "orchestrator drops the plan entirely — the existing " + "reactive retrieval handles the question. Be honest: " + "abstaining is better than mis-biasing the retriever." + ) + ) + + +_SYSTEM_PROMPT = f"""\ +{ENGRAM_PERSONA} + +You are running the retrieval-planning step of a knowledge-graph +QA pipeline. You see (a) the user's question and (b) a compressed +snapshot of the relevant slice of the fact graph: each query +mention's 1-hop neighborhood (top edges by confidence), plus a +corpus-wide predicate histogram for context. + +Your job: emit a structured retrieval plan whose three signals +BIAS (not replace) the existing PPR + beam + triple-ANN + Cohere +Rerank retrieval. The downstream pipeline is robust; your plan +shifts what it weights up, not what it sees. + +{GROUNDING_RULES} + +How to read the situation: + +- Your plan has THREE signals, in descending order of importance: + + 1. **expected_answer_type** — almost always determinable from the + question wording alone. "What year ...?" -> "year". "Who ...?" + -> "person". This is the highest-leverage output and downstream + uses it to filter facts. Confidence in your plan should be + HIGH whenever you are confident in this single field, even if + the rest is uncertain. + + 2. **priority_predicates** — the relations the retriever should + weight up in beam search. You don't need all of them; even one + well-chosen predicate from the graph view is useful. List them + in priority order, capped at 5. + + 3. **hop_sequence** — optional structural hint. Empty list is the + correct answer for single-hop questions or when you can't see + the full chain. Do not invent hops you can't justify from the + graph view. + +- The corpus contains query slot concepts ("birthplace", "spouse", + "county", "creator") that won't appear as graph node names — they + describe RELATIONS, not entities. When such mentions are listed + in unresolved_mentions, this is normal — focus on the entities + that DID resolve. Unresolved slot concepts are not evidence of + a bad plan. + +- Predicates in this corpus are LLM-extracted and sometimes verbose + or odd ("originated_during_era_of_..."). Don't panic — match by + semantic intent. If the question asks "who founded X?" and the + view shows predicates "founded_by" or "established_by", either is + a good priority predicate. + +How to set confidence (calibrated, not conservative): + +- 0.8-1.0: expected_answer_type is clear; at least one priority + predicate is grounded in the view (or in available_predicates). +- 0.5-0.8: expected_answer_type is clear but priority predicates + are weakly grounded. Plan is still useful; bias is correct in + direction. +- 0.3-0.5: expected_answer_type is your best guess; little or no + predicate signal in the view. +- <0.3: graph view is empty AND you can't even confidently guess + the answer type. RARE — most questions yield at least a + defensible answer type. + +Output a RetrievalPlan. Fill reasoning before the structured +fields. ABSTAIN ONLY when both the answer type is genuinely +ambiguous AND there's no useful predicate signal — this is the +exception, not the default. +""" + + +_CORPUS_PREAMBLE = """\ +Worked example 1 — single hop, answer-type focus. + +Question: "What year did Apple Inc. release the first iPhone?" + +Compressed graph view (abridged): + Entity: Apple Inc. (type=organization, degree=12) + out: produced -> iPhone (type=product, conf=0.95) + out: founded_by -> Steve Jobs (type=person, conf=0.95) + out: headquartered_in -> Cupertino (type=location, conf=0.90) + Entity: iPhone (type=product, degree=5) + out: released_in -> 2007 (type=year, conf=0.90) + in: produced <- Apple Inc. + Corpus predicates: produced(120), founded_by(95), released_in(80), + headquartered_in(70), capital_of(45), ... + +Defensible plan: + reasoning: "Question asks for a year. iPhone has a direct + 'released_in -> 2007' edge in the graph slice. Single hop." + hop_sequence: [] + expected_answer_type: "year" + priority_predicates: ["released_in"] + confidence: 0.95 + +Worked example 2 — multi-hop, priority predicates from corpus. + +Question: "Who is the spouse of the actor who voices Batman in the +Lego Batman Movie?" + +Compressed graph view (abridged): + Entity: Lego Batman Movie (type=event, degree=8) + out: voiced_by -> Will Arnett (type=person, conf=0.92) + out: directed_by -> Chris McKay (type=person, conf=0.90) + Entity: Batman (type=concept, degree=15) + out: appears_in -> Lego Batman Movie (type=event, conf=0.85) + Corpus predicates: voiced_by(40), directed_by(35), spouse_of(28), + appears_in(120), born_in(95), ... + +Defensible plan: + reasoning: "Question asks for a person — the spouse of a voice + actor. Voiced_by is in the view; spouse_of is in the corpus + predicates. Two-hop chain." + hop_sequence: + - {predicate: "voiced_by", expected_target_type: "person", + reason: "Get the voice actor"} + - {predicate: "spouse_of", expected_target_type: "person", + reason: "From the actor, find their spouse"} + expected_answer_type: "person" + priority_predicates: ["voiced_by", "spouse_of"] + confidence: 0.9 + +Worked example 3 — slot mentions unresolved but plan is still confident. + +Question: "What county is Gerald T. Whelan's birthplace located in?" + +Compressed graph view (abridged): + Entity: Gerald T. Whelan (type=person, degree=4) + out: born_in -> Lincoln, Nebraska (type=location, conf=0.92) + out: profession -> politician (type=concept, conf=0.85) + (unresolved: "birthplace", "county" — these are query SLOTS, normal) + Corpus predicates: born_in(120), located_in(85), county_of(45), ... + +Defensible plan: + reasoning: "Question asks for a county — a location. Gerald T. + Whelan resolves; the unresolved 'birthplace' and 'county' are + query slots, not entities, and that's expected. Born_in is the + answer to 'birthplace'; located_in / county_of will get us + from the city to its county." + hop_sequence: + - {predicate: "born_in", expected_target_type: "location", + reason: "Get the birthplace city"} + - {predicate: "located_in", expected_target_type: "location", + reason: "From the city, find the containing county"} + expected_answer_type: "location" + priority_predicates: ["born_in", "located_in", "county_of"] + confidence: 0.85 + +Worked example 4 — abstention (rare, only when truly stuck). + +Question: "What is the average salary of a working citizen of the +same nationality as the author of The Feminine Mystique?" + +Compressed graph view (abridged): + Entity: The Feminine Mystique (type=other, degree=2) + out: written_by -> Betty Friedan (type=person, conf=0.85) + Corpus predicates: written_by(45), born_in(95), nationality(12), ... + +Defensible plan: + reasoning: "The question requires statistical data (average salary + by nationality) the graph clearly doesn't model. Even the + expected answer type 'money' would mislead retrieval here. This + is one of the rare cases where the corpus genuinely can't + answer." + hop_sequence: [] + expected_answer_type: "money" + priority_predicates: ["written_by", "nationality"] + confidence: 0.25 + +Note: example 4 is rare. Most MuSiQue-style questions yield at +least a confident answer type plus 1-2 priority predicates from the +view or the corpus histogram. Default toward confidence >= 0.6. +""" + + +def _render_view(view: CompressedGraphView) -> str: + """Format the CompressedGraphView for the prompt user message. + + Compact line-per-edge layout that keeps the structure scannable + for the LLM without ballooning token cost. + """ + if not view.entities: + entity_block = "(no query mentions resolved to a graph node)" + else: + lines: list[str] = [] + for ent in view.entities: + lines.append( + f"Entity: {ent.name} (type={ent.entity_type}, " + f"total_degree={ent.total_degree})" + ) + for edge in ent.out_edges: + lines.append( + f" out: {edge.predicate} -> {edge.other_name} " + f"(type={edge.other_type}, conf={edge.confidence:.2f})" + ) + for edge in ent.in_edges: + lines.append( + f" in: {edge.predicate} <- {edge.other_name} " + f"(type={edge.other_type}, conf={edge.confidence:.2f})" + ) + entity_block = "\n".join(lines) + + if view.available_predicates: + pred_summary = ", ".join( + f"{p}({c})" for p, c in view.available_predicates + ) + else: + pred_summary = "(empty corpus)" + + unresolved_block = ( + f"\nUnresolved query mentions: {', '.join(view.unresolved_mentions)}" + if view.unresolved_mentions + else "" + ) + + return ( + f"Question: {view.query}\n\n" + f"Compressed graph view:\n{entity_block}\n\n" + f"Corpus predicates (count): {pred_summary}" + f"{unresolved_block}" + ) + + +def build_retrieval_plan_prompt( + *, view: CompressedGraphView +) -> PromptCall: + """Build the planner LLM call for one query. + + The caller is responsible for materializing the + :class:`CompressedGraphView` via + ``engram.core.graph_view.build_query_graph_view`` first. + """ + user_msg = ( + f"{_render_view(view)}\n\n" + "Produce a RetrievalPlan. Remember: ground every predicate in " + "the graph view, fill expected_answer_type carefully, and " + "set confidence honestly (< 0.5 means abstain)." + ) + messages = [ + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": _CORPUS_PREAMBLE}, + {"role": "user", "content": user_msg}, + ] + return PromptCall( + messages=messages, + schema=RetrievalPlan, + cache_breakpoints=[0, 1], + ) + + +__all__ = [ + "HopStep", + "RetrievalPlan", + "build_retrieval_plan_prompt", +] diff --git a/src/engram/dialogue/retrieval_planner.py b/src/engram/dialogue/retrieval_planner.py new file mode 100644 index 0000000..f296776 --- /dev/null +++ b/src/engram/dialogue/retrieval_planner.py @@ -0,0 +1,133 @@ +"""Query-time retrieval planner: one LLM call, one structured plan. + +Wraps the retrieval-plan prompt with the bookkeeping the query path +needs: + +- ``llm_span("retrieval.plan", ...)`` tracing so each plan shows up + alongside the other dialogue LLM calls in observability. +- Confidence-floor abstention: when the LLM returns ``confidence`` below + ``confidence_floor`` (default 0.5), this function returns ``None`` and + the caller treats the question as unplanned — the existing reactive + retrieval handles it untouched. Abstention is a feature, not a + failure: a confidently wrong plan biases the retriever worse than no + plan at all. +- Graceful no-op on empty input. If the compressed graph view has no + resolved entities, we skip the LLM call entirely and return ``None``. + +This is the first capability Engram ships that has no analogue in +HippoRAG 2, MS GraphRAG, LightRAG, or DAVIS as of early 2026. PPR-baseline +plus an explicit graph-aware up-front plan, decided in one cheap LLM +call before any retrieval runs. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from engram.dialogue.prompts import ( + RetrievalPlan, + build_retrieval_plan_prompt, +) +from engram.observability.tracing import llm_span + +if TYPE_CHECKING: + from engram.core.graph_view import CompressedGraphView + from engram.core.protocol import LLMProvider + + +logger = logging.getLogger(__name__) + + +DEFAULT_CONFIDENCE_FLOOR = 0.5 +"""Plans below this confidence are discarded by ``plan_retrieval``. + +Why 0.5 and not 0.6 (like the entity-link voice on +``feat/slm-voices``): the planner output is a SOFT bias — predicate +boost in beam search + a small fact-filter — not a hard pick. The cost +of a wrong plan is at most a few facts surfaced in suboptimal order; +the cost of an over-aggressive abstention is missing the lift entirely. +Tune up if false-positive plans are observed during Phase 2 ablation. +""" + + +async def plan_retrieval( + *, + view: CompressedGraphView, + provider: LLMProvider, + model: str | None = None, + confidence_floor: float = DEFAULT_CONFIDENCE_FLOOR, + raw_plan_sink: list[RetrievalPlan | None] | None = None, +) -> RetrievalPlan | None: + """Produce a structured retrieval plan for the query. + + ``view`` should come from + :func:`engram.core.graph_view.build_query_graph_view`. The view's + ``query`` field is the question; its ``entities`` are the resolved + 1-hop neighborhoods. + + Returns ``None`` when: + - the view has no resolved entities (nothing to plan from), + - the LLM call fails (logged, never raises), + - the returned plan's ``confidence`` is below ``confidence_floor``. + + A ``None`` return means "fall through to reactive retrieval" — the + caller MUST handle this case as no-op, not as a hard error. + """ + if not view.entities: + logger.debug( + "plan_retrieval: skipping LLM call — no resolved entities for %r", + view.query[:60], + ) + if raw_plan_sink is not None: + raw_plan_sink.append(None) + return None + + call = build_retrieval_plan_prompt(view=view) + async with llm_span( + "retrieval.plan", + model=model or "default", + message_count=len(call.messages), + cache_breakpoints=call.cache_breakpoints, + attributes={ + "engram.entity_count": len(view.entities), + "engram.unresolved_count": len(view.unresolved_mentions), + }, + ): + try: + result: RetrievalPlan = await provider.extract( + call.messages, + call.schema, + model=model, + cache_breakpoints=call.cache_breakpoints, + temperature=0.0, + ) + except Exception as exc: + logger.warning( + "Retrieval planner LLM call failed for query %r: %s", + view.query[:60], + exc, + ) + if raw_plan_sink is not None: + raw_plan_sink.append(None) + return None + + if raw_plan_sink is not None: + raw_plan_sink.append(result) + + if result.confidence < confidence_floor: + logger.debug( + "plan_retrieval: abstaining (conf=%.2f < floor=%.2f) for %r", + result.confidence, + confidence_floor, + view.query[:60], + ) + return None + + return result + + +__all__ = [ + "DEFAULT_CONFIDENCE_FLOOR", + "plan_retrieval", +] diff --git a/tests/integration/test_kg_retrieval_with_plan.py b/tests/integration/test_kg_retrieval_with_plan.py new file mode 100644 index 0000000..2727dce --- /dev/null +++ b/tests/integration/test_kg_retrieval_with_plan.py @@ -0,0 +1,244 @@ +"""Integration test: retrieval plan biases the beam search toward the planned predicate. + +Exercises the full ``kg_hybrid_neighbors`` flow with a real in-memory +``MemoryBackend``, real ``fact_graph``, and a real ``RetrievalPlan`` +(no LLM call — the plan is constructed directly). Asserts: + +- A plan that lists ``voiced_by`` as a priority predicate causes beam + search to surface the ``voiced_by`` fact ahead of higher-confidence + but irrelevant facts the planner deemed lower-priority. +- A plan-aware filter that mismatches the answer type to fact-object + types removes the wrong-type facts up to the 30% safety cap. +- The plan's hints get appended to the Cohere Rerank query (when a + reranker is configured). +""" + +from __future__ import annotations + +import hashlib +import struct +from collections.abc import Sequence +from typing import Any + +import pytest + +from benchmarks.retrieval import BM25Index, kg_hybrid_neighbors + +from engram.backends.memory import MemoryBackend +from engram.core.models import Chunk, Fact +from engram.core.protocol import EmbeddingKind +from engram.dialogue.prompts import HopStep, RetrievalPlan + + +class _HashEmbedder: + """Deterministic hash embedder for testing without real OpenAI calls. + + Mirrors the pattern in tests/unit/test_benchmarks_retrieval.py. + """ + + def __init__(self, dim: int = 16) -> None: + self._dim = dim + + async def embed( + self, + texts: Sequence[str], + *, + kind: EmbeddingKind = "document", + ) -> Sequence[Sequence[float]]: + del kind + out = [] + for t in texts: + h = hashlib.sha256(t.encode()).digest() + vec = [] + for i in range(self._dim): + sl = h[(i * 2) % len(h) : ((i * 2) + 4) % len(h) or len(h)] + if len(sl) < 4: + sl = (sl + h)[:4] + vec.append(struct.unpack("f", sl)[0]) + norm = sum(v * v for v in vec) ** 0.5 or 1.0 + out.append([v / norm for v in vec]) + return out + + +@pytest.mark.asyncio +async def test_plan_predicate_boost_surfaces_planned_voiced_by_edge( + tmp_path, +) -> None: + """The Plankton -> voiced_by -> Mr. Lawrence pattern from failure analysis. + + Without the plan, beam search picks facts purely by edge confidence. + Here we set up a graph where higher-confidence facts about Plankton + exist (appears_in, type_of) so without a planner hint the voiced_by + fact wouldn't be ranked highest. With the plan's + ``priority_predicates=["voiced_by"]``, beam multiplies the + voiced_by score by 1.5x so the voiced_by fact surfaces. + """ + backend = MemoryBackend(embedder=_HashEmbedder(), path=tmp_path / "store") + try: + chunks = [ + Chunk( + id="c1", + text="Plankton appears in SpongeBob SquarePants every episode.", + source_id="d1", + ), + Chunk( + id="c2", + text="Plankton is a one-celled organism in the show.", + source_id="d2", + ), + Chunk( + id="c3", + text="Plankton is voiced by Mr. Lawrence.", + source_id="d3", + ), + ] + await backend.upsert_chunks(chunks) + await backend.upsert_facts( + [ + # High-confidence non-target edges, the noise. + Fact( + id="f1", + subject="Plankton", + predicate="appears_in", + object="SpongeBob SquarePants", + confidence=0.95, + source_chunk_ids=["c1"], + ), + Fact( + id="f2", + subject="Plankton", + predicate="type_of", + object="organism", + confidence=0.95, + source_chunk_ids=["c2"], + ), + # The target edge — lower confidence than the noise. + Fact( + id="f3", + subject="Plankton", + predicate="voiced_by", + object="Mr. Lawrence", + confidence=0.80, + source_chunk_ids=["c3"], + ), + ] + ) + + bm25 = BM25Index(chunks) + plan = RetrievalPlan( + reasoning="Question asks for voice actor; route via voiced_by.", + hop_sequence=[ + HopStep( + predicate="voiced_by", + expected_target_type="person", + reason="Get the voice actor", + ) + ], + expected_answer_type="person", + priority_predicates=["voiced_by"], + confidence=0.9, + ) + + _, matched_facts = await kg_hybrid_neighbors( + query="Who voices Plankton in SpongeBob SquarePants?", + entities=["Plankton"], + backend=backend, + bm25=bm25, + plan=plan, + ) + + # The voiced_by fact (f3) must be in the surfaced facts — + # without the plan boost it could be dropped from the top of + # the fused ranking because of its lower base confidence. + fact_ids = {f.id for f in matched_facts} + assert "f3" in fact_ids + finally: + backend.close() + + +@pytest.mark.asyncio +async def test_kg_hybrid_neighbors_works_with_no_plan(tmp_path) -> None: + """Sanity: passing plan=None must give byte-identical behavior to before.""" + backend = MemoryBackend(embedder=_HashEmbedder(), path=tmp_path / "store") + try: + chunks = [ + Chunk(id="c1", text="Paris is the capital of France.", source_id="d1"), + ] + await backend.upsert_chunks(chunks) + await backend.upsert_facts( + [ + Fact( + id="f1", + subject="Paris", + predicate="capital_of", + object="France", + confidence=0.95, + source_chunk_ids=["c1"], + ), + ] + ) + bm25 = BM25Index(chunks) + _, matched_facts = await kg_hybrid_neighbors( + query="Where is Paris?", + entities=["Paris"], + backend=backend, + bm25=bm25, + plan=None, + ) + assert any(f.id == "f1" for f in matched_facts) + finally: + backend.close() + + +@pytest.mark.asyncio +async def test_plan_filter_disabled_when_removal_exceeds_cap(tmp_path) -> None: + """An over-aggressive plan that would drop >30% of facts must be disabled. + + Build a backend with only person-typed objects, then submit a plan + expecting answer_type=year. The filter would drop 100% of facts; + the safety cap must kick in and leave them intact. + """ + backend = MemoryBackend(embedder=_HashEmbedder(), path=tmp_path / "store") + try: + chunks = [ + Chunk(id="c1", text="X is friends with Alice.", source_id="d1"), + Chunk(id="c2", text="Y is friends with Bob.", source_id="d2"), + Chunk(id="c3", text="Z is friends with Carol.", source_id="d3"), + ] + await backend.upsert_chunks(chunks) + await backend.upsert_facts( + [ + Fact( + id=f"f{i}", + subject=f"Entity{i}", + predicate="friends_with", + object=name, + confidence=0.9, + source_chunk_ids=[f"c{i}"], + ) + for i, name in enumerate(["Alice", "Bob", "Carol"], start=1) + ] + ) + bm25 = BM25Index(chunks) + # Plan expects "year" — every fact's object is a person name, + # so the regex year filter would drop 100% of facts. The cap + # at 30% removal should disable the filter and keep them. + plan = RetrievalPlan( + reasoning="-", + hop_sequence=[], + expected_answer_type="year", + priority_predicates=[], + confidence=0.9, + ) + _, matched_facts = await kg_hybrid_neighbors( + query="When did X happen?", + entities=["Entity1", "Entity2", "Entity3"], + backend=backend, + bm25=bm25, + plan=plan, + ) + # All three facts must survive — the cap disables the filter. + fact_ids = {f.id for f in matched_facts} + assert {"f1", "f2", "f3"}.issubset(fact_ids) + finally: + backend.close() diff --git a/tests/unit/test_core_graph_view.py b/tests/unit/test_core_graph_view.py new file mode 100644 index 0000000..4e43298 --- /dev/null +++ b/tests/unit/test_core_graph_view.py @@ -0,0 +1,189 @@ +"""Unit tests for engram.core.graph_view. + +Covers the deterministic, no-LLM logic that builds the compressed +graph view the planner consumes: entity resolution, top-K edge +sampling, predicate histogram, missing-entity fall-through, and the +hub case where a node has more edges than the cap. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any +from unittest.mock import patch + +import networkx as nx + +from engram.core.graph_view import ( + DEFAULT_MAX_EDGES_PER_ENTITY, + DEFAULT_PREDICATE_TOP_N, + CompressedGraphView, + EntityNeighborhood, + build_query_graph_view, +) +from engram.core.models import EntityRecord, Fact + + +class _StubBackend: + """Minimal stand-in for MemoryBackend, exposing only what graph_view needs.""" + + def __init__( + self, + graph: nx.MultiDiGraph, + *, + entity_records: dict[str, EntityRecord] | None = None, + facts: list[Fact] | None = None, + ) -> None: + self.fact_graph = graph + self._entities = entity_records or {} + self._facts = facts or [] + + def get_entity(self, name: str) -> EntityRecord | None: + return self._entities.get(name) + + def iter_facts(self) -> Any: + return iter(self._facts) + + +def _add_edge( + g: nx.MultiDiGraph, + u: str, + v: str, + *, + predicate: str, + confidence: float, + fact_id: str | None = None, +) -> None: + g.add_node(u) + g.add_node(v) + g.add_edge( + u, + v, + key=fact_id or f"{u}-{predicate}-{v}", + predicate=predicate, + confidence=confidence, + ) + + +def _fact(subject: str, predicate: str, object_: str, *, confidence: float = 0.9) -> Fact: + return Fact(subject=subject, predicate=predicate, object=object_, confidence=confidence) + + +# --------------------------------------------------------------------------- +# build_query_graph_view +# --------------------------------------------------------------------------- + + +def test_build_view_empty_graph_returns_empty_view() -> None: + backend = _StubBackend(nx.MultiDiGraph()) + view = build_query_graph_view("any query", ["X"], backend) + assert isinstance(view, CompressedGraphView) + assert view.entities == [] + assert view.available_predicates == [] + assert view.unresolved_mentions == ["X"] + + +def test_build_view_resolves_exact_then_normalized() -> None: + g = nx.MultiDiGraph() + g.add_node("Apple Inc.") + g.add_node("Steve Jobs") + _add_edge(g, "Apple Inc.", "Steve Jobs", predicate="founded_by", confidence=0.95) + backend = _StubBackend( + g, + entity_records={ + "Apple Inc.": EntityRecord(canonical_name="Apple Inc.", entity_type="organization"), + "Steve Jobs": EntityRecord(canonical_name="Steve Jobs", entity_type="person"), + }, + ) + + # Exact hit + view = build_query_graph_view("Where", ["Apple Inc."], backend) + assert [e.name for e in view.entities] == ["Apple Inc."] + assert view.entities[0].entity_type == "organization" + assert view.entities[0].out_edges[0].predicate == "founded_by" + assert view.entities[0].out_edges[0].other_type == "person" + assert view.unresolved_mentions == [] + + # Normalized hit (lowercase) — still resolves. + view2 = build_query_graph_view("Where", ["apple inc"], backend) + assert [e.name for e in view2.entities] == ["Apple Inc."] + + +def test_build_view_unknown_entity_type_falls_through_to_unknown() -> None: + g = nx.MultiDiGraph() + _add_edge(g, "Apple Inc.", "1976", predicate="founded_in", confidence=0.9) + # No EntityRecord for "1976" — literal value + backend = _StubBackend(g) + view = build_query_graph_view("when", ["Apple Inc."], backend) + assert view.entities[0].entity_type == "unknown" + assert view.entities[0].out_edges[0].other_type == "unknown" + + +def test_build_view_top_k_by_confidence() -> None: + """When an entity has more out_edges than max_edges, keep top-K by confidence.""" + g = nx.MultiDiGraph() + for i, conf in enumerate([0.95, 0.85, 0.75, 0.65, 0.55, 0.45, 0.35, 0.25]): + _add_edge(g, "Hub", f"N{i}", predicate=f"rel{i}", confidence=conf, fact_id=f"f{i}") + backend = _StubBackend(g) + view = build_query_graph_view("Hub", ["Hub"], backend, max_edges_per_entity=3) + confs = [e.confidence for e in view.entities[0].out_edges] + assert confs == sorted(confs, reverse=True) + assert len(view.entities[0].out_edges) == 3 + assert confs == [0.95, 0.85, 0.75] + # total_degree should reflect the full degree, not the K cap. + assert view.entities[0].total_degree == 8 + + +def test_build_view_collects_in_and_out_edges_separately() -> None: + g = nx.MultiDiGraph() + _add_edge(g, "A", "B", predicate="rel1", confidence=0.9) + _add_edge(g, "C", "A", predicate="rel2", confidence=0.85) + backend = _StubBackend(g) + view = build_query_graph_view("test", ["A"], backend) + assert [e.predicate for e in view.entities[0].out_edges] == ["rel1"] + assert [e.predicate for e in view.entities[0].in_edges] == ["rel2"] + assert view.entities[0].in_edges[0].other_name == "C" + + +def test_build_view_predicate_histogram_top_n() -> None: + g = nx.MultiDiGraph() + facts = [ + _fact("X", "rel_a", "Y"), + _fact("X", "rel_a", "Z"), + _fact("X", "rel_b", "Y"), + _fact("X", "rel_c", "Z"), + ] + for f in facts: + _add_edge(g, f.subject, f.object, predicate=f.predicate, confidence=f.confidence, fact_id=f.id) + backend = _StubBackend(g, facts=facts) + view = build_query_graph_view("test", [], backend, predicate_top_n=2) + # Top-2: rel_a (count 2) then either rel_b or rel_c (tied at 1). + assert len(view.available_predicates) == 2 + assert view.available_predicates[0] == ("rel_a", 2) + assert view.available_predicates[1][1] == 1 + + +def test_build_view_unresolved_mentions_listed() -> None: + g = nx.MultiDiGraph() + g.add_node("Apple Inc.") + backend = _StubBackend(g) + view = build_query_graph_view( + "Where", ["Apple Inc.", "Banana Republic", " ", ""], backend + ) + assert [e.name for e in view.entities] == ["Apple Inc."] + assert view.unresolved_mentions == ["Banana Republic"] + + +def test_build_view_skip_blank_mentions_silently() -> None: + g = nx.MultiDiGraph() + g.add_node("Apple Inc.") + backend = _StubBackend(g) + view = build_query_graph_view("Where", ["", " ", "\t"], backend) + assert view.entities == [] + assert view.unresolved_mentions == [] + + +def test_build_view_query_field_echoed_through() -> None: + backend = _StubBackend(nx.MultiDiGraph()) + view = build_query_graph_view("a specific question", [], backend) + assert view.query == "a specific question" diff --git a/tests/unit/test_dialogue_retrieval_planner.py b/tests/unit/test_dialogue_retrieval_planner.py new file mode 100644 index 0000000..e7a6e52 --- /dev/null +++ b/tests/unit/test_dialogue_retrieval_planner.py @@ -0,0 +1,221 @@ +"""Unit tests for engram.dialogue.retrieval_planner. + +Covers the wrapper logic around the LLM call: confidence-floor +abstention, empty-view skip, LLM-error graceful fallback, and the +trace-attribute plumbing. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, TypeVar + +import pytest +from pydantic import BaseModel + +from engram.core.graph_view import ( + CompressedGraphView, + EdgeSummary, + EntityNeighborhood, +) +from engram.dialogue.prompts import HopStep, RetrievalPlan +from engram.dialogue.retrieval_planner import ( + DEFAULT_CONFIDENCE_FLOOR, + plan_retrieval, +) + +T = TypeVar("T", bound=BaseModel) + + +class _CannedProvider: + """Mock LLMProvider that returns pre-canned RetrievalPlan responses.""" + + def __init__(self, responses: list[BaseModel | Exception]) -> None: + self._responses = list(responses) + self.calls: list[dict[str, Any]] = [] + + async def complete( + self, messages: Sequence[dict[str, str]], **kwargs: Any + ) -> str: + del messages, kwargs + return "" + + async def extract( + self, + messages: Sequence[dict[str, str]], + schema: type[T], + *, + model: str | None = None, + cache_breakpoints: Sequence[int] | None = None, + **kwargs: Any, + ) -> T: + del kwargs + self.calls.append( + {"messages": list(messages), "schema": schema, "model": model} + ) + response = self._responses.pop(0) + if isinstance(response, Exception): + raise response + assert isinstance(response, schema), ( + f"unexpected schema {schema.__name__}; got {type(response).__name__}" + ) + return response # type: ignore[return-value] + + +def _view_with_one_entity() -> CompressedGraphView: + """A non-empty view with one resolved entity — the typical input shape.""" + return CompressedGraphView( + query="Where was Apple founded?", + entities=[ + EntityNeighborhood( + name="Apple Inc.", + entity_type="organization", + total_degree=3, + out_edges=[ + EdgeSummary( + predicate="founded_in", + other_name="Cupertino", + other_type="location", + confidence=0.9, + ) + ], + in_edges=[], + ) + ], + available_predicates=[("founded_in", 80)], + unresolved_mentions=[], + ) + + +def _empty_view() -> CompressedGraphView: + return CompressedGraphView( + query="any", + entities=[], + available_predicates=[], + unresolved_mentions=["X"], + ) + + +# --------------------------------------------------------------------------- +# plan_retrieval +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_plan_retrieval_returns_high_confidence_plan() -> None: + plan = RetrievalPlan( + reasoning="Single-hop founded_in.", + hop_sequence=[], + expected_answer_type="location", + priority_predicates=["founded_in"], + confidence=0.85, + ) + provider = _CannedProvider([plan]) + out = await plan_retrieval(view=_view_with_one_entity(), provider=provider) # type: ignore[arg-type] + assert out is not None + assert out.expected_answer_type == "location" + assert out.priority_predicates == ["founded_in"] + assert len(provider.calls) == 1 + + +@pytest.mark.asyncio +async def test_plan_retrieval_abstains_below_confidence_floor() -> None: + plan = RetrievalPlan( + reasoning="Ambiguous.", + hop_sequence=[], + expected_answer_type="other", + priority_predicates=[], + confidence=0.3, # below default 0.5 floor + ) + provider = _CannedProvider([plan]) + out = await plan_retrieval(view=_view_with_one_entity(), provider=provider) # type: ignore[arg-type] + assert out is None + # The LLM was still called — abstention is downstream. + assert len(provider.calls) == 1 + + +@pytest.mark.asyncio +async def test_plan_retrieval_custom_floor_passes_low_conf() -> None: + plan = RetrievalPlan( + reasoning="-", + hop_sequence=[], + expected_answer_type="location", + priority_predicates=["founded_in"], + confidence=0.4, + ) + provider = _CannedProvider([plan]) + out = await plan_retrieval( + view=_view_with_one_entity(), + provider=provider, # type: ignore[arg-type] + confidence_floor=0.3, + ) + assert out is not None + assert out.confidence == pytest.approx(0.4) + + +@pytest.mark.asyncio +async def test_plan_retrieval_skips_llm_on_empty_view() -> None: + """No resolved entities -> never invoke the LLM.""" + provider = _CannedProvider([]) # would error if called + out = await plan_retrieval(view=_empty_view(), provider=provider) # type: ignore[arg-type] + assert out is None + assert len(provider.calls) == 0 + + +@pytest.mark.asyncio +async def test_plan_retrieval_swallows_llm_errors_gracefully() -> None: + provider = _CannedProvider([RuntimeError("network died")]) + # Must not raise — abstention is the correct response to LLM failure. + out = await plan_retrieval(view=_view_with_one_entity(), provider=provider) # type: ignore[arg-type] + assert out is None + assert len(provider.calls) == 1 + + +@pytest.mark.asyncio +async def test_plan_retrieval_passes_model_through() -> None: + plan = RetrievalPlan( + reasoning="-", + hop_sequence=[], + expected_answer_type="location", + priority_predicates=["founded_in"], + confidence=0.9, + ) + provider = _CannedProvider([plan]) + await plan_retrieval( + view=_view_with_one_entity(), + provider=provider, # type: ignore[arg-type] + model="openai/gpt-4o-mini", + ) + assert provider.calls[0]["model"] == "openai/gpt-4o-mini" + + +@pytest.mark.asyncio +async def test_plan_retrieval_returns_multi_hop_plan() -> None: + plan = RetrievalPlan( + reasoning="Two-hop chain: voiced_by then spouse_of.", + hop_sequence=[ + HopStep( + predicate="voiced_by", + expected_target_type="person", + reason="Find voice actor", + ), + HopStep( + predicate="spouse_of", + expected_target_type="person", + reason="Find spouse of voice actor", + ), + ], + expected_answer_type="person", + priority_predicates=["voiced_by", "spouse_of"], + confidence=0.8, + ) + provider = _CannedProvider([plan]) + out = await plan_retrieval(view=_view_with_one_entity(), provider=provider) # type: ignore[arg-type] + assert out is not None + assert len(out.hop_sequence) == 2 + assert out.hop_sequence[0].predicate == "voiced_by" + assert out.hop_sequence[1].predicate == "spouse_of" + + +def test_default_confidence_floor_is_set() -> None: + assert DEFAULT_CONFIDENCE_FLOOR == 0.5 From a894037010d0bfdf5d5c1e9c56a0a102a4add5e0 Mon Sep 17 00:00:00 2001 From: Vedant Patel Date: Mon, 8 Jun 2026 12:06:29 -0700 Subject: [PATCH 2/3] chore(lint): add memory.py logger, ruff format + import sort CI fixes after the planner commit: - src/engram/backends/memory.py: add module-level `logger = logging.getLogger(__name__)`. The LMDB key-length guards (cherry-picked from feat/slm-voices) called logger.warning but the original main-branch module had no logger import. - ruff check --fix: remove unused imports in the new test modules (EntityNeighborhood / DEFAULT_PREDICATE_TOP_N from test_core_graph_view.py; HopStep was unused in one test). - ruff format: standardize formatting on new files + a few existing benchmarks files that had drifted (decomposition.py, reranker.py whitespace). Verified: ruff check clean, ruff format --check clean, 366/366 tests pass. Co-Authored-By: Claude Opus 4.7 --- benchmarks/decomposition.py | 9 ++++---- benchmarks/failure_tagger.py | 22 +++++-------------- benchmarks/musique.py | 19 +++++----------- benchmarks/reranker.py | 7 ++---- benchmarks/retrieval.py | 12 +++------- benchmarks/runner.py | 22 ++++++------------- src/engram/backends/memory.py | 3 +++ src/engram/core/graph_view.py | 4 +--- src/engram/dialogue/prompts/retrieval_plan.py | 16 ++++---------- .../test_kg_retrieval_with_plan.py | 2 -- tests/unit/test_core_graph_view.py | 13 ++++------- tests/unit/test_dialogue_retrieval_planner.py | 8 ++----- 12 files changed, 41 insertions(+), 96 deletions(-) diff --git a/benchmarks/decomposition.py b/benchmarks/decomposition.py index 88402e2..86caeb7 100644 --- a/benchmarks/decomposition.py +++ b/benchmarks/decomposition.py @@ -103,10 +103,10 @@ class QueryDecomposition(BaseModel): "- A sub-question's answer must be a short factual span (name, " " year, number, place) — not a description.\n" "- Use ``depends_on`` to record the reasoning chain: if Q2's " - " phrasing requires Q1's answer (\"What year was [X] born?\" " + ' phrasing requires Q1\'s answer ("What year was [X] born?" ' " where X is found by Q0), set depends_on=[0].\n" - "- For SIMPLE single-hop questions (\"What is the capital of " - " France?\"), return ONE sub-question matching the original. " + '- For SIMPLE single-hop questions ("What is the capital of ' + ' France?"), return ONE sub-question matching the original. ' " Over-decomposing single-hop questions hurts retrieval.\n" "- Maximum 5 sub-questions. Prefer fewer.\n" "- target_entity is what you expect to discover; may be empty.\n\n" @@ -161,8 +161,7 @@ async def decompose_query( ) except Exception as exc: logger.warning( - "Query decomposition failed for question, falling back to " - "single-query retrieval: %s", + "Query decomposition failed for question, falling back to single-query retrieval: %s", exc, ) return [question] diff --git a/benchmarks/failure_tagger.py b/benchmarks/failure_tagger.py index efac83a..5607a0d 100644 --- a/benchmarks/failure_tagger.py +++ b/benchmarks/failure_tagger.py @@ -57,9 +57,7 @@ class FailureTag(BaseModel): reasoning: str = Field( description="Brief justification before the category. One or two sentences." ) - category: FailureCategory = Field( - description="The single best-fit failure category." - ) + category: FailureCategory = Field(description="The single best-fit failure category.") _SYSTEM_PROMPT = """\ @@ -157,13 +155,9 @@ async def tag_failure( model: str | None = None, ) -> FailureTag | None: """Return a single FailureTag or None on LLM error.""" - messages = build_tag_prompt( - question=question, gold_answers=gold_answers, prediction=prediction - ) + messages = build_tag_prompt(question=question, gold_answers=gold_answers, prediction=prediction) try: - return await llm.extract( - messages, FailureTag, model=model, temperature=0.0 - ) + return await llm.extract(messages, FailureTag, model=model, temperature=0.0) except Exception as exc: logger.warning("Tag failed for question %r: %s", question[:60], exc) return None @@ -198,9 +192,7 @@ async def main_async(args: argparse.Namespace) -> int: if args.mode: preds = [p for p in preds if p.get("mode") == args.mode] - failures = [ - p for p in preds if is_em_failure(p["prediction"], p["gold_answers"]) - ] + failures = [p for p in preds if is_em_failure(p["prediction"], p["gold_answers"])] logger.info( "Loaded %d predictions; %d failures (mode=%s)", len(preds), @@ -245,11 +237,9 @@ async def _bounded(p: dict) -> None: total = len(tagged) or 1 print("\nFailure category histogram:") for cat, n in counts.most_common(): - print(f" {cat:25s} {n:4d} ({100*n/total:.1f}%)") + print(f" {cat:25s} {n:4d} ({100 * n / total:.1f}%)") - planner_addressable = counts.get("wrong-hop-target", 0) + counts.get( - "answer-type-mismatch", 0 - ) + planner_addressable = counts.get("wrong-hop-target", 0) + counts.get("answer-type-mismatch", 0) pct = 100 * planner_addressable / total print( f"\nPlanner-addressable (wrong-hop + answer-type-mismatch): " diff --git a/benchmarks/musique.py b/benchmarks/musique.py index 5082a52..7762a28 100644 --- a/benchmarks/musique.py +++ b/benchmarks/musique.py @@ -397,9 +397,7 @@ async def _amain(args: argparse.Namespace) -> int: summaries: dict[str, dict] = {} all_predictions: list = [] backends: list[MemoryBackend] = [] - plan_trace: list[dict] | None = ( - [] if args.trace_retrieval_plan else None - ) + plan_trace: list[dict] | None = [] if args.trace_retrieval_plan else None try: for mode in modes: @@ -419,9 +417,7 @@ async def _amain(args: argparse.Namespace) -> int: backends.append(backend) if mode == "baseline": - bm25 = await ingest_baseline( - questions, backend=backend, reingest=args.reingest - ) + bm25 = await ingest_baseline(questions, backend=backend, reingest=args.reingest) else: engram_llm = LiteLLMProvider( default_model=args.reader_model, @@ -453,19 +449,14 @@ async def _amain(args: argparse.Namespace) -> int: logger.info("rate-control stats after %s ingest: %s", mode, adaptive.stats) - use_graph_retrieval = ( - mode == "enriched" and args.graph_retrieval and args.build_graph - ) - use_kg_retrieval = ( - mode == "enriched" and args.kg_retrieval and args.build_graph - ) + use_graph_retrieval = mode == "enriched" and args.graph_retrieval and args.build_graph + use_kg_retrieval = mode == "enriched" and args.kg_retrieval and args.build_graph # Decomposition runs in enriched mode only — it's Engram's # multi-hop structural intervention, paired with the fact graph. use_decomposition = args.decompose and mode == "enriched" if use_graph_retrieval: logger.info( - "%s: graph-aware retrieval enabled " - "(entity extraction + 2-hop fact lookup)", + "%s: graph-aware retrieval enabled (entity extraction + 2-hop fact lookup)", mode, ) if use_kg_retrieval: diff --git a/benchmarks/reranker.py b/benchmarks/reranker.py index 66570c1..10a6eed 100644 --- a/benchmarks/reranker.py +++ b/benchmarks/reranker.py @@ -37,9 +37,7 @@ logger = logging.getLogger(__name__) -DEFAULT_RERANK_MODEL_ARN = ( - "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0" -) +DEFAULT_RERANK_MODEL_ARN = "arn:aws:bedrock:us-east-1::foundation-model/cohere.rerank-v3-5:0" """Bedrock model ARN for Cohere Rerank 3.5. Default in us-east-1; supports >100 languages and per-document text up to 4000 chars per source.""" @@ -80,8 +78,7 @@ def __init__( import boto3 except ImportError as exc: raise ImportError( - "CohereReranker (Bedrock) requires boto3. " - "Install with: pip install boto3" + "CohereReranker (Bedrock) requires boto3. Install with: pip install boto3" ) from exc import boto3 diff --git a/benchmarks/retrieval.py b/benchmarks/retrieval.py index ca2fcd6..90940cd 100644 --- a/benchmarks/retrieval.py +++ b/benchmarks/retrieval.py @@ -523,9 +523,7 @@ async def kg_hybrid_neighbors( # alpha=0.45 narrow seeded from Stage 1's top entities). Same call # site as the prior single-stage ppr_facts so the rest of the fusion # stays identical and the ablation isolates the PPR change. - ppr_ranked = ( - two_stage_ppr_facts(entities, backend, top_k=ppr_top_k) if entities else [] - ) + ppr_ranked = two_stage_ppr_facts(entities, backend, top_k=ppr_top_k) if entities else [] # Plan-aware bias: when a retrieval plan is provided, beam search # boosts edges whose predicate matches the planner's priority list. # PPR / triple_match stay untouched — only beam gets the bias, because @@ -651,9 +649,7 @@ async def kg_hybrid_neighbors( _PLAN_FILTER_REMOVAL_CAP = 0.30 -def _fact_object_matches_answer_type( - obj: str, answer_type: str, backend: MemoryBackend -) -> bool: +def _fact_object_matches_answer_type(obj: str, answer_type: str, backend: MemoryBackend) -> bool: """Heuristic: does this fact's object look like the planner's expected type? Returns True liberally — only returns False when we have positive @@ -729,9 +725,7 @@ def _plan_aware_rerank_query(query: str, plan: RetrievalPlan | None) -> str: if plan.expected_answer_type: hints.append(f"Expected answer type: {plan.expected_answer_type}") if plan.priority_predicates: - hints.append( - "Important relations: " + ", ".join(plan.priority_predicates[:5]) - ) + hints.append("Important relations: " + ", ".join(plan.priority_predicates[:5])) if not hints: return query return query + " | " + " | ".join(hints) diff --git a/benchmarks/runner.py b/benchmarks/runner.py index 7e16bfe..8c7af58 100644 --- a/benchmarks/runner.py +++ b/benchmarks/runner.py @@ -225,8 +225,7 @@ def _build_evidence_block( if passages: doc_lines = [ - f"Document {i + 1} ({title}): {text}" - for i, (title, text) in enumerate(passages) + f"Document {i + 1} ({title}): {text}" for i, (title, text) in enumerate(passages) ] parts.append("Relevant Documents:\n" + "\n".join(doc_lines)) @@ -459,17 +458,14 @@ async def answer_one( # round. Both IRCoT rounds share the same plan — it's a property of # the question, not of the retrieved evidence (that's IRCoT's job). retrieval_plan = None - if ( - use_retrieval_planner - and use_kg_retrieval - and planner_llm is not None - and query_entities - ): + if use_retrieval_planner and use_kg_retrieval and planner_llm is not None and query_entities: from engram.core.graph_view import build_query_graph_view from engram.dialogue.retrieval_planner import plan_retrieval view = build_query_graph_view( - question.question, query_entities, backend # type: ignore[arg-type] + question.question, + query_entities, + backend, # type: ignore[arg-type] ) # Capture the raw plan (pre-floor) for diagnostics; only the # post-floor result is used downstream. @@ -489,9 +485,7 @@ async def answer_one( "query_entities": list(query_entities), "unresolved_mentions": list(view.unresolved_mentions), "resolved_entities": [e.name for e in view.entities], - "plan": ( - retrieval_plan.model_dump() if retrieval_plan else None - ), + "plan": (retrieval_plan.model_dump() if retrieval_plan else None), "raw_plan": raw.model_dump() if raw is not None else None, } ) @@ -562,9 +556,7 @@ def _to_passages(chunks: Sequence[Chunk]) -> list[tuple[str, str]]: # ----- IRCoT Round 2: re-retrieve using the Round 1 thought ----- thought = _extract_thought_span(raw_r1) augmented_query = ( - f"{question.question}\n\nReasoning so far: {thought}" - if thought - else question.question + f"{question.question}\n\nReasoning so far: {thought}" if thought else question.question ) neighbors_r2, facts_r2 = await _retrieve(augmented_query) diff --git a/src/engram/backends/memory.py b/src/engram/backends/memory.py index 5bd4132..7aaa3fd 100644 --- a/src/engram/backends/memory.py +++ b/src/engram/backends/memory.py @@ -27,6 +27,7 @@ from __future__ import annotations import asyncio +import logging import struct import tempfile from datetime import datetime @@ -46,6 +47,8 @@ from engram.core.protocol import Embedder +logger = logging.getLogger(__name__) + _DEFAULT_MAP_SIZE = 1 << 30 # 1 GiB _DEFAULT_HNSW_M = 16 _DEFAULT_HNSW_EF_CONSTRUCTION = 200 diff --git a/src/engram/core/graph_view.py b/src/engram/core/graph_view.py index d13d657..1ef54c8 100644 --- a/src/engram/core/graph_view.py +++ b/src/engram/core/graph_view.py @@ -175,9 +175,7 @@ def _top_k_edges_by_confidence( return [es for _, es in summaries[:k]] -def _predicate_histogram( - backend: MemoryBackend, top_n: int -) -> list[tuple[str, int]]: +def _predicate_histogram(backend: MemoryBackend, top_n: int) -> list[tuple[str, int]]: """Top-N corpus predicates by fact count.""" counts: Counter[str] = Counter() for fact in backend.iter_facts(): diff --git a/src/engram/dialogue/prompts/retrieval_plan.py b/src/engram/dialogue/prompts/retrieval_plan.py index 371b66f..eb27dd1 100644 --- a/src/engram/dialogue/prompts/retrieval_plan.py +++ b/src/engram/dialogue/prompts/retrieval_plan.py @@ -83,10 +83,7 @@ class HopStep(BaseModel): ) ) reason: str = Field( - description=( - "One short sentence explaining why this hop is needed for " - "the question." - ) + description=("One short sentence explaining why this hop is needed for the question.") ) @@ -321,8 +318,7 @@ def _render_view(view: CompressedGraphView) -> str: lines: list[str] = [] for ent in view.entities: lines.append( - f"Entity: {ent.name} (type={ent.entity_type}, " - f"total_degree={ent.total_degree})" + f"Entity: {ent.name} (type={ent.entity_type}, total_degree={ent.total_degree})" ) for edge in ent.out_edges: lines.append( @@ -337,9 +333,7 @@ def _render_view(view: CompressedGraphView) -> str: entity_block = "\n".join(lines) if view.available_predicates: - pred_summary = ", ".join( - f"{p}({c})" for p, c in view.available_predicates - ) + pred_summary = ", ".join(f"{p}({c})" for p, c in view.available_predicates) else: pred_summary = "(empty corpus)" @@ -357,9 +351,7 @@ def _render_view(view: CompressedGraphView) -> str: ) -def build_retrieval_plan_prompt( - *, view: CompressedGraphView -) -> PromptCall: +def build_retrieval_plan_prompt(*, view: CompressedGraphView) -> PromptCall: """Build the planner LLM call for one query. The caller is responsible for materializing the diff --git a/tests/integration/test_kg_retrieval_with_plan.py b/tests/integration/test_kg_retrieval_with_plan.py index 2727dce..6be3bfc 100644 --- a/tests/integration/test_kg_retrieval_with_plan.py +++ b/tests/integration/test_kg_retrieval_with_plan.py @@ -18,10 +18,8 @@ import hashlib import struct from collections.abc import Sequence -from typing import Any import pytest - from benchmarks.retrieval import BM25Index, kg_hybrid_neighbors from engram.backends.memory import MemoryBackend diff --git a/tests/unit/test_core_graph_view.py b/tests/unit/test_core_graph_view.py index 4e43298..546a3df 100644 --- a/tests/unit/test_core_graph_view.py +++ b/tests/unit/test_core_graph_view.py @@ -8,17 +8,12 @@ from __future__ import annotations -from datetime import datetime from typing import Any -from unittest.mock import patch import networkx as nx from engram.core.graph_view import ( - DEFAULT_MAX_EDGES_PER_ENTITY, - DEFAULT_PREDICATE_TOP_N, CompressedGraphView, - EntityNeighborhood, build_query_graph_view, ) from engram.core.models import EntityRecord, Fact @@ -154,7 +149,9 @@ def test_build_view_predicate_histogram_top_n() -> None: _fact("X", "rel_c", "Z"), ] for f in facts: - _add_edge(g, f.subject, f.object, predicate=f.predicate, confidence=f.confidence, fact_id=f.id) + _add_edge( + g, f.subject, f.object, predicate=f.predicate, confidence=f.confidence, fact_id=f.id + ) backend = _StubBackend(g, facts=facts) view = build_query_graph_view("test", [], backend, predicate_top_n=2) # Top-2: rel_a (count 2) then either rel_b or rel_c (tied at 1). @@ -167,9 +164,7 @@ def test_build_view_unresolved_mentions_listed() -> None: g = nx.MultiDiGraph() g.add_node("Apple Inc.") backend = _StubBackend(g) - view = build_query_graph_view( - "Where", ["Apple Inc.", "Banana Republic", " ", ""], backend - ) + view = build_query_graph_view("Where", ["Apple Inc.", "Banana Republic", " ", ""], backend) assert [e.name for e in view.entities] == ["Apple Inc."] assert view.unresolved_mentions == ["Banana Republic"] diff --git a/tests/unit/test_dialogue_retrieval_planner.py b/tests/unit/test_dialogue_retrieval_planner.py index e7a6e52..9123a77 100644 --- a/tests/unit/test_dialogue_retrieval_planner.py +++ b/tests/unit/test_dialogue_retrieval_planner.py @@ -34,9 +34,7 @@ def __init__(self, responses: list[BaseModel | Exception]) -> None: self._responses = list(responses) self.calls: list[dict[str, Any]] = [] - async def complete( - self, messages: Sequence[dict[str, str]], **kwargs: Any - ) -> str: + async def complete(self, messages: Sequence[dict[str, str]], **kwargs: Any) -> str: del messages, kwargs return "" @@ -50,9 +48,7 @@ async def extract( **kwargs: Any, ) -> T: del kwargs - self.calls.append( - {"messages": list(messages), "schema": schema, "model": model} - ) + self.calls.append({"messages": list(messages), "schema": schema, "model": model}) response = self._responses.pop(0) if isinstance(response, Exception): raise response From ff9cff80e9281cd1c0c3f833ce3346e21491ddd2 Mon Sep 17 00:00:00 2001 From: Vedant Patel Date: Mon, 8 Jun 2026 12:19:46 -0700 Subject: [PATCH 3/3] fix(ci): auto-mark tests/integration/ with the integration marker CI runs pytest -m "not integration and not slow", but the new integration test test_kg_retrieval_with_plan.py wasn't marked, so CI collected it and the test failed importing BM25Index (which requires the `benchmarks` extra not installed in CI). Add tests/integration/conftest.py that scopes pytest_collection_modify items to items whose file path is under tests/integration/, then adds the integration marker. Path-scoped (not session-global) so tests outside this directory keep their existing markers. Verified: pytest tests/ -m "not integration and not slow" selects 363 unit tests (deselects 3 integration), all 366 still pass when run without the filter. Co-Authored-By: Claude Opus 4.7 --- tests/integration/conftest.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/integration/conftest.py diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 0000000..978a5f0 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,30 @@ +"""Auto-mark every test in ``tests/integration/`` with the ``integration`` marker. + +CI runs ``pytest -m "not integration and not slow"`` to skip tests +that touch external services or local disk-backed stores. Tests in +this directory typically build a real ``MemoryBackend`` on +``tmp_path`` and exercise the LMDB/hnswlib stack, so they all get +the marker by directory. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +_THIS_DIR = Path(__file__).resolve().parent + + +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: + del config + for item in items: + try: + item_path = Path(str(item.fspath)).resolve() + except (AttributeError, OSError): + continue + try: + item_path.relative_to(_THIS_DIR) + except ValueError: + continue + item.add_marker(pytest.mark.integration)