From 67504ba9c37da129ee0a0364be0e5307a571d92e Mon Sep 17 00:00:00 2001 From: Vaishnavi Desai Date: Wed, 24 Jun 2026 19:43:44 +0530 Subject: [PATCH] feat: add async evaluation and BM25 baseline to eval_beir.py - --async-eval flag: evaluates datasets concurrently via asyncio + ThreadPoolExecutor - --bm25 flag: adds BM25 baseline columns to output table (requires rank-bm25) - --max-workers param: controls thread-pool size for async mode - print_table() extended to show side-by-side VORTEXRAG vs BM25 columns - save_csv() handles variable columns gracefully Closes #3 Signed-off-by: Vaishnavi Desai --- benchmarks/eval_beir.py | 328 ++++++++++++++++++++++++++++++++-------- 1 file changed, 261 insertions(+), 67 deletions(-) diff --git a/benchmarks/eval_beir.py b/benchmarks/eval_beir.py index f187f61..d43d1cb 100644 --- a/benchmarks/eval_beir.py +++ b/benchmarks/eval_beir.py @@ -2,6 +2,7 @@ VORTEXRAG BEIR Benchmark Evaluation ===================================== Evaluates VORTEXRAG retrieval quality on standard BEIR benchmark datasets. +Supports async concurrent evaluation and optional BM25 baseline comparison. Outputs: NDCG@10, Recall@100, MAP per dataset + aggregate table. @@ -12,6 +13,12 @@ # Specific datasets only python benchmarks/eval_beir.py --datasets nq hotpotqa scifact + # Async concurrent evaluation (faster for multiple datasets) + python benchmarks/eval_beir.py --async-eval + + # With BM25 baseline comparison + python benchmarks/eval_beir.py --bm25 + # Save results to CSV python benchmarks/eval_beir.py --output results/beir_results.csv @@ -24,10 +31,12 @@ from __future__ import annotations import argparse +import asyncio import csv import os import sys import time +from concurrent.futures import ThreadPoolExecutor from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) @@ -50,20 +59,20 @@ ] DATASET_DOMAIN_MAP = { - "msmarco": "general", - "nq": "general", - "hotpotqa": "general", - "fiqa": "financial", - "arguana": "legal", + "msmarco": "general", + "nq": "general", + "hotpotqa": "general", + "fiqa": "financial", + "arguana": "legal", "webis-touche2020": "general", - "dbpedia-entity": "general", - "scidocs": "scientific", - "fever": "general", - "climate-fever": "scientific", - "scifact": "biomedical", - "trec-covid": "biomedical", - "bioasq": "biomedical", - "nfcorpus": "medical", + "dbpedia-entity": "general", + "scidocs": "scientific", + "fever": "general", + "climate-fever": "scientific", + "scifact": "biomedical", + "trec-covid": "biomedical", + "bioasq": "biomedical", + "nfcorpus": "medical", } @@ -75,6 +84,18 @@ def _beir_available() -> bool: return False +def _rank_bm25_available() -> bool: + try: + import rank_bm25 # noqa: F401 + return True + except ImportError: + return False + + +# --------------------------------------------------------------------------- +# Metric helpers +# --------------------------------------------------------------------------- + def dcg_at_k(relevances: list[int], k: int) -> float: import math return sum( @@ -112,8 +133,29 @@ def average_precision(retrieved_ids: list[str], qrels: dict[str, int]) -> float: return precision_sum / len(relevant) -def evaluate_dataset_beir(dataset_name: str, data_dir: str, top_k: int = 100) -> dict: - """Run VORTEXRAG on a BEIR dataset and return metrics.""" +# --------------------------------------------------------------------------- +# BM25 baseline +# --------------------------------------------------------------------------- + +def build_bm25_index(corpus_texts: list[str]): + """Build a BM25 index from corpus texts. Requires rank-bm25.""" + from rank_bm25 import BM25Okapi + tokenized = [text.lower().split() for text in corpus_texts] + return BM25Okapi(tokenized) + + +def bm25_retrieve(bm25_index, query: str, corpus_ids: list[str], top_k: int) -> list[str]: + scores = bm25_index.get_scores(query.lower().split()) + top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] + return [corpus_ids[i] for i in top_indices] + + +# --------------------------------------------------------------------------- +# Core evaluation +# --------------------------------------------------------------------------- + +def _load_beir_dataset(dataset_name: str, data_dir: str): + """Download (if needed) and load a BEIR dataset.""" from beir import util from beir.datasets.data_loader import GenericDataLoader @@ -121,37 +163,54 @@ def evaluate_dataset_beir(dataset_name: str, data_dir: str, top_k: int = 100) -> out_dir = Path(data_dir) / dataset_name if not out_dir.exists(): - print(f" Downloading {dataset_name}...") + print(f" [{dataset_name}] Downloading...") util.download_and_unzip(url, data_dir) corpus, queries, qrels = GenericDataLoader(data_folder=str(out_dir)).load(split="test") + return corpus, queries, qrels + + +def evaluate_dataset_beir( + dataset_name: str, + data_dir: str, + top_k: int = 100, + run_bm25: bool = False, +) -> dict: + """Run VORTEXRAG (and optionally BM25) on a BEIR dataset and return metrics.""" + corpus, queries, qrels = _load_beir_dataset(dataset_name, data_dir) domain = DATASET_DOMAIN_MAP.get(dataset_name, "general") config = VortexRAGConfig(domain=domain) - print(f" Indexing {len(corpus)} documents ({domain} domain)...") corpus_texts = [ f"{doc.get('title', '')} {doc.get('text', '')}".strip() for doc in corpus.values() ] corpus_ids = list(corpus.keys()) + print(f" [{dataset_name}] Indexing {len(corpus)} documents ({domain})...") rag = VortexRAG(corpus=corpus_texts, config=config) rag.index() - ndcg_scores, recall_scores, ap_scores = [], [], [] - query_ids = list(queries.keys())[:500] # cap at 500 queries for speed + bm25_index = None + if run_bm25 and _rank_bm25_available(): + print(f" [{dataset_name}] Building BM25 index...") + bm25_index = build_bm25_index(corpus_texts) + elif run_bm25: + print(f" [{dataset_name}] WARNING: rank-bm25 not installed, skipping BM25 baseline.") - print(f" Evaluating {len(query_ids)} queries...") + query_ids = list(queries.keys())[:500] + print(f" [{dataset_name}] Evaluating {len(query_ids)} queries...") t0 = time.time() + ndcg_scores, recall_scores, ap_scores = [], [], [] + bm25_ndcg, bm25_recall, bm25_ap = [], [], [] + for qid in query_ids: query_text = queries[qid] - query_qrels = {doc_id: rel for doc_id, rel in qrels.get(qid, {}).items()} + query_qrels = dict(qrels.get(qid, {})) result = rag.query(query_text) - - # Map retrieved chunks back to corpus doc IDs by content match retrieved_ids = [] for chunk in result.context_window[:top_k]: for doc_id, doc_text in zip(corpus_ids, corpus_texts): @@ -163,67 +222,174 @@ def evaluate_dataset_beir(dataset_name: str, data_dir: str, top_k: int = 100) -> recall_scores.append(recall_at_k(retrieved_ids, query_qrels, k=100)) ap_scores.append(average_precision(retrieved_ids, query_qrels)) + if bm25_index is not None: + bm25_ids = bm25_retrieve(bm25_index, query_text, corpus_ids, top_k) + bm25_ndcg.append(ndcg_at_k(bm25_ids, query_qrels, k=10)) + bm25_recall.append(recall_at_k(bm25_ids, query_qrels, k=100)) + bm25_ap.append(average_precision(bm25_ids, query_qrels)) + elapsed = time.time() - t0 n = len(query_ids) - return { - "dataset": dataset_name, - "domain": domain, + result_row = { + "dataset": dataset_name, + "domain": domain, "num_queries": n, - "ndcg@10": round(sum(ndcg_scores) / n, 4), - "recall@100": round(sum(recall_scores) / n, 4), - "map": round(sum(ap_scores) / n, 4), - "latency_s": round(elapsed / n, 3), + "ndcg@10": round(sum(ndcg_scores) / n, 4), + "recall@100": round(sum(recall_scores) / n, 4), + "map": round(sum(ap_scores) / n, 4), + "latency_s": round(elapsed / n, 3), } + if bm25_index is not None: + result_row.update({ + "bm25_ndcg@10": round(sum(bm25_ndcg) / n, 4), + "bm25_recall@100": round(sum(bm25_recall) / n, 4), + "bm25_map": round(sum(bm25_ap) / n, 4), + }) + + return result_row -def evaluate_dataset_stub(dataset_name: str) -> dict: + +def evaluate_dataset_stub(dataset_name: str, run_bm25: bool = False) -> dict: """Stub result used when BEIR is not installed (for CI / unit tests).""" import random rng = random.Random(hash(dataset_name) & 0xFFFF) - return { - "dataset": dataset_name, - "domain": DATASET_DOMAIN_MAP.get(dataset_name, "general"), + row = { + "dataset": dataset_name, + "domain": DATASET_DOMAIN_MAP.get(dataset_name, "general"), "num_queries": 0, - "ndcg@10": round(rng.uniform(0.35, 0.62), 4), - "recall@100": round(rng.uniform(0.60, 0.88), 4), - "map": round(rng.uniform(0.28, 0.55), 4), - "latency_s": round(rng.uniform(0.08, 0.25), 3), - "note": "stub — install beir for real evaluation", + "ndcg@10": round(rng.uniform(0.35, 0.62), 4), + "recall@100": round(rng.uniform(0.60, 0.88), 4), + "map": round(rng.uniform(0.28, 0.55), 4), + "latency_s": round(rng.uniform(0.08, 0.25), 3), + "note": "stub — install beir for real evaluation", } + if run_bm25: + rng2 = random.Random(hash(dataset_name + "_bm25") & 0xFFFF) + row.update({ + "bm25_ndcg@10": round(rng2.uniform(0.25, 0.50), 4), + "bm25_recall@100": round(rng2.uniform(0.50, 0.78), 4), + "bm25_map": round(rng2.uniform(0.20, 0.45), 4), + }) + return row + +# --------------------------------------------------------------------------- +# Async evaluation +# --------------------------------------------------------------------------- + +async def evaluate_datasets_async( + datasets: list[str], + data_dir: str, + top_k: int, + run_bm25: bool, + use_beir: bool, + max_workers: int = 4, +) -> list[dict]: + """ + Evaluate multiple BEIR datasets concurrently using a thread-pool executor. + Each dataset is indexed and queried in its own thread so the event loop + stays free and progress from all datasets interleaves naturally. + """ + loop = asyncio.get_event_loop() + executor = ThreadPoolExecutor(max_workers=max_workers) + + def _run(dataset_name: str) -> dict: + print(f"\n[{dataset_name}] starting...") + try: + if use_beir: + return evaluate_dataset_beir(dataset_name, data_dir, top_k, run_bm25) + else: + return evaluate_dataset_stub(dataset_name, run_bm25) + except Exception as exc: + print(f" [{dataset_name}] ERROR: {exc}") + return {"dataset": dataset_name, "error": str(exc)} + + tasks = [loop.run_in_executor(executor, _run, ds) for ds in datasets] + return await asyncio.gather(*tasks) + + +# --------------------------------------------------------------------------- +# Output helpers +# --------------------------------------------------------------------------- + +def print_table(results: list[dict], show_bm25: bool = False): + has_bm25 = show_bm25 and any("bm25_ndcg@10" in r for r in results) + + if has_bm25: + header = ( + f"{'Dataset':<22} {'Domain':<14} " + f"{'VRTX NDCG@10':>13} {'VRTX R@100':>11} {'VRTX MAP':>9} {'ms/q':>7} " + f"{'BM25 NDCG@10':>13} {'BM25 R@100':>11} {'BM25 MAP':>9}" + ) + else: + header = ( + f"{'Dataset':<22} {'Domain':<14} " + f"{'NDCG@10':>8} {'R@100':>8} {'MAP':>8} {'ms/q':>7}" + ) -def print_table(results: list[dict]): - header = f"{'Dataset':<22} {'Domain':<14} {'NDCG@10':>8} {'R@100':>8} {'MAP':>8} {'ms/q':>7}" sep = "-" * len(header) print(sep) print(header) print(sep) + for r in results: - print( - f"{r['dataset']:<22} {r['domain']:<14} " - f"{r['ndcg@10']:>8.4f} {r['recall@100']:>8.4f} " - f"{r['map']:>8.4f} {r['latency_s']*1000:>7.1f}" - ) + if "error" in r: + print(f"{r['dataset']:<22} ERROR: {r['error']}") + continue + if has_bm25: + print( + f"{r['dataset']:<22} {r['domain']:<14} " + f"{r['ndcg@10']:>13.4f} {r['recall@100']:>11.4f} {r['map']:>9.4f} " + f"{r['latency_s']*1000:>7.1f} " + f"{r.get('bm25_ndcg@10', 0):>13.4f} " + f"{r.get('bm25_recall@100', 0):>11.4f} " + f"{r.get('bm25_map', 0):>9.4f}" + ) + else: + print( + f"{r['dataset']:<22} {r['domain']:<14} " + f"{r['ndcg@10']:>8.4f} {r['recall@100']:>8.4f} " + f"{r['map']:>8.4f} {r['latency_s']*1000:>7.1f}" + ) + + valid = [r for r in results if "error" not in r] print(sep) - n = len(results) - if n: - avg_ndcg = sum(r["ndcg@10"] for r in results) / n - avg_r100 = sum(r["recall@100"] for r in results) / n - avg_map = sum(r["map"] for r in results) / n - print(f"{'AVERAGE':<22} {'':<14} {avg_ndcg:>8.4f} {avg_r100:>8.4f} {avg_map:>8.4f}") + if valid: + n = len(valid) + avg = lambda k: sum(r[k] for r in valid) / n # noqa: E731 + if has_bm25: + print( + f"{'AVERAGE':<22} {'':<14} " + f"{avg('ndcg@10'):>13.4f} {avg('recall@100'):>11.4f} {avg('map'):>9.4f} " + f"{'':>7} " + f"{avg('bm25_ndcg@10'):>13.4f} " + f"{avg('bm25_recall@100'):>11.4f} " + f"{avg('bm25_map'):>9.4f}" + ) + else: + print( + f"{'AVERAGE':<22} {'':<14} " + f"{avg('ndcg@10'):>8.4f} {avg('recall@100'):>8.4f} {avg('map'):>8.4f}" + ) print(sep) def save_csv(results: list[dict], path: str): Path(path).parent.mkdir(parents=True, exist_ok=True) + all_keys = list(dict.fromkeys(k for r in results for k in r.keys())) with open(path, "w", newline="") as f: - writer = csv.DictWriter(f, fieldnames=results[0].keys()) + writer = csv.DictWriter(f, fieldnames=all_keys, extrasaction="ignore") writer.writeheader() writer.writerows(results) print(f"\nResults saved to {path}") +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + def main(): parser = argparse.ArgumentParser(description="Evaluate VORTEXRAG on BEIR") parser.add_argument( @@ -242,6 +408,18 @@ def main(): "--top-k", type=int, default=100, help="Number of documents to retrieve per query" ) + parser.add_argument( + "--async-eval", action="store_true", + help="Evaluate datasets concurrently using asyncio + ThreadPoolExecutor" + ) + parser.add_argument( + "--bm25", action="store_true", + help="Add BM25 baseline comparison (requires: pip install rank-bm25)" + ) + parser.add_argument( + "--max-workers", type=int, default=4, + help="Thread-pool size for async evaluation (default: 4)" + ) args = parser.parse_args() use_beir = _beir_available() @@ -249,25 +427,41 @@ def main(): print("WARNING: 'beir' package not found — using stub results.") print("Install with: pip install beir\n") - results = [] - for dataset in args.datasets: - print(f"\n[{dataset}]") - try: - if use_beir: - r = evaluate_dataset_beir(dataset, args.data_dir, args.top_k) - else: - r = evaluate_dataset_stub(dataset) - results.append(r) - print(f" NDCG@10={r['ndcg@10']:.4f} R@100={r['recall@100']:.4f} MAP={r['map']:.4f}") - except Exception as exc: - print(f" ERROR: {exc}") + if args.bm25 and not _rank_bm25_available(): + print("WARNING: 'rank-bm25' not found — BM25 baseline disabled.") + print("Install with: pip install rank-bm25\n") + + print(f"Evaluating {len(args.datasets)} datasets " + f"({'async' if args.async_eval else 'sequential'}" + f"{', +BM25' if args.bm25 else ''})...\n") + + if args.async_eval: + results = asyncio.run( + evaluate_datasets_async( + args.datasets, args.data_dir, args.top_k, + args.bm25, use_beir, args.max_workers, + ) + ) + else: + results = [] + for dataset in args.datasets: + print(f"\n[{dataset}]") + try: + if use_beir: + r = evaluate_dataset_beir(dataset, args.data_dir, args.top_k, args.bm25) + else: + r = evaluate_dataset_stub(dataset, args.bm25) + results.append(r) + print(f" NDCG@10={r['ndcg@10']:.4f} R@100={r['recall@100']:.4f} MAP={r['map']:.4f}") + except Exception as exc: + print(f" ERROR: {exc}") print("\n\nFINAL RESULTS") - print_table(results) + print_table(results, show_bm25=args.bm25) if args.output: save_csv(results, args.output) if __name__ == "__main__": - main() + main() \ No newline at end of file