diff --git a/benches/README.md b/benches/README.md index cb4096a8e..d9853d6fd 100644 --- a/benches/README.md +++ b/benches/README.md @@ -73,6 +73,33 @@ python benches/llamacpp_bench.py tput \ Use `--json-out ` to save machine-readable results. +## Reasoning-pattern benchmark + +`reasoning_bench.py` compares Direct, Best-of-N, Tree-of-Thought, and +Graph-of-Thought through one benchmark-specific inferlet. Reference answers +remain in the Python harness and are never sent to the model. + +Build the inferlet: + +```bash +cargo build --manifest-path inferlets/reasoning-benchmark/Cargo.toml \ + --target wasm32-wasip2 --release +``` + +Run the bundled smoke problems: + +```bash +uv --project sdk/python-server run python benches/reasoning_bench.py \ + --driver cuda_native \ + --model Qwen/Qwen3-0.6B \ + --pattern all \ + --json-out .tmp/reasoning-smoke.json +``` + +For GSM8K, pass a JSONL file containing the official `question` and `answer` +fields. The harness extracts the numeric reference after GSM8K's `####` +delimiter. + ## Fairness defaults - vLLM prefix caching is disabled. diff --git a/benches/reasoning_bench.py b/benches/reasoning_bench.py new file mode 100644 index 000000000..dc112de5a --- /dev/null +++ b/benches/reasoning_bench.py @@ -0,0 +1,474 @@ +#!/usr/bin/env python3 +"""Benchmark Pie reasoning patterns on GSM8K-format JSONL data. + +The reference answer stays in this harness and is never sent to the inferlet. +Official GSM8K records (`question`, `answer` containing `#### N`) and simple +records (`id`, `question`, `answer`) are both accepted. +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import math +import re +import statistics +import subprocess +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + + +ROOT = Path(__file__).resolve().parent.parent +INFERLET = "reasoning-benchmark" +PATTERNS = ("direct", "best_of_n", "tree_of_thought", "graph_of_thought") +NUMBER_RE = re.compile(r"[-+]?(?:\d[\d,]*\.?\d*|\.\d+)(?:[eE][-+]?\d+)?") + + +@dataclass(frozen=True) +class Problem: + id: str + question: str + answer: str + + +@dataclass +class RunResult: + problem_id: str + pattern: str + repetition: int + correct: bool + oracle_correct: bool + correct_candidates: int + expected_answer: str + predicted_answer: str | None + latency_s: float + output: dict[str, Any] | None + engine_stats_delta: dict[str, int] + error: str | None = None + + +def normalize_number(value: str | int | float | None) -> str | None: + if value is None: + return None + text = str(value).strip().replace(",", "").replace("$", "") + matches = NUMBER_RE.findall(text) + if not matches: + return None + raw = matches[-1].replace(",", "") + try: + number = float(raw) + except ValueError: + return None + if not math.isfinite(number): + return None + if number.is_integer(): + return str(int(number)) + return format(number, ".12g") + + +def reference_answer(raw: Any) -> str: + text = str(raw) + if "####" in text: + text = text.rsplit("####", 1)[1] + normalized = normalize_number(text) + if normalized is None: + raise ValueError(f"reference answer has no numeric value: {raw!r}") + return normalized + + +def load_problems(path: Path, limit: int | None) -> list[Problem]: + problems: list[Problem] = [] + with path.open(encoding="utf-8") as handle: + for line_number, line in enumerate(handle, 1): + if not line.strip(): + continue + record = json.loads(line) + question = str(record["question"]).strip() + if not question: + raise ValueError(f"{path}:{line_number}: empty question") + problems.append( + Problem( + id=str(record.get("id", f"{path.stem}-{line_number}")), + question=question, + answer=reference_answer(record["answer"]), + ) + ) + if limit is not None and len(problems) >= limit: + break + if not problems: + raise ValueError(f"no problems loaded from {path}") + return problems + + +def inferlet_paths() -> tuple[Path, Path, str]: + try: + import tomllib + except ModuleNotFoundError: + import tomli as tomllib + + directory = ROOT / "inferlets" / INFERLET + candidates = [ + directory / "target" / "wasm32-wasip2" / "release" / "reasoning_benchmark.wasm", + directory / "target" / "wasm32-wasip2" / "debug" / "reasoning_benchmark.wasm", + ] + wasm = next((path for path in candidates if path.exists()), None) + if wasm is None: + raise FileNotFoundError( + "reasoning-benchmark wasm is missing; build it with " + "`cargo build --manifest-path inferlets/reasoning-benchmark/Cargo.toml " + "--target wasm32-wasip2 --release`" + ) + manifest = directory / "Pie.toml" + package = tomllib.loads(manifest.read_text(encoding="utf-8"))["package"] + return wasm, manifest, f"{package['name']}@{package['version']}" + + +def build_config(args: argparse.Namespace): + from pie.config import ( + AuthConfig, + Config, + DriverConfig, + ModelConfig, + ServerConfig, + TelemetryConfig, + ) + + device = [part.strip() for part in args.device.split(",")] + options: dict[str, Any] = {} + if args.driver in {"cuda_native", "portable"}: + options["max_num_kv_pages"] = args.kv_pages + if args.driver == "portable" and args.portable_n_gpu_layers is not None: + options["n_gpu_layers"] = args.portable_n_gpu_layers + return Config( + server=ServerConfig(port=0, max_concurrent_processes=1), + auth=AuthConfig(enabled=False), + telemetry=TelemetryConfig(), + models=[ + ModelConfig( + name="default", + hf_repo=args.model, + driver=DriverConfig( + type=args.driver, + device=device, + options=options, + ), + ) + ], + ) + + +async def model_stats(client) -> dict[str, int]: + ok, raw = await client.query("model_status", "") + if not ok: + raise RuntimeError(f"model_status query failed: {raw}") + return { + key: int(value) + for key, value in json.loads(raw).items() + if isinstance(value, (int, float)) + } + + +def stats_delta(before: dict[str, int], after: dict[str, int]) -> dict[str, int]: + return { + key: after.get(key, 0) - before.get(key, 0) + for key in sorted(set(before) | set(after)) + if key.endswith((".total_batches", ".total_tokens_processed")) + } + + +def payload_for(problem: Problem, pattern: str, args: argparse.Namespace) -> dict[str, Any]: + return { + "pattern": pattern, + "question": problem.question, + "num_candidates": args.num_candidates, + "beam_width": args.beam_width, + "max_tokens": args.max_tokens, + "score_tokens": args.score_tokens, + "temperature": args.temperature, + "top_p": args.top_p, + "thinking": args.thinking, + } + + +def result_for( + problem: Problem, + pattern: str, + repetition: int, + latency_s: float, + output: dict[str, Any] | None, + engine_stats_delta: dict[str, int], + error: str | None, +) -> RunResult: + predicted = normalize_number(output.get("final_answer") if output else None) + candidate_answers = [ + normalize_number(candidate.get("answer")) + for candidate in (output or {}).get("candidates", []) + ] + correct_candidates = sum(answer == problem.answer for answer in candidate_answers) + return RunResult( + problem_id=problem.id, + pattern=pattern, + repetition=repetition, + correct=predicted == problem.answer, + oracle_correct=correct_candidates > 0, + correct_candidates=correct_candidates, + expected_answer=problem.answer, + predicted_answer=predicted, + latency_s=latency_s, + output=output, + engine_stats_delta=engine_stats_delta, + error=error, + ) + + +def print_result(result: RunResult) -> None: + status = "correct" if result.correct else "wrong" + if result.error: + status = "error" + print( + f"{result.problem_id:18} {result.pattern:18} rep={result.repetition} " + f"{status:7} answer={result.predicted_answer!r} " + f"latency={result.latency_s:.2f}s", + flush=True, + ) + + +def parse_pie_run_output(stdout: str) -> dict[str, Any]: + for line in reversed(stdout.splitlines()): + line = line.strip() + if not line: + continue + try: + output = json.loads(line) + except json.JSONDecodeError: + continue + if isinstance(output, dict): + return output + raise ValueError("pie run did not emit a JSON object on stdout") + + +def run_cli(args: argparse.Namespace) -> list[RunResult]: + if not args.config: + raise ValueError("--config is required when --execution-mode=cli") + + problems = load_problems(Path(args.dataset), args.max_problems) + wasm, manifest, _package = inferlet_paths() + patterns = PATTERNS if args.pattern == "all" else (args.pattern,) + results: list[RunResult] = [] + + pie_bin = Path(args.pie_bin) + config = Path(args.config) + for problem in problems: + for pattern in patterns: + for repetition in range(args.repetitions): + payload = payload_for(problem, pattern, args) + cmd = [ + str(pie_bin), + "run", + "--path", + str(wasm), + "--manifest", + str(manifest), + "--config", + str(config), + "--quiet", + "--input", + json.dumps(payload, separators=(",", ":")), + ] + started = time.perf_counter() + output = None + error = None + try: + completed = subprocess.run( + cmd, + cwd=ROOT, + check=True, + capture_output=True, + text=True, + timeout=args.timeout, + ) + output = parse_pie_run_output(completed.stdout) + except Exception as exc: # keep the full experiment running + error = f"{type(exc).__name__}: {exc}" + latency_s = time.perf_counter() - started + result = result_for( + problem, + pattern, + repetition, + latency_s, + output, + {}, + error, + ) + results.append(result) + print_result(result) + return results + + +async def run_embedded(args: argparse.Namespace) -> list[RunResult]: + from pie.server import Server + from pie_client import Event + + problems = load_problems(Path(args.dataset), args.max_problems) + wasm, manifest, package = inferlet_paths() + patterns = PATTERNS if args.pattern == "all" else (args.pattern,) + results: list[RunResult] = [] + + async with Server(build_config(args)) as server: + client = await server.connect() + await client.install_program(wasm, manifest, force_overwrite=True) + + for problem in problems: + for pattern in patterns: + for repetition in range(args.repetitions): + payload = payload_for(problem, pattern, args) + before = await model_stats(client) + started = time.perf_counter() + output = None + error = None + try: + process = await client.launch_process(package, input=payload) + while True: + event, message = await asyncio.wait_for( + process.recv(), timeout=args.timeout + ) + if event == Event.Return: + output = json.loads(message) + break + if event == Event.Error: + raise RuntimeError(str(message)) + except Exception as exc: # keep the full experiment running + error = f"{type(exc).__name__}: {exc}" + latency_s = time.perf_counter() - started + after = await model_stats(client) + result = result_for( + problem, + pattern, + repetition, + latency_s, + output, + stats_delta(before, after), + error, + ) + results.append(result) + print_result(result) + return results + + +async def run(args: argparse.Namespace) -> list[RunResult]: + if args.execution_mode == "cli": + return run_cli(args) + return await run_embedded(args) + + +def summarize(results: list[RunResult]) -> dict[str, Any]: + summary: dict[str, Any] = {} + for pattern in sorted({result.pattern for result in results}): + rows = [result for result in results if result.pattern == pattern] + valid = [result for result in rows if result.error is None] + latencies = [result.latency_s for result in valid] + generated = [ + int(result.output["stats"]["generated_tokens"]) + for result in valid + if result.output is not None + ] + summary[pattern] = { + "runs": len(rows), + "errors": sum(result.error is not None for result in rows), + "accuracy": ( + sum(result.correct for result in rows) / len(rows) if rows else 0.0 + ), + "oracle_candidate_accuracy": ( + sum(result.oracle_correct for result in rows) / len(rows) + if rows + else 0.0 + ), + "mean_correct_candidates": ( + statistics.fmean(result.correct_candidates for result in rows) + if rows + else 0.0 + ), + "mean_latency_s": statistics.fmean(latencies) if latencies else None, + "mean_generated_tokens": statistics.fmean(generated) if generated else None, + } + return summary + + +def write_results(path: Path, args: argparse.Namespace, results: list[RunResult]) -> None: + artifact = { + "config": vars(args), + "summary": summarize(results), + "runs": [asdict(result) for result in results], + } + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps(artifact, indent=2), encoding="utf-8") + + +def parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Pie reasoning-pattern benchmark") + p.add_argument( + "--execution-mode", + choices=("embedded", "cli"), + default="embedded", + help="Use embedded pie.server or shell out to `pie run`.", + ) + p.add_argument( + "--pie-bin", + default="./target/release/pie", + help="Pie binary to use when --execution-mode=cli.", + ) + p.add_argument( + "--config", + default=None, + help="Pie config TOML. Required when --execution-mode=cli.", + ) + p.add_argument("--dataset", default=str(ROOT / "benches" / "reasoning_smoke.jsonl")) + p.add_argument("--pattern", choices=("all", *PATTERNS), default="all") + p.add_argument("--model", default="Qwen/Qwen3-0.6B") + p.add_argument( + "--driver", + choices=("dev", "dummy", "cuda_native", "portable", "vllm", "sglang"), + default="dev", + ) + p.add_argument("--device", default="cuda:0") + p.add_argument("--max-problems", type=int, default=None) + p.add_argument("--repetitions", type=int, default=1) + p.add_argument("--num-candidates", type=int, default=4) + p.add_argument("--beam-width", type=int, default=2) + p.add_argument("--max-tokens", type=int, default=256) + p.add_argument("--score-tokens", type=int, default=16) + p.add_argument("--temperature", type=float, default=0.7) + p.add_argument("--top-p", type=float, default=0.95) + thinking = p.add_mutually_exclusive_group() + thinking.add_argument( + "--thinking", + dest="thinking", + action="store_true", + help="Allow model thinking blocks.", + ) + thinking.add_argument( + "--no-thinking", + dest="thinking", + action="store_false", + help="Use the model/template no-thinking path. This is the default.", + ) + p.set_defaults(thinking=False) + p.add_argument("--timeout", type=float, default=300.0) + p.add_argument("--kv-pages", type=int, default=2048) + p.add_argument("--portable-n-gpu-layers", type=int, default=None) + p.add_argument("--json-out", default=str(ROOT / ".tmp" / "reasoning-benchmark.json")) + return p + + +def main() -> None: + args = parser().parse_args() + results = asyncio.run(run(args)) + summary = summarize(results) + print("\n" + json.dumps(summary, indent=2)) + write_results(Path(args.json_out), args, results) + + +if __name__ == "__main__": + main() diff --git a/benches/reasoning_smoke.jsonl b/benches/reasoning_smoke.jsonl new file mode 100644 index 000000000..d4f80c121 --- /dev/null +++ b/benches/reasoning_smoke.jsonl @@ -0,0 +1,3 @@ +{"id":"smoke-1","question":"Lina has 12 marbles and buys 7 more. She gives 5 marbles away. How many marbles does she have left?","answer":"14"} +{"id":"smoke-2","question":"A train travels 60 miles per hour for 3 hours. How many miles does it travel?","answer":"180"} +{"id":"smoke-3","question":"A box holds 8 rows of 6 pencils. If 9 pencils are removed, how many pencils remain?","answer":"39"} diff --git a/benches/test_reasoning_bench.py b/benches/test_reasoning_bench.py new file mode 100644 index 000000000..8404dcb76 --- /dev/null +++ b/benches/test_reasoning_bench.py @@ -0,0 +1,75 @@ +"""Unit tests for the reasoning benchmark's dataset and answer evaluator.""" +import json +import tempfile +import unittest +from pathlib import Path + +from argparse import Namespace + +from benches.reasoning_bench import ( + Problem, + load_problems, + normalize_number, + payload_for, + reference_answer, +) + + +class ReasoningBenchTests(unittest.TestCase): + def test_normalize_number(self): + self.assertEqual(normalize_number("$1,234.00"), "1234") + self.assertEqual(normalize_number("Final Answer: -2.5"), "-2.5") + self.assertIsNone(normalize_number("no numeric answer")) + + def test_gsm8k_reference_answer(self): + self.assertEqual( + reference_answer("Reasoning with intermediate numbers.\n#### 1,250"), + "1250", + ) + + def test_payload_disables_thinking_by_default(self): + args = Namespace( + num_candidates=4, + beam_width=2, + max_tokens=256, + score_tokens=16, + temperature=0.7, + top_p=0.95, + thinking=False, + ) + payload = payload_for(Problem("p1", "How many?", "42"), "direct", args) + self.assertIs(payload["thinking"], False) + + def test_payload_can_enable_thinking(self): + args = Namespace( + num_candidates=4, + beam_width=2, + max_tokens=256, + score_tokens=16, + temperature=0.7, + top_p=0.95, + thinking=True, + ) + payload = payload_for(Problem("p1", "How many?", "42"), "direct", args) + self.assertIs(payload["thinking"], True) + + def test_loads_official_gsm8k_shape(self): + with tempfile.TemporaryDirectory() as directory: + path = Path(directory) / "test.jsonl" + path.write_text( + json.dumps( + { + "question": "How many?", + "answer": "Compute carefully.\n#### 42", + } + ) + + "\n", + encoding="utf-8", + ) + problems = load_problems(path, None) + self.assertEqual(len(problems), 1) + self.assertEqual(problems[0].answer, "42") + + +if __name__ == "__main__": + unittest.main() diff --git a/inferlets/reasoning-benchmark/Cargo.toml b/inferlets/reasoning-benchmark/Cargo.toml new file mode 100644 index 000000000..fb61254fa --- /dev/null +++ b/inferlets/reasoning-benchmark/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "reasoning-benchmark" +version = "0.1.0" +edition = "2024" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +inferlet = { path = "../../sdk/rust/inferlet" } +futures = "0.3" +serde = { version = "1.0", features = ["derive"] } + +[profile.release] +lto = true diff --git a/inferlets/reasoning-benchmark/Pie.toml b/inferlets/reasoning-benchmark/Pie.toml new file mode 100644 index 000000000..a016e69f7 --- /dev/null +++ b/inferlets/reasoning-benchmark/Pie.toml @@ -0,0 +1,19 @@ +[package] +name = "reasoning-benchmark" +version = "0.1.0" +description = "Unified Direct, Best-of-N, Tree-of-Thought, and Graph-of-Thought benchmark inferlet" +authors = ["Pie Team"] + +[runtime] +core = "^0.2.0" +mcp = "^0.2.0" + +[parameters] +pattern = {type = "string", optional = true, description = "direct | best_of_n | tree_of_thought | graph_of_thought"} +question = {type = "string", description = "Math word problem to solve"} +num_candidates = {type = "int", optional = true, description = "Candidate count for branching patterns (default: 4)"} +beam_width = {type = "int", optional = true, description = "Number of candidates refined by tree_of_thought (default: 2)"} +max_tokens = {type = "int", optional = true, description = "Maximum answer tokens per generation (default: 256)"} +score_tokens = {type = "int", optional = true, description = "Maximum tokens per model-scoring call (default: 16)"} +temperature = {type = "float", optional = true, description = "Candidate sampling temperature (default: 0.7)"} +top_p = {type = "float", optional = true, description = "Candidate top-p value (default: 0.95)"} diff --git a/inferlets/reasoning-benchmark/src/lib.rs b/inferlets/reasoning-benchmark/src/lib.rs new file mode 100644 index 000000000..11b90ad52 --- /dev/null +++ b/inferlets/reasoning-benchmark/src/lib.rs @@ -0,0 +1,621 @@ +//! Unified inferlet for comparing reasoning workflows on math word problems. +//! +//! The inferlet never receives the reference answer. It returns candidates, +//! selections, and execution counters; the external benchmark harness owns +//! deterministic answer evaluation. + +use std::collections::HashMap; +use std::time::Instant; + +use futures::future; +use inferlet::{Context, Result, chat, model::Model, runtime, sample::Sampler}; +use serde::{Deserialize, Serialize}; + +const SYSTEM_PROMPT: &str = "\ +You solve grade-school mathematical word problems carefully. Show concise \ +reasoning and put the final numeric answer on the last line exactly as \ +\"Final Answer: \". Do not use that phrase anywhere else."; + +const SCORE_SYSTEM: &str = "\ +You evaluate candidate solutions to mathematical word problems. Check the \ +problem interpretation and arithmetic. Respond with only one integer from 1 \ +to 10, where 10 means fully correct."; + +const AGGREGATE_SYSTEM: &str = "\ +You synthesize candidate solutions to mathematical word problems. Compare \ +their reasoning, resolve disagreements by recalculating, and return one \ +concise solution. Put the final numeric answer on the last line exactly as \ +\"Final Answer: \"."; + +#[derive(Deserialize)] +struct Input { + #[serde(default = "default_pattern")] + pattern: String, + question: String, + #[serde(default = "default_num_candidates")] + num_candidates: usize, + #[serde(default = "default_beam_width")] + beam_width: usize, + #[serde(default = "default_max_tokens")] + max_tokens: usize, + #[serde(default = "default_score_tokens")] + score_tokens: usize, + #[serde(default = "default_temperature")] + temperature: f32, + #[serde(default = "default_top_p")] + top_p: f32, + #[serde(default)] + thinking: bool, +} + +fn default_pattern() -> String { + "direct".into() +} +fn default_num_candidates() -> usize { + 4 +} +fn default_beam_width() -> usize { + 2 +} +fn default_max_tokens() -> usize { + 256 +} +fn default_score_tokens() -> usize { + 16 +} +fn default_temperature() -> f32 { + 0.7 +} +fn default_top_p() -> f32 { + 0.95 +} + +#[derive(Clone, Serialize)] +struct Candidate { + id: String, + stage: &'static str, + response: String, + answer: Option, + score: Option, + generated_tokens: usize, + generator_steps: usize, +} + +#[derive(Default, Serialize)] +struct ExecutionStats { + elapsed_ms: u128, + generated_tokens: usize, + generator_steps: usize, + context_forks: usize, + generation_calls: usize, + scoring_calls: usize, +} + +impl ExecutionStats { + fn add_score(&mut self, generated: &Generated) { + self.generated_tokens += generated.tokens; + self.generator_steps += generated.steps; + self.scoring_calls += 1; + } +} + +#[derive(Serialize)] +struct Output { + pattern: String, + final_response: String, + final_answer: Option, + selected_candidate_id: Option, + candidates: Vec, + stats: ExecutionStats, +} + +struct Generated { + text: String, + tokens: usize, + steps: usize, +} + +#[inferlet::main] +async fn main(input: Input) -> Result { + validate(&input)?; + let started = Instant::now(); + let model_name = runtime::models() + .first() + .cloned() + .ok_or("No models available")?; + let model = Model::load(&model_name)?; + + let mut answer_root = Context::new(&model)?; + answer_root.system(SYSTEM_PROMPT); + answer_root.user(&input.question); + answer_root.flush().await?; + + let mut score_root = Context::new(&model)?; + score_root.system(SCORE_SYSTEM); + score_root.flush().await?; + + let mut aggregate_root = Context::new(&model)?; + aggregate_root.system(AGGREGATE_SYSTEM); + aggregate_root.flush().await?; + + let mut output = match input.pattern.to_ascii_lowercase().as_str() { + "direct" => run_direct(&input, &model, &answer_root).await?, + "best_of_n" | "best-of-n" | "self_consistency" => { + run_best_of_n(&input, &model, &answer_root).await? + } + "tree_of_thought" | "tree-of-thought" | "tot" => { + run_tree_of_thought(&input, &model, &answer_root, &score_root).await? + } + "graph_of_thought" | "graph-of-thought" | "got" => { + run_graph_of_thought(&input, &model, &answer_root, &aggregate_root).await? + } + other => { + return Err(format!( + "unknown pattern '{other}': expected direct, best_of_n, \ + tree_of_thought, or graph_of_thought" + )); + } + }; + output.stats.elapsed_ms = started.elapsed().as_millis(); + Ok(output) +} + +fn validate(input: &Input) -> Result<()> { + if input.question.trim().is_empty() { + return Err("question must not be empty".into()); + } + if !(1..=16).contains(&input.num_candidates) { + return Err("num_candidates must be in [1, 16]".into()); + } + if !(1..=input.num_candidates).contains(&input.beam_width) { + return Err("beam_width must be in [1, num_candidates]".into()); + } + if input.max_tokens == 0 || input.score_tokens == 0 { + return Err("token budgets must be at least 1".into()); + } + if !(input.temperature.is_finite() && (0.0..=2.0).contains(&input.temperature)) { + return Err("temperature must be in [0.0, 2.0]".into()); + } + if !(input.top_p.is_finite() && input.top_p > 0.0 && input.top_p <= 1.0) { + return Err("top_p must be in (0.0, 1.0]".into()); + } + Ok(()) +} + +async fn run_direct(input: &Input, model: &Model, root: &Context) -> Result { + let generated = generate_answer( + root.fork()?, + model, + input.max_tokens, + input.temperature, + input.top_p, + input.thinking, + None, + ) + .await?; + let candidate = candidate("direct-0", "answer", generated, None); + let mut stats = ExecutionStats { + context_forks: 1, + ..Default::default() + }; + add_candidate_stats(&mut stats, &candidate); + + Ok(Output { + pattern: "direct".into(), + final_response: candidate.response.clone(), + final_answer: candidate.answer.clone(), + selected_candidate_id: Some(candidate.id.clone()), + candidates: vec![candidate], + stats, + }) +} + +async fn run_best_of_n(input: &Input, model: &Model, root: &Context) -> Result { + let generated = generate_candidates(input, model, root, "sample", None).await?; + let candidates: Vec = generated + .into_iter() + .enumerate() + .map(|(idx, generated)| candidate(&format!("sample-{idx}"), "sample", generated, None)) + .collect(); + let selected = majority_vote_index(&candidates).unwrap_or(0); + let chosen = &candidates[selected]; + let mut stats = ExecutionStats { + context_forks: candidates.len(), + ..Default::default() + }; + for candidate in &candidates { + add_candidate_stats(&mut stats, candidate); + } + + Ok(Output { + pattern: "best_of_n".into(), + final_response: chosen.response.clone(), + final_answer: chosen.answer.clone(), + selected_candidate_id: Some(chosen.id.clone()), + candidates, + stats, + }) +} + +async fn run_tree_of_thought( + input: &Input, + model: &Model, + answer_root: &Context, + score_root: &Context, +) -> Result { + let initial = generate_candidates(input, model, answer_root, "initial", None).await?; + let mut candidates: Vec = initial + .into_iter() + .enumerate() + .map(|(idx, generated)| candidate(&format!("initial-{idx}"), "initial", generated, None)) + .collect(); + let mut stats = ExecutionStats { + context_forks: candidates.len(), + ..Default::default() + }; + for candidate in &candidates { + add_candidate_stats(&mut stats, candidate); + } + + score_candidates( + &input.question, + input.thinking, + input.score_tokens, + model, + score_root, + &mut candidates, + &mut stats, + ) + .await?; + + let mut ranked: Vec = (0..candidates.len()).collect(); + ranked.sort_by_key(|&idx| std::cmp::Reverse(candidates[idx].score.unwrap_or(0))); + ranked.truncate(input.beam_width); + + let refine_futures = ranked.iter().map(|&idx| { + let previous = candidates[idx].response.clone(); + let prompt = format!( + "Previous candidate solution:\n{previous}\n\n\ + Critique its interpretation and arithmetic, then produce a corrected \ + complete solution to the original problem." + ); + let ctx = answer_root.fork(); + async move { + generate_answer( + ctx?, + model, + input.max_tokens, + input.temperature, + input.top_p, + input.thinking, + Some(&prompt), + ) + .await + } + }); + let refined = future::join_all(refine_futures) + .await + .into_iter() + .collect::>>()?; + stats.context_forks += refined.len(); + + let first_refined = candidates.len(); + for (idx, generated) in refined.into_iter().enumerate() { + let candidate = candidate(&format!("refined-{idx}"), "refined", generated, None); + add_candidate_stats(&mut stats, &candidate); + candidates.push(candidate); + } + score_candidates( + &input.question, + input.thinking, + input.score_tokens, + model, + score_root, + &mut candidates[first_refined..], + &mut stats, + ) + .await?; + + let selected = (first_refined..candidates.len()) + .max_by_key(|&idx| candidates[idx].score.unwrap_or(0)) + .unwrap_or(0); + let chosen = &candidates[selected]; + + Ok(Output { + pattern: "tree_of_thought".into(), + final_response: chosen.response.clone(), + final_answer: chosen.answer.clone(), + selected_candidate_id: Some(chosen.id.clone()), + candidates, + stats, + }) +} + +async fn run_graph_of_thought( + input: &Input, + model: &Model, + answer_root: &Context, + aggregate_root: &Context, +) -> Result { + let generated = generate_candidates(input, model, answer_root, "proposal", None).await?; + let mut candidates: Vec = generated + .into_iter() + .enumerate() + .map(|(idx, generated)| candidate(&format!("proposal-{idx}"), "proposal", generated, None)) + .collect(); + let mut stats = ExecutionStats { + context_forks: candidates.len() + 1, + ..Default::default() + }; + for candidate in &candidates { + add_candidate_stats(&mut stats, candidate); + } + + let proposals = candidates + .iter() + .enumerate() + .map(|(idx, candidate)| format!("Candidate {}:\n{}", idx + 1, candidate.response)) + .collect::>() + .join("\n\n"); + let prompt = format!( + "Original problem:\n{}\n\nCandidate solutions:\n{}\n\n\ + Synthesize the most reliable final solution.", + input.question, proposals + ); + let generated = generate_answer( + aggregate_root.fork()?, + model, + input.max_tokens, + 0.0, + 1.0, + input.thinking, + Some(&prompt), + ) + .await?; + let aggregate = candidate("aggregate-0", "aggregate", generated, None); + add_candidate_stats(&mut stats, &aggregate); + let final_response = aggregate.response.clone(); + let final_answer = aggregate.answer.clone(); + let selected_candidate_id = Some(aggregate.id.clone()); + candidates.push(aggregate); + + Ok(Output { + pattern: "graph_of_thought".into(), + final_response, + final_answer, + selected_candidate_id, + candidates, + stats, + }) +} + +async fn generate_candidates( + input: &Input, + model: &Model, + root: &Context, + _stage: &'static str, + prompt: Option<&str>, +) -> Result> { + let contexts = (0..input.num_candidates) + .map(|_| root.fork()) + .collect::>>()?; + let futures = contexts.into_iter().map(|ctx| { + generate_answer( + ctx, + model, + input.max_tokens, + input.temperature, + input.top_p, + input.thinking, + prompt, + ) + }); + future::join_all(futures) + .await + .into_iter() + .collect::>>() +} + +async fn score_candidates( + question: &str, + thinking: bool, + max_tokens: usize, + model: &Model, + root: &Context, + candidates: &mut [Candidate], + stats: &mut ExecutionStats, +) -> Result<()> { + let futures = + candidates.iter().map(|candidate| { + let prompt = format!( + "Problem:\n{question}\n\nCandidate solution:\n{}\n\nScore:", + candidate.response + ); + let ctx = root.fork(); + async move { + generate_answer(ctx?, model, max_tokens, 0.0, 1.0, thinking, Some(&prompt)).await + } + }); + let scores = future::join_all(futures) + .await + .into_iter() + .collect::>>()?; + stats.context_forks += scores.len(); + for (candidate, generated) in candidates.iter_mut().zip(scores) { + candidate.score = parse_score(&generated.text); + stats.add_score(&generated); + } + Ok(()) +} + +async fn generate_answer( + mut ctx: Context, + model: &Model, + max_tokens: usize, + temperature: f32, + top_p: f32, + thinking: bool, + prompt: Option<&str>, +) -> Result { + if let Some(prompt) = prompt { + ctx.user(prompt); + } + ctx.cue(); + if !thinking { + ctx.append(&model.tokenizer().encode("\n\n\n\n")); + } + let sampler = if temperature <= 0.0 { + Sampler::Argmax + } else { + Sampler::TopP { + temperature, + p: top_p, + } + }; + let stops = chat::stop_tokens(model); + let mut generator = ctx.generate(sampler).max_tokens(max_tokens).stop(&stops); + let mut decoder = chat::Decoder::new(model); + let mut text = String::new(); + let mut tokens = 0; + let mut steps = 0; + + while let Some(step) = generator.next()? { + let output = step.execute().await?; + tokens += output.tokens.len(); + steps += 1; + match decoder.feed(&output.tokens)? { + chat::Event::Delta(delta) => text.push_str(&delta), + chat::Event::Done(done) => { + text = done; + break; + } + chat::Event::Idle | chat::Event::Interrupt(_) => {} + } + } + Ok(Generated { + text, + tokens, + steps, + }) +} + +fn candidate(id: &str, stage: &'static str, generated: Generated, score: Option) -> Candidate { + Candidate { + id: id.into(), + stage, + answer: extract_final_answer(&generated.text), + response: generated.text, + score, + generated_tokens: generated.tokens, + generator_steps: generated.steps, + } +} + +fn add_candidate_stats(stats: &mut ExecutionStats, candidate: &Candidate) { + stats.generated_tokens += candidate.generated_tokens; + stats.generator_steps += candidate.generator_steps; + stats.generation_calls += 1; +} + +fn majority_vote_index(candidates: &[Candidate]) -> Option { + let mut counts: HashMap<&str, usize> = HashMap::new(); + for answer in candidates + .iter() + .filter_map(|candidate| candidate.answer.as_deref()) + { + *counts.entry(answer).or_default() += 1; + } + candidates + .iter() + .enumerate() + .filter_map(|(idx, candidate)| { + let answer = candidate.answer.as_deref()?; + Some((idx, counts.get(answer).copied().unwrap_or(0))) + }) + .max_by_key(|(idx, count)| (*count, std::cmp::Reverse(*idx))) + .map(|(idx, _)| idx) +} + +fn extract_final_answer(text: &str) -> Option { + let lower = text.to_ascii_lowercase(); + let marker = "final answer:"; + let pos = lower.rfind(marker)?; + let tail = &text[pos + marker.len()..]; + extract_last_number(tail) +} + +fn extract_last_number(text: &str) -> Option { + let mut tokens = Vec::new(); + let mut current = String::new(); + for ch in text.chars() { + if ch.is_ascii_digit() || matches!(ch, '-' | '+' | '.' | ',' | '/') { + current.push(ch); + } else if !current.is_empty() { + tokens.push(std::mem::take(&mut current)); + } + } + if !current.is_empty() { + tokens.push(current); + } + tokens + .into_iter() + .rev() + .find_map(|token| normalize_number(&token)) +} + +fn normalize_number(raw: &str) -> Option { + let mut value = raw + .trim_matches(|c: char| matches!(c, '+' | '.' | ',' | '/')) + .replace(',', ""); + if value.is_empty() || !value.chars().any(|c| c.is_ascii_digit()) { + return None; + } + if value.ends_with(".0") { + value.truncate(value.len() - 2); + } + Some(value) +} + +fn parse_score(text: &str) -> Option { + text.split(|c: char| !c.is_ascii_digit()) + .filter_map(|part| part.parse::().ok()) + .find(|score| (1..=10).contains(score)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn extracts_marked_numeric_answer() { + assert_eq!( + extract_final_answer("Work.\nFinal Answer: $1,234.0"), + Some("1234".into()) + ); + } + + #[test] + fn rejects_unmarked_numeric_answer() { + assert_eq!(extract_final_answer("Work.\nThe answer is 1234."), None); + } + + #[test] + fn rejects_final_answer_without_colon() { + assert_eq!(extract_final_answer("Work.\nFinal Answer is 1234."), None); + } + + #[test] + fn vote_prefers_first_candidate_on_tie() { + let make = |id: &str, answer: &str| Candidate { + id: id.into(), + stage: "sample", + response: String::new(), + answer: Some(answer.into()), + score: None, + generated_tokens: 0, + generator_steps: 0, + }; + let candidates = vec![make("a", "1"), make("b", "2")]; + assert_eq!(majority_vote_index(&candidates), Some(0)); + } +} diff --git a/tests/inferlets/run_all.py b/tests/inferlets/run_all.py index 5f6a20778..654c3f181 100644 --- a/tests/inferlets/run_all.py +++ b/tests/inferlets/run_all.py @@ -14,6 +14,7 @@ from test_tree_of_thought import test_tree_of_thought from test_graph_of_thought import test_graph_of_thought from test_recursion_of_thought import test_recursion_of_thought +from test_reasoning_benchmark import test_reasoning_benchmark from test_agent_react import test_agent_react from test_agent_codeact import test_agent_codeact from test_image_fetch import test_image_fetch @@ -50,6 +51,7 @@ test_tree_of_thought, test_graph_of_thought, test_recursion_of_thought, + test_reasoning_benchmark, test_best_of_n, test_knowledge_graph, test_agent_react, diff --git a/tests/inferlets/test_reasoning_benchmark.py b/tests/inferlets/test_reasoning_benchmark.py new file mode 100644 index 000000000..4d98a45ae --- /dev/null +++ b/tests/inferlets/test_reasoning_benchmark.py @@ -0,0 +1,36 @@ +"""E2E interface test for the unified reasoning benchmark inferlet.""" +import json + +from conftest import run_inferlet, run_tests + + +async def test_reasoning_benchmark(client, args): + for pattern in ( + "direct", + "best_of_n", + "tree_of_thought", + "graph_of_thought", + ): + output = await run_inferlet( + client, + "reasoning-benchmark", + { + "pattern": pattern, + "question": "Mia has 7 apples and buys 5 more. How many apples?", + "num_candidates": 2, + "beam_width": 1, + "max_tokens": 48, + "score_tokens": 8, + }, + timeout=args.timeout, + ) + result = json.loads(output) + assert result["pattern"] == pattern + assert isinstance(result["candidates"], list) + assert result["candidates"] + assert result["stats"]["generation_calls"] >= 1 + assert result["stats"]["context_forks"] >= 1 + + +if __name__ == "__main__": + run_tests([test_reasoning_benchmark])