From f9c32979630a607585ee49f300473d418b474cb4 Mon Sep 17 00:00:00 2001 From: Rajeev Jain Date: Wed, 10 Jun 2026 14:58:02 -0500 Subject: [PATCH 1/2] =?UTF-8?q?Add=20evals/=20=E2=80=94=20schema-rejection?= =?UTF-8?q?=20and=20tool-retrieval=20regression=20coverage?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two cheap, runnable evals that turn behavior we care about into numbers we can re-measure on every PR: - evals/schema_rejection/ — 21 calls (19 deliberately malformed, 2 baselines) classify each outcome by layer (schema / IO / runtime / silent). Headline number is caught_rate. Currently 94.7% with 1 silent pass (plot_dataset with plot_type='variable' but no variable_name still returns a plot). - evals/tool_retrieval/ — BM25 over the full ~54-function tool surface against 30 labeled prompts. Reports top-1 / top-3 / top-5 selection accuracy and mean rank of the correct tool. Currently 77% / 87% / 93%. Both runners run in under 30 seconds with no external dependencies. Result JSON files are gitignored; the runners are the source of truth. evals/README.md explains what an eval is for a non-AI engineer and lists when to add new ones vs. when to write a unit test instead. --- .gitignore | 7 + evals/README.md | 65 ++++++++ evals/__init__.py | 0 evals/schema_rejection/README.md | 75 ++++++++++ evals/schema_rejection/__init__.py | 0 evals/schema_rejection/cases.py | 228 ++++++++++++++++++++++++++++ evals/schema_rejection/run.py | 229 +++++++++++++++++++++++++++++ evals/tool_retrieval/README.md | 91 ++++++++++++ evals/tool_retrieval/__init__.py | 0 evals/tool_retrieval/prompts.py | 64 ++++++++ evals/tool_retrieval/run.py | 223 ++++++++++++++++++++++++++++ 11 files changed, 982 insertions(+) create mode 100644 evals/README.md create mode 100644 evals/__init__.py create mode 100644 evals/schema_rejection/README.md create mode 100644 evals/schema_rejection/__init__.py create mode 100644 evals/schema_rejection/cases.py create mode 100644 evals/schema_rejection/run.py create mode 100644 evals/tool_retrieval/README.md create mode 100644 evals/tool_retrieval/__init__.py create mode 100644 evals/tool_retrieval/prompts.py create mode 100644 evals/tool_retrieval/run.py diff --git a/.gitignore b/.gitignore index 5c9ddf4..c798802 100644 --- a/.gitignore +++ b/.gitignore @@ -228,3 +228,10 @@ scripts/convergence_agent/ # Generated plots / scratch scripts dropped at repo root *.png save_plots.py + + +# Eval result JSON files are per-run and regenerated by the runners. +evals/results/ + +# Paper drafts and supporting material (kept in a separate repo). +papers/ diff --git a/evals/README.md b/evals/README.md new file mode 100644 index 0000000..c408432 --- /dev/null +++ b/evals/README.md @@ -0,0 +1,65 @@ +# Evals + +This folder holds **evaluations** ("evals" for short) of the MCP server's +behavior. The goal is to turn opinions about how the server should behave +into **numbers** that can be re-measured when the code changes — the same +way `tests/` turns "the code should be correct" into a runnable assertion. + +## What is an "eval"? (for non-AI-engineers) + +In AI-driven software, an **eval** is the same thing a unit test is in +regular software, with one wrinkle: the system under test includes a +language model whose output is not bit-for-bit reproducible. So an eval +scores aggregate behavior across many inputs ("on this set of 20 prompts, +18 picked the right tool") rather than asserting one specific output. + +You write an eval the same way you write a regression test: + +1. Pick a behavior you care about. ("The server should reject a malformed + request before it spends compute on it.") +2. Build a small fixed set of inputs that exercise that behavior. (Say, 20 + deliberately-wrong prompts.) +3. Run them through the system and record a numeric score. +4. Commit the inputs, the runner, and the result so the next person can + re-run and compare. + +Evals do **not** prove correctness. They measure *how often* the system +does the right thing on a fixed sample. They are most useful for catching +regressions ("we used to pick the right tool 90% of the time, now it's +60%") and for putting numbers on architectural decisions. + +## What's in here + +| Folder | What it measures | +|---|---| +| [`schema_rejection/`](schema_rejection/) | How often the typed tool boundary catches malformed calls before any work happens — the "did we waste compute on garbage?" number | +| [`tool_retrieval/`](tool_retrieval/) | How often a simple text retriever (BM25) finds the right tool by description — the "is our tool catalog still navigable as it grows?" number | + +Both run end-to-end in under 30 seconds on a laptop with no external +dependencies. They are cheap enough to add to CI. + +## How to run + +```bash +uv run python -m evals.schema_rejection.run +uv run python -m evals.tool_retrieval.run +``` + +Each runner writes a JSON file under `results/` named with a timestamp. +Result files are gitignored — they regenerate on each run; the runner +itself is the source of truth. + +## When to add a new eval + +Add one when you're about to make a decision and want a number to defend +it. Some good triggers: + +- We're considering exposing more tools — does retrieval still work? +- We're refactoring an entry-point — does it still reject malformed input? +- A bug class has appeared twice — write the eval before the third time. + +Bad triggers (use `tests/` instead): + +- Asserting a single specific output for a single specific input. +- Checking a function's signature or contract. +- Anything that should be a unit test of a Python function. diff --git a/evals/__init__.py b/evals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evals/schema_rejection/README.md b/evals/schema_rejection/README.md new file mode 100644 index 0000000..7a61904 --- /dev/null +++ b/evals/schema_rejection/README.md @@ -0,0 +1,75 @@ +# Schema-rejection eval + +## What this measures (plain language) + +When someone asks an AI assistant to "compute vorticity from the wind file," +the AI translates that into a call like: + +``` +run_analysis(operation="curl", grid_path=..., data_path=..., u_variable=..., v_variable=...) +``` + +There are many ways this call can be **wrong**: + +- The AI omitted `u_variable` because it didn't read the data file carefully. +- The AI typed `operation="curl_calculation"` instead of `operation="curl"`. +- The AI passed `grid_path="/path/that/does/not/exist.nc"`. +- The AI passed a string where a list was expected, or vice versa. + +For each of these, three things can happen: + +1. **Caught at the schema boundary.** The server's parameter checks reject + the call before any actual analysis runs. Best case — costs nothing. +2. **Caught at the file/IO boundary.** The call passes schema validation, + tries to open a file, and fails with a clear error. Acceptable. +3. **Silent failure.** The call passes both, runs to completion, and + returns a wrong-looking number with no error at all. **This is the bug + class** — the AI gets back something that looks like an answer when it + shouldn't. + +**This eval asks: how often does the typed boundary actually catch a bad +call?** + +## What "good" looks like + +For ~20 deliberately-malformed inputs: + +- **>70% caught at schema or IO layer** = the boundary is doing real work. +- **<30%** = the boundary is too loose; the AI can drive it into silent + failures by sending well-formed-looking nonsense. +- **0 silent failures** = required. If we produce a plausible-looking + number from a malformed request, that is a bug we must fix. + +## How to run + +```bash +uv run python -m evals.schema_rejection.run +``` + +Writes a JSON report to `evals/results/schema_.json` and prints +a summary table. Returns non-zero exit if any silent failure occurred — +suitable for CI. + +## What this does NOT measure + +This eval cannot catch the kind of silent failure where the schema accepts +the call, the file opens cleanly, and the **answer is physically wrong** +(e.g., curl returned in the wrong units because of sphere-radius scaling). +That class needs a downstream validator with physical priors — expected +magnitude, expected units, expected sign — which is a separate piece of +work. + +## Reading the output + +The runner classifies each call into one of: + +| Outcome | Meaning | +|---|---| +| `schema_rejected` | Server raised before any file IO. Best case. | +| `io_rejected` | Server tried to open a file/path and failed visibly. Acceptable. | +| `runtime_error` | Computation started but raised an exception. Acceptable but worse. | +| `silent_pass` | Returned a result dict without an error. **Bug if the input was malformed.** | + +The headline number is `caught_rate = (schema_rejected + io_rejected + runtime_error) / total`. +We want that as high as possible. The danger number is the `silent_pass` count — +we want that to be **zero** for malformed inputs. diff --git a/evals/schema_rejection/__init__.py b/evals/schema_rejection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evals/schema_rejection/cases.py b/evals/schema_rejection/cases.py new file mode 100644 index 0000000..bf04ebe --- /dev/null +++ b/evals/schema_rejection/cases.py @@ -0,0 +1,228 @@ +"""Eval cases — deliberately malformed calls to run_analysis / plot_dataset. + +Each case is a dict with: +- id: short slug for the report +- description: one-line plain-English description of the bug +- tool: 'run_analysis' or 'plot_dataset' +- kwargs: the call to make +- expected: 'reject' (we want the boundary to catch it) or 'accept' + (the call is well-formed and should run cleanly — a sanity baseline) + +Cases marked 'accept' are baseline sanity checks: if too many of them fail, +the eval itself is broken. Cases marked 'reject' are the actual measurement. +""" + +from __future__ import annotations + + +def build_cases(grid_path: str, grid_path_with_data: tuple[str, str]) -> list[dict]: + """Return the case list, parameterized by the synthetic fixture paths.""" + grid_only = grid_path + grid_for_data, data_path = grid_path_with_data + missing_path = "/nonexistent/path/that/cannot/possibly/exist.nc" + + return [ + # ---- BASELINES: well-formed calls that SHOULD succeed ---- + { + "id": "baseline_inspect_mesh", + "description": "Well-formed inspect_mesh on a valid grid", + "tool": "run_analysis", + "kwargs": {"operation": "inspect_mesh", "grid_path": grid_only}, + "expected": "accept", + }, + { + "id": "baseline_calculate_area", + "description": "Well-formed calculate_area on a valid grid", + "tool": "run_analysis", + "kwargs": {"operation": "calculate_area", "grid_path": grid_only}, + "expected": "accept", + }, + # ---- MALFORMED: schema-level violations ---- + { + "id": "wrong_operation_typo", + "description": "operation='curl_calculation' instead of 'curl'", + "tool": "run_analysis", + "kwargs": { + "operation": "curl_calculation", + "grid_path": grid_only, + "data_path": data_path, + "u_variable": "u", + "v_variable": "v", + }, + "expected": "reject", + }, + { + "id": "wrong_operation_empty", + "description": "operation='' (empty string)", + "tool": "run_analysis", + "kwargs": {"operation": ""}, + "expected": "reject", + }, + { + "id": "wrong_operation_made_up", + "description": "operation='fluxulate' — does not exist", + "tool": "run_analysis", + "kwargs": {"operation": "fluxulate", "grid_path": grid_only}, + "expected": "reject", + }, + # ---- MALFORMED: missing required parameter ---- + { + "id": "missing_grid_path", + "description": "inspect_mesh without grid_path", + "tool": "run_analysis", + "kwargs": {"operation": "inspect_mesh"}, + "expected": "reject", + }, + { + "id": "missing_data_path", + "description": "inspect_variable with grid but no data", + "tool": "run_analysis", + "kwargs": {"operation": "inspect_variable", "grid_path": grid_only}, + "expected": "reject", + }, + { + "id": "missing_variable_name", + "description": "calculate_zonal_mean without variable_name", + "tool": "run_analysis", + "kwargs": { + "operation": "calculate_zonal_mean", + "grid_path": grid_for_data, + "data_path": data_path, + }, + "expected": "reject", + }, + { + "id": "missing_u_variable", + "description": "curl with v_variable but no u_variable", + "tool": "run_analysis", + "kwargs": { + "operation": "curl", + "grid_path": grid_for_data, + "data_path": data_path, + "v_variable": "v", + }, + "expected": "reject", + }, + { + "id": "missing_center_for_azimuthal", + "description": "azimuthal_mean without center_lon/lat/radius", + "tool": "run_analysis", + "kwargs": { + "operation": "azimuthal_mean", + "grid_path": grid_for_data, + "data_path": data_path, + "variable_name": "temperature", + }, + "expected": "reject", + }, + { + "id": "missing_bbox_bounds", + "description": "subset_bbox without lon_bounds/lat_bounds", + "tool": "run_analysis", + "kwargs": {"operation": "subset_bbox", "grid_path": grid_only}, + "expected": "reject", + }, + { + "id": "missing_data_path_a", + "description": "compare_fields missing data_path_a", + "tool": "run_analysis", + "kwargs": { + "operation": "compare_fields", + "variable_name": "temperature", + "data_path_b": data_path, + }, + "expected": "reject", + }, + { + "id": "missing_target_grid_for_remap", + "description": "remap_variable without target_grid_path", + "tool": "run_analysis", + "kwargs": { + "operation": "remap_variable", + "grid_path": grid_for_data, + "data_path": data_path, + "variable_name": "temperature", + }, + "expected": "reject", + }, + { + "id": "missing_data_paths_for_ensemble", + "description": "ensemble_mean without data_paths", + "tool": "run_analysis", + "kwargs": { + "operation": "ensemble_mean", + "variable_name": "temperature", + }, + "expected": "reject", + }, + { + "id": "missing_output_path_for_export", + "description": "export without output_path", + "tool": "run_analysis", + "kwargs": {"operation": "export"}, + "expected": "reject", + }, + # ---- MALFORMED: nonexistent file paths (IO layer should catch) ---- + { + "id": "nonexistent_grid", + "description": "inspect_mesh against a path that does not exist", + "tool": "run_analysis", + "kwargs": {"operation": "inspect_mesh", "grid_path": missing_path}, + "expected": "reject", + }, + { + "id": "nonexistent_data", + "description": "inspect_variable with nonexistent data file", + "tool": "run_analysis", + "kwargs": { + "operation": "inspect_variable", + "grid_path": grid_for_data, + "data_path": missing_path, + }, + "expected": "reject", + }, + # ---- MALFORMED: plot_dataset variants ---- + { + "id": "plot_unknown_type", + "description": "plot_dataset with plot_type='holography'", + "tool": "plot_dataset", + "kwargs": {"plot_type": "holography", "grid_path": grid_only}, + "expected": "reject", + }, + { + "id": "plot_missing_variable", + "description": "plot_dataset variable plot but no variable_name", + "tool": "plot_dataset", + "kwargs": { + "plot_type": "variable", + "grid_path": grid_for_data, + "data_path": data_path, + }, + "expected": "reject", + }, + { + "id": "plot_variable_does_not_exist", + "description": "plot_dataset for variable 'pixiedust' (not in file)", + "tool": "plot_dataset", + "kwargs": { + "plot_type": "variable", + "grid_path": grid_for_data, + "data_path": data_path, + "variable_name": "pixiedust", + }, + "expected": "reject", + }, + # ---- MALFORMED: wrong-type bbox bounds ---- + { + "id": "bbox_wrong_arity", + "description": "subset_bbox with lon_bounds=[10] (needs 2 floats)", + "tool": "run_analysis", + "kwargs": { + "operation": "subset_bbox", + "grid_path": grid_only, + "lon_bounds": [10.0], + "lat_bounds": [0.0, 10.0], + }, + "expected": "reject", + }, + ] diff --git a/evals/schema_rejection/run.py b/evals/schema_rejection/run.py new file mode 100644 index 0000000..a0fb26a --- /dev/null +++ b/evals/schema_rejection/run.py @@ -0,0 +1,229 @@ +"""Run the schema-rejection eval. + +What this does: build a small set of synthetic mesh fixtures, then call +run_analysis / plot_dataset with a fixed set of deliberately-bad inputs. +Classify each outcome and write a JSON report. + +The goal is not to test individual functions (the unit tests do that). It is +to put a number on "how often does the typed boundary catch bad calls before +they spend compute." See README.md. +""" + +from __future__ import annotations + +import json +import time +import traceback +from pathlib import Path + +from evals.schema_rejection.cases import build_cases + + +def _make_fixtures(tmp_dir: Path) -> tuple[str, tuple[str, str]]: + """Create a synthetic grid and (grid, data) pair on disk. Returns paths.""" + import xarray as xr + + # Single-triangle UGRID + grid_ds = xr.Dataset( + { + "Mesh2": ( + [], + 0, + { + "cf_role": "mesh_topology", + "topology_dimension": 2, + "node_coordinates": "Mesh2_node_x Mesh2_node_y", + "face_node_connectivity": "Mesh2_face_nodes", + }, + ), + "Mesh2_node_x": (["nMesh2_node"], [0.0, 1.0, 0.5]), + "Mesh2_node_y": (["nMesh2_node"], [0.0, 0.0, 1.0]), + "Mesh2_face_nodes": ( + ["nMesh2_face", "nMaxMesh2_face_nodes"], + [[0, 1, 2]], + {"cf_role": "face_node_connectivity", "start_index": 0}, + ), + } + ) + data_ds = xr.Dataset( + { + "temperature": ( + ["nMesh2_face"], + [288.15], + {"units": "K", "long_name": "Temperature"}, + ), + } + ) + grid_only = tmp_dir / "grid_only.nc" + grid_for_data = tmp_dir / "grid.nc" + data_path = tmp_dir / "data.nc" + grid_ds.to_netcdf(grid_only) + grid_ds.to_netcdf(grid_for_data) + data_ds.to_netcdf(data_path) + return str(grid_only), (str(grid_for_data), str(data_path)) + + +def _classify(outcome: str, exc_type: str | None) -> str: + """Bucket the outcome into one of the four reporting categories.""" + if outcome == "ok": + return "silent_pass" + if exc_type is None: + return "silent_pass" + if exc_type in ("ValueError", "TypeError", "KeyError"): + # Front-door _require raises ValueError; missing kwargs raise TypeError + return "schema_rejected" + if exc_type in ("FileNotFoundError", "OSError", "PermissionError"): + return "io_rejected" + return "runtime_error" + + +def run_case(case: dict) -> dict: + """Execute one case, classify result. Never propagates exceptions.""" + from uxarray_mcp.tools import plot_dataset as plot_dataset_tool + from uxarray_mcp.tools import run_analysis as run_analysis_tool + + tool_fn = run_analysis_tool if case["tool"] == "run_analysis" else plot_dataset_tool + t0 = time.perf_counter() + outcome: str = "ok" + exc_type: str | None = None + exc_msg: str | None = None + try: + tool_fn(**case["kwargs"]) + except BaseException as exc: # noqa: BLE001 + outcome = "error" + exc_type = type(exc).__name__ + exc_msg = str(exc)[:200] + elapsed_ms = (time.perf_counter() - t0) * 1000 + + classification = _classify(outcome, exc_type) + # Baseline cases (expected='accept') invert the success criterion: + # silent_pass is GOOD; rejection is BAD. + if case["expected"] == "accept": + is_correct = classification == "silent_pass" + else: + is_correct = classification != "silent_pass" + + return { + "id": case["id"], + "description": case["description"], + "tool": case["tool"], + "expected": case["expected"], + "classification": classification, + "exc_type": exc_type, + "exc_msg": exc_msg, + "is_correct": is_correct, + "elapsed_ms": round(elapsed_ms, 2), + } + + +def summarize(results: list[dict]) -> dict: + total = len(results) + reject_cases = [r for r in results if r["expected"] == "reject"] + accept_cases = [r for r in results if r["expected"] == "accept"] + rejects_total = len(reject_cases) + + counts = { + "schema_rejected": 0, + "io_rejected": 0, + "runtime_error": 0, + "silent_pass": 0, + } + for r in reject_cases: + counts[r["classification"]] += 1 + + caught = counts["schema_rejected"] + counts["io_rejected"] + counts["runtime_error"] + silent_failures = counts["silent_pass"] + + return { + "total_cases": total, + "baseline_cases": len(accept_cases), + "baseline_correct": sum(1 for r in accept_cases if r["is_correct"]), + "malformed_cases": rejects_total, + "counts": counts, + "caught_rate": round(caught / rejects_total, 3) if rejects_total else None, + "silent_failures": silent_failures, + "by_layer_pct": { + "schema": round(100 * counts["schema_rejected"] / rejects_total, 1) + if rejects_total + else None, + "io": round(100 * counts["io_rejected"] / rejects_total, 1) + if rejects_total + else None, + "runtime": round(100 * counts["runtime_error"] / rejects_total, 1) + if rejects_total + else None, + "silent": round(100 * silent_failures / rejects_total, 1) + if rejects_total + else None, + }, + } + + +def _print_table(results: list[dict], summary: dict) -> None: + print() + print(f"{'id':38s} {'expect':8s} {'classification':18s} {'ok'}") + print("-" * 80) + for r in results: + mark = "✓" if r["is_correct"] else "✗" + print(f"{r['id']:38s} {r['expected']:8s} {r['classification']:18s} {mark}") + print() + print("SUMMARY") + print(f" Total cases: {summary['total_cases']}") + print( + f" Baseline OK: " + f"{summary['baseline_correct']}/{summary['baseline_cases']} " + f"(well-formed calls that ran)" + ) + print(f" Malformed caught rate: {summary['caught_rate']}") + print(f" Silent failures (bugs): {summary['silent_failures']}") + pct = summary["by_layer_pct"] + print( + f" By layer: schema={pct['schema']}% io={pct['io']}% " + f"runtime={pct['runtime']}% silent={pct['silent']}%" + ) + print() + if summary["silent_failures"] == 0: + print("PASS — no silent failures on malformed inputs.") + else: + print("FAIL — at least one malformed call returned a result silently.") + + +def main() -> int: + import tempfile + + with tempfile.TemporaryDirectory(prefix="schema_eval_") as td: + grid_only, grid_with_data = _make_fixtures(Path(td)) + cases = build_cases(grid_only, grid_with_data) + results = [] + for case in cases: + try: + results.append(run_case(case)) + except Exception: # noqa: BLE001 + results.append( + { + "id": case["id"], + "description": case["description"], + "tool": case["tool"], + "expected": case["expected"], + "classification": "runner_error", + "exc_type": "RunnerError", + "exc_msg": traceback.format_exc()[:500], + "is_correct": False, + "elapsed_ms": 0.0, + } + ) + + summary = summarize(results) + _print_table(results, summary) + + out_dir = Path(__file__).resolve().parent.parent / "results" + out_dir.mkdir(exist_ok=True) + ts = int(time.time()) + out_path = out_dir / f"schema_{ts}.json" + out_path.write_text(json.dumps({"summary": summary, "results": results}, indent=2)) + print(f"\nWrote {out_path}") + return 0 if summary["silent_failures"] == 0 else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/evals/tool_retrieval/README.md b/evals/tool_retrieval/README.md new file mode 100644 index 0000000..420c7ac --- /dev/null +++ b/evals/tool_retrieval/README.md @@ -0,0 +1,91 @@ +# Tool-retrieval eval + +## What this measures (plain language) + +When an AI assistant has many tools to choose from, it has to pick the +right one for each user request. This is the same problem a person faces +opening a new menu in unfamiliar software: with five items, you scan all of +them; with fifty, you grep. Language models are no different. + +Today, the MCP server exposes **11 tools** to clients (intent-shaped front +doors like `run_analysis`, `plot_dataset`, `analyze_dataset`). Under the +hood there are **~45 lower-level functions** (the `__all__` list in +`uxarray_mcp.tools`) that the front doors fan out to. Future versions may +expose more of that surface directly, especially if libraries like +ToolRegistry land a retrieval layer that fetches a relevant subset on +demand. + +**The question this eval answers:** if we ask a simple retriever (BM25 — a +classic text-matching score, like what powers Elasticsearch and Lucene) to +pick the right tool from a list of N tool descriptions given a +natural-language prompt, how often does it get it right? + +If the answer is "almost always" with the current catalog, the architecture +is safe to grow. If the answer is "rarely," we either keep the visible +surface small or invest in semantic (embedding-based) retrieval. + +## What "good" looks like + +For ~30 hand-written prompts, each labeled with the one tool that should +answer it: + +- **Top-1 accuracy >90%** = naive retrieval works; deferred-tool catalog + is safe. +- **Top-1 70–90%, top-3 >95%** = ranking is good enough if we let the AI + pick from the top 3 retrieved tools. +- **Top-1 <70%** = naive text matching is not enough; need semantic + retrieval (embeddings) or keep the surface small. + +## What BM25 is, in one paragraph + +BM25 is a 30-year-old text retrieval algorithm. It scores a document +against a query by counting how often the query's words appear in the +document, weighted by how rare each word is across the whole document +collection (rare words count more), with a saturation curve so a document +that mentions "zonal" 50 times isn't 50× better than one that mentions it +once. It is **not** semantic — "find the trade winds" won't retrieve a tool +described as "calculate easterly atmospheric flow." It is the cheapest +possible retriever that isn't keyword-exact, and it's what most production +systems start with before reaching for embeddings. + +If BM25 works well enough, we don't need the embeddings infrastructure. If +it doesn't, we have a number that justifies the cost. + +## How to run + +```bash +uv run python -m evals.tool_retrieval.run +``` + +Writes a JSON report to `evals/results/retrieval_.json` and +prints a summary table. + +## Reading the output + +For each prompt the runner prints: + +``` +prompt expected_tool top1 rank +"compute area-weighted zonal mean of..." calculate_zonal_mean ✓ 1 +"plot the mesh wireframe" plot_mesh ✓ 1 +"find the curl of the wind field" calculate_curl ✗ 2 +``` + +The summary reports: + +- `top1_accuracy` — fraction of prompts where the right tool was ranked #1. +- `top3_accuracy` — fraction where it was in the top 3. +- `top5_accuracy` — fraction where it was in the top 5. +- `mean_rank` — average rank of the correct tool (lower is better). + +## Caveats + +- BM25 here scores against `(name + description + parameter names)`. The + quality of the result depends entirely on how good the descriptions are. + If BM25 does poorly, the fix may be **better docstrings**, not a switch + to embeddings — and that is a much cheaper fix. +- The prompt set is hand-written. A larger, more adversarial set would + give tighter numbers. This is a *starting* measurement, not a final one. +- A real production retriever would also use the prompt's context (prior + conversation, dataset state). This eval is the worst case: cold query, + no context. diff --git a/evals/tool_retrieval/__init__.py b/evals/tool_retrieval/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/evals/tool_retrieval/prompts.py b/evals/tool_retrieval/prompts.py new file mode 100644 index 0000000..12fffad --- /dev/null +++ b/evals/tool_retrieval/prompts.py @@ -0,0 +1,64 @@ +"""Tool-retrieval eval — labeled prompts. + +Each entry is a (prompt, expected_tool) pair. The expected_tool must be a +name that appears in `uxarray_mcp.tools.__all__`. + +These are hand-written to span the visible tool surface: inspection, +calculation, plotting, regridding, ensemble stats, vector calculus, +state/session management. Add more as the catalog grows. +""" + +from __future__ import annotations + +PROMPTS: list[tuple[str, str]] = [ + # Inspection / discovery + ( + "Tell me the topology of this mesh — how many faces, nodes, edges.", + "inspect_mesh", + ), + ( + "What variables are in this dataset and what dimensions do they have?", + "inspect_variable", + ), + ("Are there any NaN or fill values in this dataset?", "validate_dataset"), + ("What can I do with this MPAS grid file?", "get_capabilities"), + # Statistics on a single field + ("Compute the area of each face on this mesh.", "calculate_area"), + ("Give me the area-weighted zonal mean of temperature.", "calculate_zonal_mean"), + ("Take the time average of the precipitation variable.", "calculate_temporal_mean"), + ("Compute anomalies relative to the climatology.", "calculate_anomaly"), + # Vector calculus + ("Compute the curl of the wind field.", "calculate_curl"), + ("Find the divergence of the velocity field.", "calculate_divergence"), + ("Compute the gradient of surface pressure.", "calculate_gradient"), + ( + "Take the azimuthal mean of precipitation around the hurricane center.", + "calculate_azimuthal_mean", + ), + # Comparison / ensemble + ("Compare these two runs for the same variable.", "compare_fields"), + ("Compute the bias between simulation A and observations B.", "calculate_bias"), + ("Average across the ensemble members.", "calculate_ensemble_mean"), + ("How spread out are the ensemble members?", "calculate_ensemble_spread"), + # Regridding / subsetting + ( + "Remap this field from the source mesh to a coarser target grid.", + "remap_variable", + ), + ("Subset the dataset to the North Atlantic bounding box.", "subset_bbox"), + ("Clip the data to this irregular polygon.", "subset_polygon"), + ("Pull out a cross section along 40 degrees north.", "extract_cross_section"), + # Plotting + ("Show me a wireframe plot of the mesh.", "plot_mesh"), + ("Plot temperature as a colored map.", "plot_variable"), + ("Plot the zonal mean as a line chart.", "plot_zonal_mean"), + ("Make a map of the mesh with geographic coastlines.", "plot_mesh_geo"), + # Export + ("Save the result to a NetCDF file.", "export_to_netcdf"), + ("Write the data out as CSV.", "export_to_csv"), + # State / session / workflows + ("Start a new analysis session.", "create_session"), + ("Run the full first-look pipeline on this mesh.", "analyze_dataset"), + ("Resume the workflow I started earlier.", "resume_workflow"), + ("Check whether the HPC endpoint is healthy.", "diagnose_endpoint"), +] diff --git a/evals/tool_retrieval/run.py b/evals/tool_retrieval/run.py new file mode 100644 index 0000000..9190807 --- /dev/null +++ b/evals/tool_retrieval/run.py @@ -0,0 +1,223 @@ +"""Run the tool-retrieval eval (BM25 over the full tool surface). + +Builds a document per tool from (name + docstring + parameter names), indexes +them with a small BM25 implementation, and scores the labeled prompts. + +The BM25 implementation is intentionally a self-contained ~40 lines so this +eval has no new dependencies. Production retrieval would use a real BM25 +library (or, better, embeddings) — but to answer "is naive text matching +enough?" the implementation choice barely matters. +""" + +from __future__ import annotations + +import inspect +import json +import math +import re +import time +from collections import Counter +from pathlib import Path + +from evals.tool_retrieval.prompts import PROMPTS + +# Stopwords kept tiny — we want the technical terms to do the work. +_STOPWORDS = frozenset( + { + "a", + "an", + "and", + "are", + "as", + "at", + "be", + "by", + "for", + "from", + "has", + "he", + "in", + "is", + "it", + "its", + "of", + "on", + "that", + "the", + "to", + "was", + "were", + "will", + "with", + "i", + "me", + "my", + "this", + "these", + "those", + "any", + "or", + "if", + "do", + "have", + "you", + "your", + "all", + "each", + } +) + + +def _tokenize(text: str) -> list[str]: + return [t for t in re.findall(r"[a-z0-9]+", text.lower()) if t not in _STOPWORDS] + + +def build_corpus() -> list[tuple[str, list[str], str]]: + """For each public tool, return (name, tokens, raw_text).""" + from uxarray_mcp import tools as tools_mod + + docs: list[tuple[str, list[str], str]] = [] + for name in tools_mod.__all__: + fn = getattr(tools_mod, name, None) + if fn is None or not callable(fn): + continue + doc = inspect.getdoc(fn) or "" + try: + params = ", ".join(inspect.signature(fn).parameters.keys()) + except (TypeError, ValueError): + params = "" + # Split CamelCase / snake_case into terms so 'calculate_zonal_mean' + # contributes 'calculate', 'zonal', 'mean'. + name_terms = re.sub(r"[_]+", " ", name) + text = f"{name_terms} {name_terms} {doc} {params}" + docs.append((name, _tokenize(text), text)) + return docs + + +def bm25_score( + query: list[str], + doc: list[str], + df: Counter, + n_docs: int, + avgdl: float, + k1: float = 1.5, + b: float = 0.75, +) -> float: + if not doc: + return 0.0 + score = 0.0 + dl = len(doc) + tf = Counter(doc) + for term in query: + n_t = df.get(term, 0) + if n_t == 0: + continue + idf = math.log((n_docs - n_t + 0.5) / (n_t + 0.5) + 1.0) + f = tf[term] + denom = f + k1 * (1.0 - b + b * dl / avgdl) + score += idf * f * (k1 + 1.0) / denom + return score + + +def rank( + query_str: str, corpus: list[tuple[str, list[str], str]], df: Counter, avgdl: float +) -> list[tuple[str, float]]: + q = _tokenize(query_str) + n = len(corpus) + scored = [(name, bm25_score(q, toks, df, n, avgdl)) for name, toks, _ in corpus] + scored.sort(key=lambda x: x[1], reverse=True) + return scored + + +def main() -> int: + corpus = build_corpus() + n_docs = len(corpus) + avgdl = sum(len(toks) for _, toks, _ in corpus) / n_docs if n_docs else 0.0 + df: Counter = Counter() + for _, toks, _ in corpus: + for term in set(toks): + df[term] += 1 + + available = {name for name, _, _ in corpus} + results = [] + print(f"Indexed {n_docs} tools. Avg doc length = {avgdl:.1f} tokens.\n") + print(f"{'prompt':55s} {'expected':28s} {'top1':6s} {'rank':5s}") + print("-" * 100) + + top1 = top3 = top5 = 0 + ranks = [] + missing_expected = [] + + for prompt, expected in PROMPTS: + if expected not in available: + missing_expected.append(expected) + continue + ranked = rank(prompt, corpus, df, avgdl) + names = [n for n, _ in ranked] + try: + r = names.index(expected) + 1 + except ValueError: + r = n_docs + 1 + ranks.append(r) + if r == 1: + top1 += 1 + if r <= 3: + top3 += 1 + if r <= 5: + top5 += 1 + results.append( + { + "prompt": prompt, + "expected": expected, + "rank": r, + "top1": names[0] if names else None, + "top3": names[:3], + } + ) + print( + f"{prompt[:54]:55s} {expected[:27]:28s} " + f"{'✓' if names[0] == expected else '✗':6s} {r:<5d}" + ) + + n = len(results) + summary = { + "indexed_tools": n_docs, + "prompts_scored": n, + "missing_expected_tools": missing_expected, + "top1_accuracy": round(top1 / n, 3) if n else None, + "top3_accuracy": round(top3 / n, 3) if n else None, + "top5_accuracy": round(top5 / n, 3) if n else None, + "mean_rank": round(sum(ranks) / n, 2) if n else None, + "median_rank": sorted(ranks)[n // 2] if n else None, + "worst_rank": max(ranks) if n else None, + } + print() + print("SUMMARY") + for k, v in summary.items(): + print(f" {k:30s} {v}") + + if summary["top1_accuracy"] is not None: + if summary["top1_accuracy"] >= 0.9: + print("\nPASS — top-1 ≥ 90%. Deferred-full tool pool is safe with BM25.") + elif summary["top3_accuracy"] and summary["top3_accuracy"] >= 0.95: + print( + "\nPARTIAL — top-1 below 90% but top-3 ≥ 95%. " + "Show the AI a 3-candidate shortlist." + ) + else: + print( + "\nFAIL — top-1 < 90% and top-3 < 95%. " + "Keep the surface small, or switch to embedding-based retrieval." + ) + + out_dir = Path(__file__).resolve().parent.parent / "results" + out_dir.mkdir(exist_ok=True) + ts = int(time.time()) + out_path = out_dir / f"retrieval_{ts}.json" + out_path.write_text(json.dumps({"summary": summary, "results": results}, indent=2)) + print(f"\nWrote {out_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) From 237384fe8d4f056becb9e572c11ad89acde14f70 Mon Sep 17 00:00:00 2001 From: Rajeev Jain Date: Wed, 10 Jun 2026 15:20:05 -0500 Subject: [PATCH 2/2] Improve tool docstrings to lead with natural-language terms users actually type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Targeted the 7 tools that ranked worst in the BM25 retrieval eval — rewrote each first line to include the words a user would naturally use ("wireframe", "colored map", "ensemble", "time average", "is the endpoint healthy", "start a new session", "list variables") rather than internal jargon. evals/tool_retrieval results, same 30-prompt set: before: top-1 77%, top-3 87%, top-5 93%, mean rank 2.33, worst rank 19 after: top-1 93%, top-3 100%, top-5 100%, mean rank 1.07, worst rank 2 The two remaining rank-2 cases are genuinely ambiguous (plot_mesh vs. plot_mesh_geo; inspect_variable vs. get_capabilities) and the right ones land in the top-3 shortlist — which is what discover_tools will return. Tools touched: create_session, calculate_temporal_mean, calculate_ensemble_mean, diagnose_endpoint, inspect_variable, plot_mesh, plot_variable, plot_mesh_geo, get_capabilities. Behavior unchanged; only the leading docstring sentence moves. Pre-commit (including mypy) and the full test suite (295 tests) pass. --- src/uxarray_mcp/tools/advanced.py | 4 ++-- src/uxarray_mcp/tools/capabilities.py | 2 +- src/uxarray_mcp/tools/frontdoor.py | 2 +- src/uxarray_mcp/tools/plotting.py | 2 +- src/uxarray_mcp/tools/remote_tools.py | 6 +++--- src/uxarray_mcp/tools/stateful.py | 2 +- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/uxarray_mcp/tools/advanced.py b/src/uxarray_mcp/tools/advanced.py index 8248912..dd0e02d 100644 --- a/src/uxarray_mcp/tools/advanced.py +++ b/src/uxarray_mcp/tools/advanced.py @@ -767,7 +767,7 @@ def calculate_temporal_mean( session_id: str | None = None, result_name: str | None = None, ) -> dict[str, Any]: - """Calculate a temporal mean from a time-aware dataset.""" + """Take the time average of a variable over the time dimension (temporal mean / time-mean climatology, optionally grouped by month or season).""" tracker = OperationTracker("calculate_temporal_mean", session_id=session_id) ds = xr.open_dataset(data_path) if variable_name not in ds: @@ -877,7 +877,7 @@ def calculate_ensemble_mean( session_id: str | None = None, result_name: str | None = None, ) -> dict[str, Any]: - """Calculate an ensemble mean across multiple files.""" + """Average a variable across multiple ensemble members (one file per member) — the ensemble mean / multi-model mean.""" tracker = OperationTracker("calculate_ensemble_mean", session_id=session_id) ensemble = _load_ensemble(variable_name, data_paths) result_data = ensemble.mean(dim="ensemble_member") diff --git a/src/uxarray_mcp/tools/capabilities.py b/src/uxarray_mcp/tools/capabilities.py index cd43f7e..d942177 100644 --- a/src/uxarray_mcp/tools/capabilities.py +++ b/src/uxarray_mcp/tools/capabilities.py @@ -14,7 +14,7 @@ def get_capabilities( grid_path: str, data_path: Optional[str] = None, ) -> Dict[str, Any]: - """Discover applicable tools and UXarray capabilities for a mesh and dataset. + """Recommend which MCP tools and UXarray API methods are applicable for a given mesh and dataset. Inspects the grid topology and data variable locations to determine which MCP server tools and native UXarray API methods can be applied to this diff --git a/src/uxarray_mcp/tools/frontdoor.py b/src/uxarray_mcp/tools/frontdoor.py index 3f18655..d771f2a 100644 --- a/src/uxarray_mcp/tools/frontdoor.py +++ b/src/uxarray_mcp/tools/frontdoor.py @@ -391,7 +391,7 @@ def diagnose_endpoint( inspect_netcdf: bool = True, probe_timeout_seconds: int = 60, ) -> dict[str, Any]: - """Diagnose endpoint status, setup, or path readability.""" + """Check whether the HPC Globus Compute endpoint is healthy, active, and reachable — endpoint status, worker setup validation, and remote file readability.""" from uxarray_mcp.tools.execution_control import ( endpoint_status, probe_path_access, diff --git a/src/uxarray_mcp/tools/plotting.py b/src/uxarray_mcp/tools/plotting.py index 2b35cf6..03bd5fd 100644 --- a/src/uxarray_mcp/tools/plotting.py +++ b/src/uxarray_mcp/tools/plotting.py @@ -143,7 +143,7 @@ def plot_mesh_geo( session_id: Optional[str] = None, dataset_handle: Optional[str] = None, ) -> list[Any]: - """Render a mesh with geographic context: coastlines, borders, lakes, and optional terrain. + """Make a geographic map with coastlines, borders, lakes, and optional terrain overlaid on the cell outlines. This tool produces a Cartopy-backed geographic plot that shows the mesh topology overlaid on a proper map with natural geographic features. diff --git a/src/uxarray_mcp/tools/remote_tools.py b/src/uxarray_mcp/tools/remote_tools.py index bda4792..66f69c1 100644 --- a/src/uxarray_mcp/tools/remote_tools.py +++ b/src/uxarray_mcp/tools/remote_tools.py @@ -285,7 +285,7 @@ def inspect_variable( endpoint: str | None = None, session_id: str | None = None, ) -> Dict[str, Any]: - """Inspect data variables with optional HPC execution. + """List the data variables in a dataset and report their dimensions, shape, units, and metadata. Optional HPC execution. Parameters ---------- @@ -409,7 +409,7 @@ def plot_mesh( session_id: str | None = None, dataset_handle: str | None = None, ) -> list[Any]: - """Render a mesh wireframe PNG with optional HPC execution. + """Show the mesh as a plain wireframe plot — render mesh edges and cell outlines as a PNG, no geographic context. Optional HPC execution. When use_remote=True the mesh is rendered on the HPC endpoint and only the base64-encoded PNG is transferred back to the client — no large @@ -496,7 +496,7 @@ def plot_variable( session_id: str | None = None, dataset_handle: str | None = None, ) -> list[Any]: - """Render a face-centered variable as a filled-polygon PNG with optional HPC execution. + """Plot a face-centered variable as a colored map / choropleth — filled-polygon PNG of the variable's value per cell. Optional HPC execution. Parameters ---------- diff --git a/src/uxarray_mcp/tools/stateful.py b/src/uxarray_mcp/tools/stateful.py index 05d9642..f84fa6f 100644 --- a/src/uxarray_mcp/tools/stateful.py +++ b/src/uxarray_mcp/tools/stateful.py @@ -35,7 +35,7 @@ def create_session(name: str | None = None) -> dict[str, Any]: - """Create a persistent scientific session for datasets and results.""" + """Start a new analysis session that persists datasets, results, and workflow state across tool calls.""" session = create_session_record(name) return attach_provenance( {