Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions benchmarks/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ class QueryDecomposition(BaseModel):
"- A sub-question's answer must be a short factual span (name, "
" year, number, place) — not a description.\n"
"- Use ``depends_on`` to record the reasoning chain: if Q2's "
" phrasing requires Q1's answer (\"What year was [X] born?\" "
' phrasing requires Q1\'s answer ("What year was [X] born?" '
" where X is found by Q0), set depends_on=[0].\n"
"- For SIMPLE single-hop questions (\"What is the capital of "
" France?\"), return ONE sub-question matching the original. "
'- For SIMPLE single-hop questions ("What is the capital of '
' France?"), return ONE sub-question matching the original. '
" Over-decomposing single-hop questions hurts retrieval.\n"
"- Maximum 5 sub-questions. Prefer fewer.\n"
"- target_entity is what you expect to discover; may be empty.\n\n"
Expand Down Expand Up @@ -161,8 +161,7 @@ async def decompose_query(
)
except Exception as exc:
logger.warning(
"Query decomposition failed for question, falling back to "
"single-query retrieval: %s",
"Query decomposition failed for question, falling back to single-query retrieval: %s",
exc,
)
return [question]
Expand Down
273 changes: 273 additions & 0 deletions benchmarks/failure_tagger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
"""Phase 0 failure tagger for the graph-aware planner branch.

One LLM call per failed prediction classifies the failure mode. Used
once, before any planner implementation, to confirm there's a real
addressable surface for the planner. Output drives the decision gate:
if wrong-hop-target + answer-type-mismatch failures together account
for less than ~15% of all failures, the planner approach is killed
before any code is written.

Run as a script:

.venv/bin/python -m benchmarks.failure_tagger \\
--input predictions_ircot_synth_on.jsonl \\
--mode baseline \\
--output tagged_failures.jsonl

Failures are defined by exact-match (case + punct + article normalized);
the classifier only sees the cases that failed under that scorer.
"""

from __future__ import annotations

import argparse
import asyncio
import json
import logging
import re
import sys
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Literal

from pydantic import BaseModel, ConfigDict, Field

if TYPE_CHECKING:
from engram.core.protocol import LLMProvider


logger = logging.getLogger(__name__)


FailureCategory = Literal[
"format-mismatch",
"wrong-hop-target",
"hallucination",
"missing-info",
"answer-type-mismatch",
"other",
]


class FailureTag(BaseModel):
"""LLM verdict for one failed prediction."""

model_config = ConfigDict(extra="forbid")

reasoning: str = Field(
description="Brief justification before the category. One or two sentences."
)
category: FailureCategory = Field(description="The single best-fit failure category.")


_SYSTEM_PROMPT = """\
You are a benchmark failure-mode classifier. You see a question, the
gold answer(s) the benchmark expects, and a model's prediction that
scored 0 on exact-match. Your job: pick the SINGLE best-fit failure
category from the list below.

Categories (pick exactly one):

- format-mismatch:
The prediction CONTAINS the correct answer but adds extra words,
explanations, modifiers, or has different formatting. A human
reading the prediction would see the right answer is in there.
Examples:
Gold "blackmail" Pred "Japanese blackmail."
Gold "Amy Poehler" Pred "Amy Poehler (previously married to Will Arnett)."
Gold "Association for Computing Machinery" Pred "ACM (Association for Computing Machinery)"
Gold "22" Pred "22 times"

- wrong-hop-target:
Multi-hop question; the prediction is a wrong intermediate entity
or the wrong final entity from the reasoning chain. The model
landed on a nearby-but-wrong node. NOT a missing-info case.
Examples:
Q: "Who voices the character X?" Gold "Mr. Lawrence" Pred "Plankton" (character, not voice actor)
Q: "Who owns the record label of the performer of X?"
Gold "Warner Music Group" Pred "Shawn Lane owns Eye Reckon Records."
Q: "Spouse of the founder of X?" Gold "Real Spouse" Pred "The founder is Y" (gave intermediate)

- hallucination:
Prediction is a confident wrong answer that doesn't appear in any
reasonable evidence and isn't a hop-confusion. Often a plausible
but unrelated entity, or invented detail.
Examples:
Gold "Michael Buble" Pred "Josh Groban" (similar singer, no chain to support it)
Gold "1929" Pred "1492" (invented date)

- missing-info:
Prediction expresses inability to answer — "unknown",
"evidence does not provide", "no information found", "cannot be
determined", etc. The retrieval failed.
Examples:
Pred "The evidence does not provide information about ..."
Pred "Unknown"

- answer-type-mismatch:
Prediction is the wrong TYPE of thing the question asked for.
Question asked WHO; prediction gave a number/value. Question asked
HOW LONG; prediction gave a date. Question asked WHERE; prediction
gave a person.
Examples:
Q: "Who has the lowest batting average?" Gold "Bill Bergen" Pred ".170" (gave the value, not the person)
Q: "How long?" Gold "4 years" Pred "2005" (gave a date, not a duration)

- other:
Doesn't cleanly fit any above. Use sparingly.

Rules:

- Pick exactly one category. If two categories seem plausible, pick
the one that more strongly explains the failure.
- format-mismatch is the LOOSEST category — only use it when the
exact gold string (or a tiny variant) is clearly present in the
prediction. If the prediction has the wrong PRIMARY entity, prefer
wrong-hop-target even if some token overlap exists.
- wrong-hop-target is for multi-hop reasoning errors specifically.
Single-hop "wrong entity" with no chain reasoning is hallucination.
- Fill reasoning briefly before the category.
"""


def build_tag_prompt(
*, question: str, gold_answers: Sequence[str], prediction: str
) -> list[dict[str, str]]:
gold_block = " | ".join(f"{g!r}" for g in gold_answers)
user_msg = (
f"Question: {question}\n"
f"Gold answer(s): {gold_block}\n"
f"Prediction: {prediction!r}\n\n"
"Classify this failure. Return a FailureTag."
)
return [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
]


async def tag_failure(
*,
question: str,
gold_answers: Sequence[str],
prediction: str,
llm: LLMProvider,
model: str | None = None,
) -> FailureTag | None:
"""Return a single FailureTag or None on LLM error."""
messages = build_tag_prompt(question=question, gold_answers=gold_answers, prediction=prediction)
try:
return await llm.extract(messages, FailureTag, model=model, temperature=0.0)
except Exception as exc:
logger.warning("Tag failed for question %r: %s", question[:60], exc)
return None


_PUNCT_RE = re.compile(r"[^a-z0-9 ]")
_ARTICLE_RE = re.compile(r"\b(a|an|the)\b")


def _normalize(s: str) -> str:
s = s.lower()
s = _ARTICLE_RE.sub(" ", s)
s = _PUNCT_RE.sub(" ", s)
return " ".join(s.split())


def is_em_failure(prediction: str, gold_answers: Sequence[str]) -> bool:
"""SQuAD-style EM check; failure when no gold matches."""
pred_n = _normalize(prediction)
return not any(pred_n == _normalize(g) for g in gold_answers)


async def main_async(args: argparse.Namespace) -> int:
from engram.llm.litellm_provider import LiteLLMProvider

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)

preds = [json.loads(line) for line in args.input.open()]
if args.mode:
preds = [p for p in preds if p.get("mode") == args.mode]

failures = [p for p in preds if is_em_failure(p["prediction"], p["gold_answers"])]
logger.info(
"Loaded %d predictions; %d failures (mode=%s)",
len(preds),
len(failures),
args.mode or "all",
)

llm = LiteLLMProvider(default_model=args.model)
sem = asyncio.Semaphore(args.concurrency)
tagged: list[dict] = []

async def _bounded(p: dict) -> None:
async with sem:
tag = await tag_failure(
question=p["question"],
gold_answers=p["gold_answers"],
prediction=p["prediction"],
llm=llm,
model=args.model,
)
tagged.append(
{
"question_id": p["question_id"],
"question": p["question"],
"gold_answers": p["gold_answers"],
"prediction": p["prediction"],
"category": tag.category if tag else "other",
"reasoning": tag.reasoning if tag else "(tag failed)",
}
)

await asyncio.gather(*[_bounded(p) for p in failures])

with args.output.open("w") as fh:
for record in tagged:
fh.write(json.dumps(record) + "\n")
logger.info("Wrote %d tagged records to %s", len(tagged), args.output)

from collections import Counter

counts = Counter(r["category"] for r in tagged)
total = len(tagged) or 1
print("\nFailure category histogram:")
for cat, n in counts.most_common():
print(f" {cat:25s} {n:4d} ({100 * n / total:.1f}%)")

planner_addressable = counts.get("wrong-hop-target", 0) + counts.get("answer-type-mismatch", 0)
pct = 100 * planner_addressable / total
print(
f"\nPlanner-addressable (wrong-hop + answer-type-mismatch): "
f"{planner_addressable}/{total} = {pct:.1f}%"
)
print(f"Gate (>=15%): {'PASS' if pct >= 15 else 'FAIL'}")
return 0


def _parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--input", type=Path, required=True)
p.add_argument(
"--mode",
type=str,
default=None,
help="Filter predictions by mode ('baseline' or 'enriched'). Default: all.",
)
p.add_argument("--output", type=Path, required=True)
p.add_argument("--model", type=str, default="openai/gpt-4o-mini")
p.add_argument("--concurrency", type=int, default=8)
return p.parse_args()


def main() -> int:
args = _parse_args()
return asyncio.run(main_async(args))


if __name__ == "__main__":
sys.exit(main())
53 changes: 42 additions & 11 deletions benchmarks/musique.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,29 @@ def _parse_args() -> argparse.Namespace:
"layer — cleaner ablation."
),
)
parser.add_argument(
"--retrieval-planner",
action="store_true",
default=False,
help=(
"Graph-aware retrieval planner: one LLM call per question "
"produces an explicit hop sequence + expected answer type + "
"priority predicates that bias kg_hybrid_neighbors (beam "
"predicate boost, fact-object type filter, plan-aware "
"rerank query). Requires --kg-retrieval. Bypasses cleanly "
"when planner abstains (confidence < 0.5)."
),
)
parser.add_argument(
"--trace-retrieval-plan",
type=Path,
default=None,
help=(
"Diagnostic JSONL dump of (question, query_entities, view, "
"plan) per question when --retrieval-planner is on. Used "
"for manual review of plan quality during smoke runs."
),
)
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose logging.")
return parser.parse_args()

Expand Down Expand Up @@ -374,6 +397,7 @@ async def _amain(args: argparse.Namespace) -> int:
summaries: dict[str, dict] = {}
all_predictions: list = []
backends: list[MemoryBackend] = []
plan_trace: list[dict] | None = [] if args.trace_retrieval_plan else None

try:
for mode in modes:
Expand All @@ -393,9 +417,7 @@ async def _amain(args: argparse.Namespace) -> int:
backends.append(backend)

if mode == "baseline":
bm25 = await ingest_baseline(
questions, backend=backend, reingest=args.reingest
)
bm25 = await ingest_baseline(questions, backend=backend, reingest=args.reingest)
else:
engram_llm = LiteLLMProvider(
default_model=args.reader_model,
Expand Down Expand Up @@ -427,19 +449,14 @@ async def _amain(args: argparse.Namespace) -> int:

logger.info("rate-control stats after %s ingest: %s", mode, adaptive.stats)

use_graph_retrieval = (
mode == "enriched" and args.graph_retrieval and args.build_graph
)
use_kg_retrieval = (
mode == "enriched" and args.kg_retrieval and args.build_graph
)
use_graph_retrieval = mode == "enriched" and args.graph_retrieval and args.build_graph
use_kg_retrieval = mode == "enriched" and args.kg_retrieval and args.build_graph
# Decomposition runs in enriched mode only — it's Engram's
# multi-hop structural intervention, paired with the fact graph.
use_decomposition = args.decompose and mode == "enriched"
if use_graph_retrieval:
logger.info(
"%s: graph-aware retrieval enabled "
"(entity extraction + 2-hop fact lookup)",
"%s: graph-aware retrieval enabled (entity extraction + 2-hop fact lookup)",
mode,
)
if use_kg_retrieval:
Expand Down Expand Up @@ -477,6 +494,10 @@ async def _amain(args: argparse.Namespace) -> int:
decomposer_llm=reader if use_decomposition else None,
decomposer_model=args.reader_model if use_decomposition else None,
reranker=cohere_reranker,
use_retrieval_planner=args.retrieval_planner,
planner_llm=reader if args.retrieval_planner else None,
planner_model=args.reader_model if args.retrieval_planner else None,
plan_trace=plan_trace,
)
summary = score_predictions(predictions_for_scoring(predictions))
summaries[mode] = summary.as_dict()
Expand Down Expand Up @@ -513,6 +534,16 @@ async def _amain(args: argparse.Namespace) -> int:
+ "\n"
)
logger.info("Wrote %d predictions to %s", len(all_predictions), args.output)

if args.trace_retrieval_plan and plan_trace is not None:
with args.trace_retrieval_plan.open("w") as fh:
for record in plan_trace:
fh.write(json.dumps(record) + "\n")
logger.info(
"Wrote %d retrieval-plan trace records to %s",
len(plan_trace),
args.trace_retrieval_plan,
)
finally:
for backend in backends:
backend.close()
Expand Down
Loading
Loading