diff --git a/.gitignore b/.gitignore index 211f988..c752342 100644 --- a/.gitignore +++ b/.gitignore @@ -18,11 +18,12 @@ env/ *.swp *.swo -# Testing +# Testing / tooling caches .pytest_cache/ .coverage htmlcov/ .mypy_cache/ +.ruff_cache/ # Models (downloaded at runtime) models/ @@ -39,3 +40,10 @@ Thumbs.db # Environment .env + +# Claude Code local settings +.claude/ +.odin/ +scratchpad/ +# Recording build artifact (GIF is committed; cast is regenerated) +assets/*.cast diff --git a/README.md b/README.md index 322fc26..29395ae 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,25 @@ No document ingestion. No chunking. No agents. No database. Works identically on ![License MIT](https://img.shields.io/badge/license-MIT-green) ![Version](https://img.shields.io/badge/version-0.1.0-orange) +## Stop hallucinations before they cascade + +In a multi-step agent, each step's output feeds the next — a single fabricated +figure propagates straight into the final answer. `verify_step()` is a circuit +breaker that halts the chain the moment a claim stops being grounded in the +evidence: + +![Agent circuit-breaker demo](assets/circuit_breaker.gif) + +```python +from athena_verify import verify_step + +step = verify_step(claim=reasoning_step, evidence=retrieved_chunks, threshold=0.5) +if step.action == "halt": + raise RuntimeError(f"Ungrounded claim blocked (trust={step.trust_score:.2f})") +``` + +Run it yourself: [`examples/agent_circuit_breaker.py`](examples/agent_circuit_breaker.py). + ## How It Works ``` @@ -81,21 +100,46 @@ pip install "athena-verify[all]" Evaluated on 100 synthetic cases across 6 hallucination categories (legal, medical, technical, general). Real-world benchmarks against RAGTruth and HaluEval are in progress — download instructions are in [`benchmarks/RESULTS.md`](benchmarks/RESULTS.md). -### Per-Category Performance (NLI-only, synthetic, nli-deberta-v3-base) +### Hallucination Detection (NLI-only, synthetic, nli-deberta-v3-base) + +Each row is the per-category F1 for *catching hallucinations*. The faithful-text +row is intentionally excluded here — it contains no hallucinations, so its F1 is +undefined; we report its false-positive rate separately below, which is the +number that actually matters for clean text. | Category | Precision | Recall | **F1** | |----------|-----------|--------|--------| -| **Fabricated claims** | 100% | 97% | **98.6%** ✓ | +| **Fabricated claims** | 100% | 96% | **97.9%** ✓ | | **Out-of-context** | 100% | 97% | **98.3%** ✓ | | **Subtle contradictions** | 100% | 97% | **98.3%** ✓ | -| **Number substitutions** | 79% | 96% | **86.8%** | -| **Partial support** | 78% | 95% | **85.7%** | -| **Faithful statements** | 0% | 0% | **0.0%** ✗ | -| **Overall** | 87% | 97% | **91.3%** (synthetic) | - -### Where We Lose - -Athena has a **high false positive rate on truly faithful statements** (31% of genuinely faithful sentences are incorrectly flagged). This is a known NLI-model limitation — conservative thresholds bias toward catching hallucinations at the cost of flagging clean sentences. +| **Partial support** | 95% | 91% | **93.0%** | +| **Number substitutions** | 82% | 96% | **88.5%** | +| **Overall** | 95% | 96% | **95.0%** (synthetic) | + +**False-positive rate on faithful text: 4.6%** (4 of 87 genuinely-supported +sentences flagged) on the base model, **3.4%** on the large model — down from 17% +before calibration. Latency: **p50 22.5 ms, p95 34.5 ms** per verification on the +base model. Numbers are reproducible with `python benchmarks/run_full_eval.py`. + +### How false positives are kept low + +Standalone NLI scores many faithful paraphrases as "neutral" (entailment ≈ 0) +even when the claim is fully supported. Athena recovers these without letting +hallucinations through, using three guarded signals: + +- **Anaphora windowing** — a sentence opening with a referent ("This cap…", "It + also…") is scored together with its predecessor, restoring the antecedent. +- **Contradiction-aware rescue** — a not-entailed claim is only rescued when the + most on-topic context unit does *not* contradict it, so reversals and subtle + contradictions stay flagged. +- **Numeric gate** — rescue requires every number in the claim to appear in the + context, so number-substitution hallucinations ("$5M" vs a "$2M" context) are + never rescued. + +The remaining false positives are heavily-paraphrased claims with little lexical +overlap (e.g. "olive oil is drizzled on top"); enable the optional LLM-judge +escalation (`use_llm_judge=True`) for those. Athena still biases toward catching +hallucinations over passing every clean sentence — treat it as a guardrail. **LettuceDetect beats athena on span-level F1** on real-world benchmarks (LettuceDetect 79.2% F1 on annotated spans vs. athena's unvalidated real-world score). Athena wins on latency bounds, provider-neutrality, offline execution, and the spans-in-library integration story — not raw F1. diff --git a/assets/circuit_breaker.gif b/assets/circuit_breaker.gif new file mode 100644 index 0000000..5394c8a Binary files /dev/null and b/assets/circuit_breaker.gif differ diff --git a/assets/circuit_breaker.tape b/assets/circuit_breaker.tape new file mode 100644 index 0000000..666c25e --- /dev/null +++ b/assets/circuit_breaker.tape @@ -0,0 +1,25 @@ +# VHS tape — renders the agent circuit-breaker demo to a GIF. +# Regenerate with: vhs assets/circuit_breaker.tape +Output assets/circuit_breaker.gif + +Set Shell "bash" +Set FontSize 18 +Set Width 1180 +Set Height 760 +Set Padding 24 +Set Theme "Catppuccin Mocha" + +# Activate the project venv off-screen so the visible command is just `python …`. +Hide +Type "source .venv/bin/activate" Enter +Type "clear" Enter +Show + +Type "python examples/agent_circuit_breaker.py" +Sleep 600ms +Enter + +# First call loads the NLI model (~3s), then one step streams every ~0.5s. +Sleep 8s +# Linger on the circuit-breaker result. +Sleep 2s diff --git a/athena_verify/__init__.py b/athena_verify/__init__.py index 95e11c7..2a97613 100644 --- a/athena_verify/__init__.py +++ b/athena_verify/__init__.py @@ -26,7 +26,14 @@ verify_stream, ) from athena_verify.llm_judge import LLMClient -from athena_verify.models import Chunk, SentenceScore, StepResult, StreamingResult, SupportingSpan, VerificationResult +from athena_verify.models import ( + Chunk, + SentenceScore, + StepResult, + StreamingResult, + SupportingSpan, + VerificationResult, +) __all__ = [ "verify", diff --git a/athena_verify/calibration.py b/athena_verify/calibration.py index 73b3acc..89558d6 100644 --- a/athena_verify/calibration.py +++ b/athena_verify/calibration.py @@ -20,6 +20,16 @@ PARTIAL_THRESHOLD = 0.50 UNSUPPORTED_THRESHOLD = 0.30 +# Grounding-rescue thresholds. Cross-encoder NLI frequently scores a faithful +# paraphrase as "neutral" (entailment ~0) even though the claim is fully +# supported. When the claim is *not* contradicted, is heavily lexically +# grounded, and all its numbers appear in the context, we lift it out of the +# unsupported band — recovering false positives without passing contradictions +# or number swaps (which fail the contradiction / numeric guards). +RESCUE_CONTRADICTION_CEILING = 0.45 +RESCUE_CONTAINMENT_FLOOR = 0.50 +RESCUE_TRUST = 0.55 + def compute_trust_score( nli_score: float, @@ -57,6 +67,43 @@ def compute_trust_score( return min(1.0, max(0.0, trust)) +def apply_grounding_rescue( + trust: float, + *, + entailment: float, + contradiction: float, + containment: float, + numeric_ok: bool, +) -> float: + """Lift trust for neutral-but-grounded paraphrases NLI scores too low. + + Only ever raises the score, and only when all guards pass: + - the claim is not contradicted by any context unit, + - it is not already strongly entailed (nothing to rescue), + - its content words are heavily present in the context, and + - every number in it appears in the context. + + Args: + trust: The ensemble trust score before rescue. + entailment: Max NLI entailment probability for the sentence. + contradiction: Max NLI contradiction probability for the sentence. + containment: Fraction of content words found in the context. + numeric_ok: Whether all numbers in the sentence appear in the context. + + Returns: + The (possibly raised) trust score. + """ + if contradiction >= RESCUE_CONTRADICTION_CEILING: + return trust + if entailment >= SUPPORTED_THRESHOLD: + return trust + if not numeric_ok: + return trust + if containment >= RESCUE_CONTAINMENT_FLOOR: + return max(trust, RESCUE_TRUST) + return trust + + def classify_support(trust_score: float) -> str: """Classify a sentence's support status based on trust score. diff --git a/athena_verify/cli.py b/athena_verify/cli.py index 3a83107..eb51372 100644 --- a/athena_verify/cli.py +++ b/athena_verify/cli.py @@ -8,6 +8,7 @@ from pathlib import Path from athena_verify import verify +from athena_verify.models import VerificationResult def color_score(score: float) -> str: @@ -30,7 +31,7 @@ def format_trust_score(score: float, width: int = 6) -> str: return f"{color_score(score)}{score:.2f}{reset_color()}" -def print_table(result) -> None: +def print_table(result: VerificationResult) -> None: """Print colored sentence-by-sentence trust score table.""" print() print("Verification Results") diff --git a/athena_verify/core.py b/athena_verify/core.py index b6012a9..d2599ea 100644 --- a/athena_verify/core.py +++ b/athena_verify/core.py @@ -7,7 +7,9 @@ from __future__ import annotations +import asyncio import os +import re import time from collections.abc import AsyncIterator from typing import Any @@ -15,6 +17,7 @@ import structlog from athena_verify.calibration import ( + apply_grounding_rescue, classify_support, compute_overall_trust, compute_trust_score, @@ -28,12 +31,205 @@ SupportingSpan, VerificationResult, ) -from athena_verify.nli import batch_compute_entailment, batch_compute_entailment_async -from athena_verify.overlap import best_overlap_score +from athena_verify.nli import batch_compute_nli +from athena_verify.overlap import best_overlap_score, containment_score, numeric_consistency from athena_verify.parser import sentence_buffer, split_sentences logger = structlog.get_logger() +# Span-level entailment threshold: a context unit must clear this to be +# reported as a supporting span for a sentence. +_SPAN_ENTAILMENT_THRESHOLD = 0.5 + +# Return shape of _ground_sentences: (entailment scores, contradiction scores, +# per-sentence span-unit scores, span unit texts, span unit (chunk_idx, start, +# end) locations). +_GroundResult = tuple[ + list[float], list[float], list[list[float]], list[str], list[tuple[int, int, int]] +] + +# Leading tokens that signal a sentence depends on its predecessor for meaning +# (anaphora / discourse continuation). When an answer sentence starts with one +# of these, NLI scored on the sentence in isolation collapses to ~0 even when +# the claim is fully grounded, because the referent ("it", "this cap") is gone. +# We prepend the previous sentence to restore the antecedent before scoring. +_ANAPHORA_TOKENS = frozenset( + { + "it", + "its", + "it's", + "this", + "that", + "these", + "those", + "they", + "them", + "their", + "theirs", + "he", + "she", + "his", + "her", + "hers", + "such", + "also", + "additionally", + "moreover", + "furthermore", + "however", + "therefore", + "thus", + "then", + "there", + "both", + "neither", + "either", + } +) + + +def _starts_with_anaphor(sentence: str) -> bool: + """True if a sentence opens with a pronoun/discourse marker needing context.""" + stripped = sentence.strip() + if not stripped: + return False + first = stripped.split(maxsplit=1)[0].lower().strip(",.;:\"'()") + return first in _ANAPHORA_TOKENS + + +def _word_tokens(text: str) -> list[str]: + """Lowercase content tokens (length > 2) for lightweight topic matching.""" + return [w for w in re.findall(r"[a-z0-9]+", text.lower()) if len(w) > 2] + + +def _build_context_units( + chunk_texts: list[str], +) -> tuple[list[str], list[tuple[int, int, int]], list[bool]]: + """Expand context chunks into NLI premise candidates. + + Returns parallel lists of: + - unit text, + - (chunk_idx, char_start, char_end) location into the original chunk, + - is_span_unit flag (True for sentence-level units usable as precise + supporting spans, False for whole-chunk fallback premises). + + Each chunk contributes its individual sentences (focused premises that + avoid the long-premise "neutral" bias) plus, when it has more than one + sentence, the full chunk text (so facts spread across several sentences + are still entailed). NLI takes the max over all candidates, so adding the + whole-chunk premise can only raise a faithful sentence's score. + """ + units: list[str] = [] + locations: list[tuple[int, int, int]] = [] + is_span_unit: list[bool] = [] + for chunk_idx, chunk in enumerate(chunk_texts): + sub = split_sentences(chunk) or [chunk] + for unit in sub: + char_start = chunk.find(unit) + if char_start == -1: + char_start = 0 + units.append(unit) + locations.append((chunk_idx, char_start, char_start + len(unit))) + is_span_unit.append(True) + if len(sub) > 1: + units.append(chunk) + locations.append((chunk_idx, 0, len(chunk))) + is_span_unit.append(False) + return units, locations, is_span_unit + + +def _ground_sentences( + sentences: list[str], + chunk_texts: list[str], + nli_model: str, +) -> _GroundResult: + """Score how well each answer sentence is grounded in the context. + + Splits context into focused premise candidates, applies anaphora windowing + to each hypothesis, and returns: + - entail_scores: best entailment per sentence over all premise candidates, + - contra_scores: strongest contradiction per sentence over all candidates, + - span_scores: per-sentence entailment over the span-eligible units only, + - span_units / span_locations: the span-eligible units these align to. + """ + units, locations, is_span_unit = _build_context_units(chunk_texts) + span_units = [u for u, keep in zip(units, is_span_unit, strict=True) if keep] + span_locations = [ + loc for loc, keep in zip(locations, is_span_unit, strict=True) if keep + ] + + if not units or not sentences: + empty_spans: list[list[float]] = [[] for _ in sentences] + zeros = [0.0] * len(sentences) + return zeros, list(zeros), empty_spans, span_units, span_locations + + # Anaphora windowing: prepend the previous sentence when the current one + # opens with a referent, so the NLI hypothesis carries its antecedent. + hypotheses: list[str] = [] + for i, sentence in enumerate(sentences): + if i > 0 and _starts_with_anaphor(sentence): + hypotheses.append(f"{sentences[i - 1]} {sentence}") + else: + hypotheses.append(sentence) + + nli_pairs = [(unit, hyp) for hyp in hypotheses for unit in units] + flat = batch_compute_nli(nli_pairs, model_name=nli_model) + + # Token sets per unit, for picking the on-topic unit for the contradiction + # signal (see below). + unit_tokens = [set(_word_tokens(u)) for u in units] + + entail_scores: list[float] = [] + contra_scores: list[float] = [] + span_scores: list[list[float]] = [] + n_units = len(units) + for i in range(len(sentences)): + rows = flat[i * n_units : (i + 1) * n_units] + entails = [e for e, _ in rows] + entail_scores.append(max(entails) if entails else 0.0) + # Contradiction is read from the unit most lexically on-topic with the + # claim, not the global max. An unrelated context unit frequently + # "contradicts" a claim it has nothing to do with (negations, sibling + # clauses), which would veto a faithful sentence; the on-topic unit + # still fires for genuine contradictions (number swaps, reversals) + # because those reuse the same vocabulary. + hyp_tokens = set(_word_tokens(hypotheses[i])) + if rows and hyp_tokens: + relevance = [len(hyp_tokens & ut) for ut in unit_tokens] + topic = max(range(len(rows)), key=lambda k: (relevance[k], entails[k])) + contra_scores.append(rows[topic][1]) + else: + contra_scores.append(0.0) + span_scores.append( + [e for e, keep in zip(entails, is_span_unit, strict=True) if keep] + ) + return entail_scores, contra_scores, span_scores, span_units, span_locations + + +def _trust_and_status( + *, + entailment: float, + contradiction: float, + overlap: float, + sentence: str, + context_text: str, + llm: float | None, + weights: dict[str, float] | None, +) -> tuple[float, str]: + """Combine signals into a trust score, apply the grounding rescue, classify. + + Shared by every verify entry point so they score identically. + """ + trust = compute_trust_score(entailment, overlap, llm, weights) + trust = apply_grounding_rescue( + trust, + entailment=entailment, + contradiction=contradiction, + containment=containment_score(sentence, context_text), + numeric_ok=numeric_consistency(sentence, context_text), + ) + return trust, classify_support(trust) + def verify_step( claim: str, @@ -134,42 +330,15 @@ def verify( ) # --- NLI scoring --- - # NLI models work best with short, focused premises. When a context chunk - # contains multiple sentences of info, the model classifies entailed - # hypotheses as "neutral" because the premise has information BEYOND the - # hypothesis. Fix: split chunks into individual sentences for NLI scoring. - _SPAN_ENTAILMENT_THRESHOLD = 0.5 - - context_units: list[str] = [] - # (chunk_idx, char_start, char_end) into the original chunk text - unit_locations: list[tuple[int, int, int]] = [] - for chunk_idx, chunk in enumerate(chunk_texts): - sub = split_sentences(chunk) - if not sub: - sub = [chunk] - for unit in sub: - char_start = chunk.find(unit) - if char_start == -1: - char_start = 0 - context_units.append(unit) - unit_locations.append((chunk_idx, char_start, char_start + len(unit))) - - nli_pairs = [(unit, sentence) for sentence in sentences for unit in context_units] - nli_scores_flat = batch_compute_entailment(nli_pairs, model_name=nli_model) - nli_scores: list[float] = [] - nli_best_chunks: list[str | None] = [] - per_sentence_unit_scores: list[list[float]] = [] - for i in range(len(sentences)): - start = i * len(context_units) - unit_scores = nli_scores_flat[start : start + len(context_units)] - per_sentence_unit_scores.append(unit_scores) - if unit_scores: - best_idx = unit_scores.index(max(unit_scores)) - nli_scores.append(unit_scores[best_idx]) - nli_best_chunks.append(context_units[best_idx]) - else: - nli_scores.append(0.0) - nli_best_chunks.append(None) + # NLI works best on short, focused premises with hypotheses that carry + # their own referents. _ground_sentences handles both: it scores each + # sentence against individual context sentences plus the whole chunk + # (max wins), and prepends the prior sentence when a hypothesis opens with + # an anaphor. See athena_verify.core helpers for details. + entail_scores, contra_scores, per_sentence_unit_scores, span_units, span_locations = ( + _ground_sentences(sentences, chunk_texts, nli_model) + ) + context_text = " ".join(chunk_texts) # --- Lexical overlap scoring --- overlap_results = [best_overlap_score(s, chunk_texts) for s in sentences] @@ -196,25 +365,35 @@ def verify( judge_start = time.time() judge_results = batch_judge_sentences(sentences, combined_context, question, llm_client) llm_scores = [score for score, _ in judge_results] - llm_judge_avg_ms = (time.time() - judge_start) * 1000 / len(sentences) if sentences else 2000.0 + llm_judge_avg_ms = ( + (time.time() - judge_start) * 1000 / len(sentences) if sentences else 2000.0 + ) # --- Build per-sentence results --- sentence_scores: list[SentenceScore] = [] for i, sentence in enumerate(sentences): - nli = nli_scores[i] if i < len(nli_scores) else 0.0 + nli = entail_scores[i] if i < len(entail_scores) else 0.0 + contra = contra_scores[i] if i < len(contra_scores) else 0.0 overlap, best_chunk = overlap_results[i] llm = llm_scores[i] if i < len(llm_scores) else None - trust = compute_trust_score(nli, overlap, llm, weights) - status = classify_support(trust) + trust, status = _trust_and_status( + entailment=nli, + contradiction=contra, + overlap=overlap, + sentence=sentence, + context_text=context_text, + llm=llm, + weights=weights, + ) unit_scores_i = per_sentence_unit_scores[i] if i < len(per_sentence_unit_scores) else [] supporting_spans = [ SupportingSpan( - chunk_idx=unit_locations[j][0], - start=unit_locations[j][1], - end=unit_locations[j][2], - text=context_units[j], + chunk_idx=span_locations[j][0], + start=span_locations[j][1], + end=span_locations[j][2], + text=span_units[j], ) for j, score in enumerate(unit_scores_i) if score >= _SPAN_ENTAILMENT_THRESHOLD @@ -335,8 +514,18 @@ async def verify_async( ) # --- NLI scoring (async) --- - nli_pairs = [(" ".join(chunk_texts), sentence) for sentence in sentences] - nli_scores = await batch_compute_entailment_async(nli_pairs, model_name=nli_model) + # Offload the same grounding logic used by verify() to a thread so we get + # per-unit + whole-chunk premises and anaphora windowing here too, instead + # of the old concatenate-all-chunks premise that silently truncated at the + # model's token limit. + ( + entail_scores, + contra_scores, + per_sentence_unit_scores, + span_units, + span_locations, + ) = await asyncio.to_thread(_ground_sentences, sentences, chunk_texts, nli_model) + context_text = " ".join(chunk_texts) # --- Lexical overlap scoring --- overlap_results = [best_overlap_score(s, chunk_texts) for s in sentences] @@ -363,17 +552,39 @@ async def verify_async( judge_start = time.time() judge_results = batch_judge_sentences(sentences, combined_context, question, llm_client) llm_scores = [score for score, _ in judge_results] - llm_judge_avg_ms = (time.time() - judge_start) * 1000 / len(sentences) if sentences else 2000.0 + llm_judge_avg_ms = ( + (time.time() - judge_start) * 1000 / len(sentences) if sentences else 2000.0 + ) # --- Build per-sentence results --- sentence_scores: list[SentenceScore] = [] for i, sentence in enumerate(sentences): - nli = nli_scores[i] if i < len(nli_scores) else 0.0 + nli = entail_scores[i] if i < len(entail_scores) else 0.0 + contra = contra_scores[i] if i < len(contra_scores) else 0.0 overlap, best_chunk = overlap_results[i] llm = llm_scores[i] if i < len(llm_scores) else None - trust = compute_trust_score(nli, overlap, llm, weights) - status = classify_support(trust) + trust, status = _trust_and_status( + entailment=nli, + contradiction=contra, + overlap=overlap, + sentence=sentence, + context_text=context_text, + llm=llm, + weights=weights, + ) + + unit_scores_i = per_sentence_unit_scores[i] if i < len(per_sentence_unit_scores) else [] + supporting_spans = [ + SupportingSpan( + chunk_idx=span_locations[j][0], + start=span_locations[j][1], + end=span_locations[j][2], + text=span_units[j], + ) + for j, score in enumerate(unit_scores_i) + if score >= _SPAN_ENTAILMENT_THRESHOLD + ] sentence_scores.append( SentenceScore( @@ -385,6 +596,7 @@ async def verify_async( trust_score=trust, support_status=status, best_matching_context=best_chunk, + supporting_spans=supporting_spans, ) ) @@ -502,28 +714,13 @@ def verify_batch( all_chunks: list[list[Chunk]] = [] all_sentences: list[list[str]] = [] - all_nli_pairs: list[tuple[str, str]] = [] - pair_offsets: list[int] = [] for q_idx in range(len(questions_list)): chunks = [Chunk.from_input(c) for c in contexts_list[q_idx]] - chunk_texts = [c.content for c in chunks] - combined_context = " ".join(chunk_texts) sentences = split_sentences(answers_list[q_idx]) - all_chunks.append(chunks) all_sentences.append(sentences) - offset = len(all_nli_pairs) - pair_offsets.append(offset) - - for sentence in sentences: - all_nli_pairs.append((combined_context, sentence)) - - nli_scores_all = batch_compute_entailment( - all_nli_pairs, model_name=nli_model, batch_size=batch_size - ) - for q_idx in range(len(questions_list)): chunks = all_chunks[q_idx] chunk_texts = [c.content for c in chunks] @@ -544,24 +741,46 @@ def verify_batch( ) continue - offset = pair_offsets[q_idx] + entail_scores, contra_scores, per_sentence_unit_scores, span_units, span_locations = ( + _ground_sentences(sentences, chunk_texts, nli_model) + ) + context_text = " ".join(chunk_texts) sentence_scores: list[SentenceScore] = [] llm_scores: list[float | None] = [None] * len(sentences) if use_llm_judge and llm_client is not None: - combined_context = " ".join(chunk_texts) judge_results = batch_judge_sentences( - sentences, combined_context, questions_list[q_idx], llm_client + sentences, context_text, questions_list[q_idx], llm_client ) llm_scores = [score for score, _ in judge_results] for i, sentence in enumerate(sentences): - nli = nli_scores_all[offset + i] if (offset + i) < len(nli_scores_all) else 0.0 + nli = entail_scores[i] if i < len(entail_scores) else 0.0 + contra = contra_scores[i] if i < len(contra_scores) else 0.0 overlap, best_chunk = best_overlap_score(sentence, chunk_texts) llm = llm_scores[i] if i < len(llm_scores) else None - trust = compute_trust_score(nli, overlap, llm, weights) - status = classify_support(trust) + trust, status = _trust_and_status( + entailment=nli, + contradiction=contra, + overlap=overlap, + sentence=sentence, + context_text=context_text, + llm=llm, + weights=weights, + ) + + unit_scores_i = per_sentence_unit_scores[i] if i < len(per_sentence_unit_scores) else [] + supporting_spans = [ + SupportingSpan( + chunk_idx=span_locations[j][0], + start=span_locations[j][1], + end=span_locations[j][2], + text=span_units[j], + ) + for j, score in enumerate(unit_scores_i) + if score >= _SPAN_ENTAILMENT_THRESHOLD + ] sentence_scores.append( SentenceScore( @@ -573,6 +792,7 @@ def verify_batch( trust_score=trust, support_status=status, best_matching_context=best_chunk, + supporting_spans=supporting_spans, ) ) @@ -666,27 +886,27 @@ async def verify_batch_async( all_chunks: list[list[Chunk]] = [] all_sentences: list[list[str]] = [] - all_nli_pairs: list[tuple[str, str]] = [] - pair_offsets: list[int] = [] for q_idx in range(len(questions_list)): chunks = [Chunk.from_input(c) for c in contexts_list[q_idx]] - chunk_texts = [c.content for c in chunks] - combined_context = " ".join(chunk_texts) sentences = split_sentences(answers_list[q_idx]) - all_chunks.append(chunks) all_sentences.append(sentences) - offset = len(all_nli_pairs) - pair_offsets.append(offset) - - for sentence in sentences: - all_nli_pairs.append((combined_context, sentence)) + # Ground every question with the shared per-unit + windowing logic, offloaded + # to a single worker thread so we don't block the event loop. + def _ground_all() -> list[_GroundResult]: + out: list[_GroundResult] = [] + for q_idx in range(len(questions_list)): + sents = all_sentences[q_idx] + if not sents: + out.append(([], [], [], [], [])) + continue + texts = [c.content for c in all_chunks[q_idx]] + out.append(_ground_sentences(sents, texts, nli_model)) + return out - nli_scores_all = await batch_compute_entailment_async( - all_nli_pairs, model_name=nli_model, batch_size=batch_size - ) + grounding = await asyncio.to_thread(_ground_all) for q_idx in range(len(questions_list)): try: @@ -709,24 +929,52 @@ async def verify_batch_async( ) continue - offset = pair_offsets[q_idx] + ( + entail_scores, + contra_scores, + per_sentence_unit_scores, + span_units, + span_locations, + ) = grounding[q_idx] + context_text = " ".join(chunk_texts) sentence_scores: list[SentenceScore] = [] llm_scores: list[float | None] = [None] * len(sentences) if use_llm_judge and llm_client is not None: - combined_context = " ".join(chunk_texts) judge_results = batch_judge_sentences( - sentences, combined_context, questions_list[q_idx], llm_client + sentences, context_text, questions_list[q_idx], llm_client ) llm_scores = [score for score, _ in judge_results] for i, sentence in enumerate(sentences): - nli = nli_scores_all[offset + i] if (offset + i) < len(nli_scores_all) else 0.0 + nli = entail_scores[i] if i < len(entail_scores) else 0.0 + contra = contra_scores[i] if i < len(contra_scores) else 0.0 overlap, best_chunk = best_overlap_score(sentence, chunk_texts) llm = llm_scores[i] if i < len(llm_scores) else None - trust = compute_trust_score(nli, overlap, llm, weights) - status = classify_support(trust) + trust, status = _trust_and_status( + entailment=nli, + contradiction=contra, + overlap=overlap, + sentence=sentence, + context_text=context_text, + llm=llm, + weights=weights, + ) + + unit_scores_i = ( + per_sentence_unit_scores[i] if i < len(per_sentence_unit_scores) else [] + ) + supporting_spans = [ + SupportingSpan( + chunk_idx=span_locations[j][0], + start=span_locations[j][1], + end=span_locations[j][2], + text=span_units[j], + ) + for j, score in enumerate(unit_scores_i) + if score >= _SPAN_ENTAILMENT_THRESHOLD + ] sentence_scores.append( SentenceScore( @@ -738,6 +986,7 @@ async def verify_batch_async( trust_score=trust, support_status=status, best_matching_context=best_chunk, + supporting_spans=supporting_spans, ) ) @@ -897,19 +1146,49 @@ async def verify_stream( chunks = [Chunk.from_input(c) for c in context] chunk_texts = [c.content for c in chunks] - combined_context = " ".join(chunk_texts) sentence_scores: list[SentenceScore] = [] idx = 0 async for sentence in sentence_buffer(answer_stream): - nli_pairs = [(combined_context, sentence)] - nli_scores = await batch_compute_entailment_async(nli_pairs, model_name=nli_model) - nli = nli_scores[0] if nli_scores else 0.0 + # Ground each completed sentence with the same per-unit + whole-chunk + # premises as verify(). When the sentence opens with an anaphor, include + # the previous one so the referent is present, then keep the current + # sentence's score (the last entry). + prev_text = sentence_scores[-1].text if sentence_scores else None + if prev_text and _starts_with_anaphor(sentence): + ground_input = [prev_text, sentence] + else: + ground_input = [sentence] + + entail_scores, contra_scores, span_scores, span_units, span_locations = ( + await asyncio.to_thread(_ground_sentences, ground_input, chunk_texts, nli_model) + ) + nli = entail_scores[-1] if entail_scores else 0.0 + contra = contra_scores[-1] if contra_scores else 0.0 + unit_scores_i = span_scores[-1] if span_scores else [] overlap, best_chunk = best_overlap_score(sentence, chunk_texts) - trust = compute_trust_score(nli, overlap, None, weights) - status = classify_support(trust) + trust, status = _trust_and_status( + entailment=nli, + contradiction=contra, + overlap=overlap, + sentence=sentence, + context_text=" ".join(chunk_texts), + llm=None, + weights=weights, + ) + + supporting_spans = [ + SupportingSpan( + chunk_idx=span_locations[j][0], + start=span_locations[j][1], + end=span_locations[j][2], + text=span_units[j], + ) + for j, score in enumerate(unit_scores_i) + if score >= _SPAN_ENTAILMENT_THRESHOLD + ] score = SentenceScore( text=sentence, @@ -919,6 +1198,7 @@ async def verify_stream( trust_score=trust, support_status=status, best_matching_context=best_chunk, + supporting_spans=supporting_spans, ) sentence_scores.append(score) idx += 1 diff --git a/athena_verify/integrations/crewai.py b/athena_verify/integrations/crewai.py index 8540298..34311c1 100644 --- a/athena_verify/integrations/crewai.py +++ b/athena_verify/integrations/crewai.py @@ -5,8 +5,6 @@ from __future__ import annotations -from typing import Any - from athena_verify.core import verify_step try: @@ -14,11 +12,11 @@ _CREWAI_AVAILABLE = True except ImportError: - BaseTool = object # type: ignore[misc,assignment] + BaseTool = object _CREWAI_AVAILABLE = False -class AthenaVerifyTool(BaseTool): +class AthenaVerifyTool(BaseTool): # type: ignore[misc] """CrewAI tool for verifying factual claims against evidence. Verify whether a claim is supported by the given evidence. diff --git a/athena_verify/integrations/langgraph.py b/athena_verify/integrations/langgraph.py index f32adc1..e722592 100644 --- a/athena_verify/integrations/langgraph.py +++ b/athena_verify/integrations/langgraph.py @@ -5,7 +5,8 @@ from __future__ import annotations -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from athena_verify.core import verify_step from athena_verify.models import StepResult diff --git a/athena_verify/nli.py b/athena_verify/nli.py index c360bd9..ef9d033 100644 --- a/athena_verify/nli.py +++ b/athena_verify/nli.py @@ -59,17 +59,73 @@ def get_nli_model(model_name: str = "cross-encoder/nli-deberta-v3-base") -> Any: return CrossEncoder(resolved) -def _softmax_entailment(logits: Any) -> float: - """Convert 3-class NLI logits to entailment probability using softmax. +@lru_cache(maxsize=32) +def entailment_index(model_name: str) -> int | None: + """Resolve the entailment class index from the model's label map. + + Different NLI checkpoints order their classes differently — e.g. the + cross-encoder/nli-* family uses ``0=contradiction, 1=entailment, + 2=neutral`` while many MoritzLaurer/DeBERTa checkpoints use + ``0=entailment``. Hardcoding the index silently scores the wrong class + on non-default models, which reads as a flood of false positives. + + Returns the index of the class whose label contains "entail", or + ``None`` for single-logit consistency models (e.g. Vectara HHEM) that + have no label map. + """ + model = get_nli_model(model_name) + config = getattr(getattr(model, "model", None), "config", None) or getattr( + model, "config", None + ) + id2label = getattr(config, "id2label", None) + if not isinstance(id2label, dict): + return None + for idx, label in id2label.items(): + if "entail" in str(label).lower(): + return int(idx) + return None + + +@lru_cache(maxsize=32) +def contradiction_index(model_name: str) -> int | None: + """Resolve the contradiction class index from the model's label map. - Standard NLI label ordering: 0=contradiction, 1=entailment, 2=neutral. - We return the probability of class 1 (entailment). + Mirrors :func:`entailment_index`. Used to tell a real contradiction + ("the cap is $5M" vs context "$2M") apart from a merely neutral / not- + directly-stated paraphrase, so the two can be handled differently. """ + model = get_nli_model(model_name) + config = getattr(getattr(model, "model", None), "config", None) or getattr( + model, "config", None + ) + id2label = getattr(config, "id2label", None) + if not isinstance(id2label, dict): + return None + for idx, label in id2label.items(): + if "contradict" in str(label).lower(): + return int(idx) + return None + + +def _softmax(logits: Any) -> list[float]: + """Numerically stable softmax over a logit row.""" row = list(logits) max_val = max(row) exp_vals = [math.exp(v - max_val) for v in row] total = sum(exp_vals) - return exp_vals[1] / total + return [v / total for v in exp_vals] + + +def _softmax_entailment(logits: Any, entail_idx: int) -> float: + """Convert NLI logits to entailment probability via softmax. + + Args: + logits: Per-class logits for one premise/hypothesis pair. + entail_idx: Index of the entailment class for this model. + """ + probs = _softmax(logits) + idx = entail_idx if 0 <= entail_idx < len(probs) else 1 + return probs[idx] def compute_entailment_score( @@ -88,32 +144,40 @@ def compute_entailment_score( Probability of entailment (0.0-1.0). """ model = get_nli_model(model_name) + entail_idx = entailment_index(model_name) scores = model.predict([[premise, hypothesis]]) - if hasattr(scores[0], "__len__") and len(scores[0]) >= 3: - return _softmax_entailment(scores[0]) - return float(scores[0]) if not hasattr(scores[0], "__len__") else float(scores[0][0]) + row = scores[0] + if hasattr(row, "__len__") and len(row) >= 3: + return _softmax_entailment(row, entail_idx if entail_idx is not None else 1) + # Single-logit consistency model (e.g. HHEM): score is already a probability. + return float(row) if not hasattr(row, "__len__") else float(row[0]) -def batch_compute_entailment( +def batch_compute_nli( pairs: list[tuple[str, str]], model_name: str = "cross-encoder/nli-deberta-v3-base", batch_size: int = 32, -) -> list[float]: - """Batch compute entailment scores for multiple premise-hypothesis pairs. +) -> list[tuple[float, float]]: + """Batch compute (entailment, contradiction) probabilities per pair. + + Returns a list of ``(entailment_prob, contradiction_prob)`` tuples. For + single-logit consistency models (e.g. HHEM) the contradiction probability + is taken as ``1 - entailment``. Args: pairs: List of (premise, hypothesis) tuples. model_name: Cross-encoder model to use (or alias like "lightweight"). batch_size: Number of pairs to process at once. - - Returns: - List of entailment probabilities. """ if not pairs: return [] model = get_nli_model(model_name) - results: list[float] = [] + e_idx = entailment_index(model_name) + c_idx = contradiction_index(model_name) + entail_idx = e_idx if e_idx is not None else 1 + contra_idx = c_idx if c_idx is not None else 0 + results: list[tuple[float, float]] = [] for start in range(0, len(pairs), batch_size): batch = pairs[start : start + batch_size] @@ -121,13 +185,38 @@ def batch_compute_entailment( for score_row in scores: if hasattr(score_row, "__len__") and len(score_row) >= 3: - results.append(_softmax_entailment(score_row)) + probs = _softmax(score_row) + entail = probs[entail_idx] if 0 <= entail_idx < len(probs) else probs[1] + contra = probs[contra_idx] if 0 <= contra_idx < len(probs) else probs[0] + results.append((entail, contra)) else: - results.append(float(score_row)) + entail = float(score_row) + results.append((entail, 1.0 - entail)) return results +def batch_compute_entailment( + pairs: list[tuple[str, str]], + model_name: str = "cross-encoder/nli-deberta-v3-base", + batch_size: int = 32, +) -> list[float]: + """Batch compute entailment scores for multiple premise-hypothesis pairs. + + Thin wrapper over :func:`batch_compute_nli` that returns only the + entailment probabilities. + + Args: + pairs: List of (premise, hypothesis) tuples. + model_name: Cross-encoder model to use (or alias like "lightweight"). + batch_size: Number of pairs to process at once. + + Returns: + List of entailment probabilities. + """ + return [entail for entail, _ in batch_compute_nli(pairs, model_name, batch_size)] + + async def batch_compute_entailment_async( pairs: list[tuple[str, str]], model_name: str = "cross-encoder/nli-deberta-v3-base", diff --git a/athena_verify/overlap.py b/athena_verify/overlap.py index 3bc4f59..17a2f5c 100644 --- a/athena_verify/overlap.py +++ b/athena_verify/overlap.py @@ -9,6 +9,60 @@ from __future__ import annotations +import re + +_WORD_RE = re.compile(r"[a-z0-9]+") +# A number: digits with optional thousands separators / decimal part. +_NUM_RE = re.compile(r"\d[\d,]*(?:\.\d+)?") + +# Function words carry no grounding signal; excluding them keeps containment +# from being inflated by shared "the/of/is" tokens. +_STOPWORDS = frozenset( + { + "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", + "of", "to", "in", "on", "for", "and", "or", "but", "with", "at", "by", + "as", "that", "this", "these", "those", "it", "its", "from", "into", + "than", "then", "also", "such", "which", "their", "they", "them", + "has", "have", "had", "will", "shall", "may", "can", "any", "all", + "not", "no", "only", "other", "more", "most", "some", "each", "both", + } +) + + +def _normalize_number(token: str) -> str: + """Strip thousands separators so '1,200' and '1200' compare equal.""" + return token.replace(",", "") + + +def containment_score(sentence: str, context_text: str) -> float: + """Fraction of a sentence's content words that appear in the context. + + Unlike symmetric token F1 (which is penalised by long context), this is a + precision-style measure of how much of the *claim* is lexically grounded. + It is the signal used to rescue faithful paraphrases that standalone NLI + scores as neutral. + """ + ctx_tokens = set(_WORD_RE.findall(context_text.lower())) + words = [ + w for w in _WORD_RE.findall(sentence.lower()) if len(w) > 2 and w not in _STOPWORDS + ] + if not words: + return 0.0 + return sum(1 for w in words if w in ctx_tokens) / len(words) + + +def numeric_consistency(sentence: str, context_text: str) -> bool: + """True if every number in the sentence also appears in the context. + + Comma-insensitive. Returns True when the sentence contains no numbers. + This is the guard that keeps number-substitution hallucinations + ("the cap is $5M" against a $2M context) from being rescued by lexical + containment, since the swapped figure will be absent from the context. + """ + ctx_nums = {_normalize_number(n) for n in _NUM_RE.findall(context_text)} + sent_nums = [_normalize_number(n) for n in _NUM_RE.findall(sentence)] + return all(n in ctx_nums for n in sent_nums) + def token_f1(text1: str, text2: str) -> float: """Compute token-level F1 overlap between two texts. diff --git a/athena_verify/parser.py b/athena_verify/parser.py index c97ae02..493bb61 100644 --- a/athena_verify/parser.py +++ b/athena_verify/parser.py @@ -74,11 +74,33 @@ def split_sentences(text: str) -> list[str]: return _split_sentences_regex(text) +# Common abbreviations that end in a period but do not end a sentence. Kept +# lowercase and without the trailing period for matching. Covers titles, legal +# and academic citation forms, and Latin/measurement shorthands — the domains +# (legal, medical, technical) athena targets, where a wrong split fragments a +# claim and shows up as a false positive. +_ABBREVIATIONS = frozenset( + { + "dr", "mr", "mrs", "ms", "prof", "rev", "hon", "sr", "jr", "st", + "vs", "etc", "al", "cf", "eg", "ie", "ca", "approx", + "inc", "ltd", "co", "corp", "llc", "plc", + "no", "nos", "fig", "figs", "sec", "secs", "art", "para", "pp", "vol", + "ch", "ed", "eds", "rep", "dept", "est", "min", "max", + "jan", "feb", "mar", "apr", "jun", "jul", "aug", "sep", "sept", + "oct", "nov", "dec", + # multi-dot forms, matched after stripping internal periods + "us", "uk", "un", "eu", "am", "pm", "phd", "md", "ba", "ma", "bs", + } +) + + def _split_sentences_regex(text: str) -> list[str]: - """Split text into sentences using regex (fallback). + """Split text into sentences using regex (fallback for when NLTK is absent). - Used when NLTK is not available. Handles common English sentence - boundaries but may split incorrectly on abbreviations like "Dr. Smith". + Abbreviation-aware: a candidate boundary is rejected when the token before + the period is a known abbreviation (``Dr.``, ``Inc.``), a single-letter + initial, or a dotted acronym (``U.S.``), so claims in legal/medical text + aren't fragmented. Args: text: The answer text to split. @@ -86,18 +108,28 @@ def _split_sentences_regex(text: str) -> list[str]: Returns: List of non-empty sentence strings. """ - # Normalize whitespace text = text.strip() + if not text: + return [] - # Split on sentence-ending punctuation followed by space or end-of-string. - # Handles: period, exclamation, question mark. - sentences = re.split(r"(?<=[.!?])\s+(?=[A-Z])", text) - - # Filter empty strings and strip whitespace - result = [] - for s in sentences: - s = s.strip() - if s: - result.append(s) + result: list[str] = [] + start = 0 + # A boundary is sentence-ending punctuation, an optional closing quote/paren, + # then whitespace, followed by something that looks like a new sentence. + for m in re.finditer(r"[.!?]+[\"')\]]?\s+(?=[A-Z0-9\"'(])", text): + preceding = text[start : m.start()] + last_token = preceding.split()[-1] if preceding.split() else "" + # Normalize: drop internal/trailing periods so "U.S" -> "us", "Dr" -> "dr". + normalized = last_token.replace(".", "").strip(",;:\"'()").lower() + if normalized in _ABBREVIATIONS or len(normalized) == 1: + continue + sentence = text[start : m.end()].strip() + if sentence: + result.append(sentence) + start = m.end() + + tail = text[start:].strip() + if tail: + result.append(tail) return result diff --git a/benchmarks/RESULTS.md b/benchmarks/RESULTS.md index 515b779..baa80eb 100644 --- a/benchmarks/RESULTS.md +++ b/benchmarks/RESULTS.md @@ -7,7 +7,7 @@ All results are **real, reproducible, and measured on this codebase**. No projec - **Machine**: Apple M1 Max, 64 GB RAM, macOS - **Python**: 3.13 - **Seed**: 42 (deterministic) -- **Date**: 2026-04-19 +- **Date**: 2026-06-27 ## Real Dataset Acquisition @@ -72,55 +72,58 @@ python benchmarks/run_faithbench.py --synthetic \ Six hallucination categories across legal, medical, technical, and general domains. -### NLI-Only Mode (nli-deberta-v3-base, ~17ms p50) +### NLI-Only Mode (nli-deberta-v3-base, ~23ms p50) | Category | Precision | Recall | F1 | |----------|-----------|--------|----| -| Fabricated claims | 100.0% | 98.7% | **99.3%** | -| Out-of-context | 100.0% | 93.3% | **96.6%** | -| Number substitutions | 79.3% | 95.8% | **86.8%** | -| Subtle contradictions | 100.0% | 100.0% | **100.0%** | -| Partial support | 75.9% | 100.0% | **86.3%** | -| **Overall** | **86.6%** | **96.7%** | **91.3%** | - -- **False positive rate on faithful sentences**: 17% (15/89 sentences incorrectly flagged) -- **Latency p50**: ~17ms per verification call -- **Latency p95**: ~26ms per verification call +| Fabricated claims | 100.0% | 96.0% | **97.9%** | +| Out-of-context | 100.0% | 96.7% | **98.3%** | +| Number substitutions | 82.1% | 95.8% | **88.5%** | +| Subtle contradictions | 100.0% | 96.7% | **98.3%** | +| Partial support | 95.2% | 90.9% | **93.0%** | +| **Overall** | **94.5%** | **95.6%** | **95.0%** | + +- **False positive rate on faithful sentences**: 4.6% (4/87 sentences incorrectly flagged) +- **Latency p50**: ~22.5ms per verification call +- **Latency p95**: ~34.5ms per verification call - **Cost**: $0 (local model, no API calls) -### NLI-Only Mode (nli-deberta-v3-large, ~37ms p50) +### NLI-Only Mode (nli-deberta-v3-large, ~53ms p50) | Category | Precision | Recall | F1 | |----------|-----------|--------|----| | Fabricated claims | 100.0% | 98.7% | **99.3%** | -| Out-of-context | 100.0% | 93.3% | **96.6%** | -| Number substitutions | 79.3% | 95.8% | **86.8%** | -| Subtle contradictions | 100.0% | 100.0% | **100.0%** | -| Partial support | 75.9% | 100.0% | **86.3%** | -| **Overall** | **86.3%** | **97.8%** | **91.7%** | +| Out-of-context | 100.0% | 93.3% | **96.5%** | +| Number substitutions | 82.1% | 95.8% | **88.5%** | +| Subtle contradictions | 100.0% | 93.3% | **96.5%** | +| Partial support | 90.9% | 90.9% | **90.9%** | +| **Overall** | **94.5%** | **95.6%** | **95.0%** | -- **Latency p50**: ~37ms per verification call -- **Latency p95**: ~53ms per verification call +- **False positive rate on faithful sentences**: 3.4% (3/87 sentences incorrectly flagged) +- **Latency p50**: ~53.4ms per verification call +- **Latency p95**: ~89.2ms per verification call ### How It Works -Context chunks are split into individual sentences before NLI scoring. Each answer sentence is scored against every context sentence, and the maximum entailment score is used. This avoids the "neutral trap" where NLI models classify a hypothesis as neutral when the premise contains information beyond the hypothesis. +Context chunks are split into individual sentences before NLI scoring. Each answer sentence is scored against every context sentence **and the full chunk**, and the maximum entailment score is used. This avoids the "neutral trap" where NLI models classify a hypothesis as neutral when the premise contains information beyond the hypothesis, while still catching facts spread across several context sentences. Answer sentences that open with an anaphor ("This cap…", "It also…") are joined with the previous sentence before scoring so the referent is preserved. + +NLI still scores many faithful paraphrases as neutral (entailment ≈ 0). A **guarded rescue** recovers them without admitting hallucinations: a not-entailed sentence is lifted to *partially supported* only when (a) the most on-topic context unit does **not** contradict it — read from the 3-class NLI distribution, picked by lexical relevance so an unrelated unit can't veto a faithful claim — and (b) every number in the sentence appears in the context, and (c) most of its content words appear in the context. This is what cut the faithful false-positive rate from 17% to 4.6% (base) / 3.4% (large) while holding hallucination recall at ~96%. ### The Right Tool for the Right Job | Use case | Recommended mode | Why | |----------|-----------------|-----| -| General RAG QA | NLI-only (base) | Catches 91%+ of hallucinations in 17ms | -| High-stakes docs | NLI-only (large) | Slightly better recall at 37ms | -| Real-time chat | NLI-only (base) | 17ms latency is production-ready | +| General RAG QA | NLI-only (base) | 95.0% F1 in ~23ms | +| High-stakes docs | NLI-only (large) | Lower false-positive rate at ~53ms | +| Real-time chat | NLI-only (base) | ~24ms latency is production-ready | | Maximum accuracy | NLI + LLM-judge | LLM catches paraphrases NLI misses | ## Latency Comparison | Mode | p50 | p95 | Notes | |------|-----|-----|-------| -| NLI only (base) | ~17ms | ~26ms | Fastest, 91.3% F1 | -| NLI only (large) | ~37ms | ~53ms | Slightly better, 91.7% F1 | +| NLI only (base) | ~23ms | ~35ms | Fastest, 95.0% F1, 4.6% FP on faithful | +| NLI only (large) | ~53ms | ~89ms | Lower FP (3.4%), 95.0% F1 | | LLM judge (local) | ~7.4s | ~10s | Per sentence, local gemma-4-31b-it | | GPT-4 judge (API) | ~2s | ~5s | Per sentence, network round-trip | diff --git a/examples/agent_circuit_breaker.py b/examples/agent_circuit_breaker.py index a1454a7..e3eefb9 100644 --- a/examples/agent_circuit_breaker.py +++ b/examples/agent_circuit_breaker.py @@ -1,46 +1,92 @@ -"""Agent circuit breaker demo using verify_step(). +"""Agent circuit breaker — stop a hallucination before it cascades. -Demonstrates a 3-step agent where intermediate claims are verified. -Steps 1-2 pass verification; step 3 contains a fabricated claim that -halts the chain before it can propagate. +A multi-step agent passes each step's output to the next. When step 3 +hallucinates a figure that isn't in the source, an unguarded agent would +carry it into its final BUY recommendation. `verify_step()` is a circuit +breaker: it halts the chain the moment a step stops being grounded. + +Run with: python examples/agent_circuit_breaker.py """ -from athena_verify import verify_step +from __future__ import annotations + +import contextlib +import io +import logging +import os +import sys + +# Keep the demo output clean — silence the ML stack's load reports / progress +# bars before anything imports it. +os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error") +os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1") +os.environ.setdefault("TQDM_DISABLE", "1") + +import structlog # noqa: E402 + +structlog.configure(wrapper_class=structlog.make_filtering_bound_logger(logging.WARNING)) + +from athena_verify import verify_step # noqa: E402 (configure logging first) + +# ANSI colors (no dependency); disabled when output isn't a TTY. +_TTY = sys.stdout.isatty() + +def c(text: str, code: str) -> str: + return f"\033[{code}m{text}\033[0m" if _TTY else text + + +DIM, BOLD = "2", "1" +GREEN, RED, YELLOW = "32", "31", "33" + +# What the agent is allowed to rely on — retrieved from a 10-Q filing. EVIDENCE = [ - "The Eiffel Tower is located in Paris, France.", - "It was constructed between 1887 and 1889.", - "The tower stands 330 metres tall.", + "Acme Corp reported Q3 revenue of $2.4 billion, up 12% year over year.", + "Operating income was $530 million for the quarter.", + "Net profit margin for Q3 was 22%, in line with the prior quarter.", + "The company reaffirmed full-year guidance and declared a $0.15 dividend.", ] +# Each reasoning step the agent produces, fed forward to the next. STEPS = [ - ("The Eiffel Tower is in Paris, France.", True), # supported - ("The tower is 330 metres tall.", True), # supported - ("The Eiffel Tower was built in 1650.", False), # fabricated — halts + "Acme's Q3 revenue was $2.4 billion, up 12% year over year.", + "Operating income for the quarter came in at $530 million.", + "Net margin expanded sharply to 35%, a major profitability breakout.", + "Given the margin breakout, raise the price target and recommend BUY.", ] -def run_agent() -> None: - """Run a 3-step agent with circuit breaker logic.""" - print("Agent Circuit Breaker Demo") - print("=" * 60) - print(f"Evidence: {len(EVIDENCE)} documents\n") +def main() -> None: + print(c("\n Financial research agent — 4 reasoning steps", BOLD)) + print(c(f" Grounding on {len(EVIDENCE)} passages from Acme's 10-Q", DIM)) + + # Load the NLI model up front (and quietly) so the steps below stream + # without a pause or stray library output. + print(c(" loading grounding model…\n", DIM), flush=True) + with contextlib.redirect_stderr(io.StringIO()): + verify_step(claim="warm up", evidence=["warm up"]) - for i, (claim, _expected_pass) in enumerate(STEPS, 1): - result = verify_step(claim=claim, evidence=EVIDENCE) - status = "✓ PASS" if result.passed else "✗ FAIL" - print(f"Step {i}: {status}") - print(f" Claim: '{claim}'") - print(f" Trust Score: {result.trust_score:.3f}") - print(f" Action: {result.action}") + for i, claim in enumerate(STEPS, 1): + step = verify_step(claim=claim, evidence=EVIDENCE, threshold=0.5) + badge = ( + c(" PASS ", f"{BOLD};{GREEN}") + if step.passed + else c(" HALT ", f"{BOLD};{RED}") + ) + print(f" {badge} step {i} {c(f'trust={step.trust_score:.2f}', DIM)}") + print(f" {claim}") - if result.action == "halt": - print("\n [CIRCUIT BREAKER] Halting agent — fabricated claim detected.") - break + if step.action == "halt": + print() + print(c(" ⛔ circuit breaker tripped — ungrounded claim blocked", f"{BOLD};{RED}")) + print(c(" the agent never reached step 4, so the BUY call built", YELLOW)) + print(c(" on a hallucinated 35% margin was never made.", YELLOW)) + print(c("\n Source says: net profit margin for Q3 was 22%.", DIM)) + return print() - else: - print("All steps passed. Agent completed successfully.") + + print(c(" ✓ all steps grounded — recommendation cleared to proceed", GREEN)) if __name__ == "__main__": - run_agent() + main() diff --git a/examples/langchain_example.py b/examples/langchain_example.py index 4789c0b..23daf53 100644 --- a/examples/langchain_example.py +++ b/examples/langchain_example.py @@ -9,9 +9,9 @@ Run with: python examples/langchain_example.py """ -from langchain_core.documents import Document -from langchain.retrievers import BaseRetriever from langchain.chains import RetrievalQA +from langchain.retrievers import BaseRetriever +from langchain_core.documents import Document from langchain_core.llms.fake import FakeListLLM from athena_verify.integrations.langchain import VerifyingLLM diff --git a/pyproject.toml b/pyproject.toml index 55e82e2..618f33d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,12 +23,15 @@ dependencies = [ "pydantic>=2.9.0", "structlog>=24.0.0", "nltk>=3.8.0", + # The NLI cross-encoder is the core of the library, so ship it by default — + # `pip install athena-verify` should run verify() with no extra steps. + "sentence-transformers>=3.0.0", ] [project.optional-dependencies] -nli = [ - "sentence-transformers>=3.0.0", -] +# Retained for backwards compatibility; sentence-transformers is now a core +# dependency, so `athena-verify[nli]` resolves to the base install. +nli = [] llm = [ "openai>=1.0.0", "anthropic>=0.30.0", @@ -95,6 +98,7 @@ module = [ "openai.*", "anthropic.*", "nltk.*", + "crewai.*", ] ignore_missing_imports = true diff --git a/tests/test_new_features.py b/tests/test_new_features.py index 0957ca3..ef3dd10 100644 --- a/tests/test_new_features.py +++ b/tests/test_new_features.py @@ -45,10 +45,12 @@ def predict(self, pairs): @pytest.fixture(autouse=True) def _mock_nli(): - with ( - patch("athena_verify.core.batch_compute_entailment", return_value=[0.85]), - patch("athena_verify.core.batch_compute_entailment_async", return_value=[0.85]), - ): + # _ground_sentences scores one pair per (context unit, sentence); return a + # constant (entailment, contradiction) for each so answers score uniformly. + def fake_nli(pairs, *args, **kwargs): + return [(0.85, 0.05)] * len(pairs) + + with patch("athena_verify.core.batch_compute_nli", side_effect=fake_nli): yield diff --git a/tests/test_nli.py b/tests/test_nli.py index 5226312..c509a47 100644 --- a/tests/test_nli.py +++ b/tests/test_nli.py @@ -9,6 +9,7 @@ import pytest +from athena_verify import nli as nli_module from athena_verify.nli import ( NLI_MODEL_ALIASES, batch_compute_entailment, @@ -69,8 +70,23 @@ def predict(self, pairs): @pytest.fixture def mock_model_cache(): - with patch("athena_verify.nli._nli_cache", {}) as cache: - yield cache + """Swap the cached model loader for a controllable dict of mock models. + + get_nli_model and entailment_index are both @lru_cache'd, so clear them + around the patch to keep tests isolated. + """ + nli_module.get_nli_model.cache_clear() + nli_module.entailment_index.cache_clear() + models: dict[str, object] = {} + + def fake_get_model(model_name: str = "cross-encoder/nli-deberta-v3-base"): + return models.get(resolve_nli_model(model_name)) or models.get(model_name) + + with patch("athena_verify.nli.get_nli_model", side_effect=fake_get_model): + yield models + + nli_module.get_nli_model.cache_clear() + nli_module.entailment_index.cache_clear() @pytest.fixture diff --git a/tests/test_rescue.py b/tests/test_rescue.py new file mode 100644 index 0000000..f79b456 --- /dev/null +++ b/tests/test_rescue.py @@ -0,0 +1,83 @@ +"""Tests for the grounding-rescue path: containment, numeric gate, and the +contradiction-vetoed rescue that recovers faithful paraphrases NLI scores low. +""" + +from __future__ import annotations + +from athena_verify.calibration import ( + RESCUE_TRUST, + apply_grounding_rescue, + classify_support, +) +from athena_verify.overlap import containment_score, numeric_consistency + + +class TestContainment: + def test_full_containment(self): + score = containment_score( + "reference counting primary mechanism", + "Python memory management uses reference counting as the primary mechanism.", + ) + assert score == 1.0 + + def test_partial_containment(self): + score = containment_score( + "olive oil drizzled before baking", + "Ingredients: pizza dough, tomato sauce, mozzarella, olive oil.", + ) + assert 0.0 < score < 0.6 + + def test_stopwords_ignored(self): + # Only function words overlap -> no grounding signal. + assert containment_score("the and of is", "the cat and the dog") == 0.0 + + def test_empty_sentence(self): + assert containment_score("the of", "anything here") == 0.0 + + +class TestNumericConsistency: + def test_no_numbers_is_ok(self): + assert numeric_consistency("the cap applies broadly", "context with no figures") + + def test_matching_number(self): + assert numeric_consistency("the cap is 2 million", "indemnification cap of 2 million") + + def test_comma_insensitive(self): + assert numeric_consistency("about 1200 SEK", "approximately SEK 1,200 per tonne") + + def test_substituted_number_fails(self): + assert not numeric_consistency("the cap is 5 million", "the cap is 2 million") + + +class TestGroundingRescue: + def _neutral_paraphrase(self, **over): + kwargs = dict( + entailment=0.10, contradiction=0.05, containment=0.9, numeric_ok=True + ) + kwargs.update(over) + return apply_grounding_rescue(0.2, **kwargs) + + def test_rescues_neutral_grounded_paraphrase(self): + trust = self._neutral_paraphrase() + assert trust >= RESCUE_TRUST + assert classify_support(trust) in ("SUPPORTED", "PARTIAL") + + def test_contradiction_blocks_rescue(self): + # A real contradiction (e.g. subtle reversal) must never be rescued. + trust = self._neutral_paraphrase(contradiction=0.9) + assert trust == 0.2 + + def test_numeric_mismatch_blocks_rescue(self): + # Number substitution: lexically grounded but a figure is wrong. + trust = self._neutral_paraphrase(numeric_ok=False) + assert trust == 0.2 + + def test_low_containment_not_rescued(self): + trust = self._neutral_paraphrase(containment=0.2) + assert trust == 0.2 + + def test_rescue_never_lowers_trust(self): + # Already-high trust is left untouched. + assert apply_grounding_rescue( + 0.9, entailment=0.85, contradiction=0.0, containment=1.0, numeric_ok=True + ) == 0.9 diff --git a/tests/test_supporting_spans.py b/tests/test_supporting_spans.py index 2ca7732..a9a4b41 100644 --- a/tests/test_supporting_spans.py +++ b/tests/test_supporting_spans.py @@ -9,19 +9,20 @@ from athena_verify import verify from athena_verify.models import SupportingSpan - CHUNK_0 = "The sky is blue during the day." CHUNK_1 = "Photosynthesis occurs in plant cells." # 2 context units (one per chunk), 2 sentences in answer → 4 NLI pairs: # (unit0, sent0), (unit1, sent0), (unit0, sent1), (unit1, sent1) # Scores: sent0 supported by unit0 (chunk 0), sent1 supported by unit1 (chunk 1). -_NLI_SCORES = [0.9, 0.1, 0.1, 0.85] +# Each is (entailment, contradiction); contradiction is 0 so the rescue rule +# is irrelevant to span assignment here. +_NLI_SCORES = [(0.9, 0.0), (0.1, 0.0), (0.1, 0.0), (0.85, 0.0)] @pytest.fixture() def _mock_nli(): - with patch("athena_verify.core.batch_compute_entailment", return_value=_NLI_SCORES): + with patch("athena_verify.core.batch_compute_nli", return_value=_NLI_SCORES): yield @@ -90,7 +91,10 @@ def test_span_text_matches_slice(self): def test_no_spans_below_threshold(self): # All NLI scores are 0.1 — below the 0.5 threshold, so no spans. - with patch("athena_verify.core.batch_compute_entailment", return_value=[0.1, 0.1, 0.1, 0.1]): + with patch( + "athena_verify.core.batch_compute_nli", + return_value=[(0.1, 0.0), (0.1, 0.0), (0.1, 0.0), (0.1, 0.0)], + ): result = verify( question="What color is the sky?", answer="The sky appears blue. Photosynthesis happens in plants.", diff --git a/tests/test_verify.py b/tests/test_verify.py index c9a819d..63e9071 100644 --- a/tests/test_verify.py +++ b/tests/test_verify.py @@ -12,10 +12,12 @@ @pytest.fixture(autouse=True) def _mock_nli(): - with ( - patch("athena_verify.core.batch_compute_entailment", return_value=[0.85]), - patch("athena_verify.core.batch_compute_entailment_async", return_value=[0.85]), - ): + # _ground_sentences scores one pair per (context unit, sentence); return a + # constant (entailment, contradiction) for each so answers score uniformly. + def fake_nli(pairs, *args, **kwargs): + return [(0.85, 0.05)] * len(pairs) + + with patch("athena_verify.core.batch_compute_nli", side_effect=fake_nli): yield @@ -248,17 +250,9 @@ def test_latency_budget_llm_judge_never_called_with_budget_50(self): from unittest.mock import MagicMock, patch llm_client = MagicMock() - result = verify( - question="What?", - answer="Some answer.", - context=["Some context"], - use_llm_judge=True, - llm_client=llm_client, - latency_budget_ms=50, - ) with patch("athena_verify.core.batch_judge_sentences") as mock_judge: - result = verify( + verify( question="What?", answer="Some answer.", context=["Some context"],