From 13eb921e3949ad4d9e9b1ed48b9577976a570bff Mon Sep 17 00:00:00 2001 From: saarang Date: Mon, 30 Mar 2026 17:18:35 -0700 Subject: [PATCH] Add problem automation contract and validator --- .github/workflows/validate-problems.yml | 20 + README.md | 27 ++ docs/problem-authoring-contract.md | 149 ++++++ docs/problem-automation-roadmap.md | 78 +++ docs/problem-validation-contract.md | 107 ++++ scripts/validate_problem.py | 620 ++++++++++++++++++++++++ templates/problem-template.def.py | 72 +++ templates/problem-template.md | 30 ++ 8 files changed, 1103 insertions(+) create mode 100644 .github/workflows/validate-problems.yml create mode 100644 docs/problem-authoring-contract.md create mode 100644 docs/problem-automation-roadmap.md create mode 100644 docs/problem-validation-contract.md create mode 100644 scripts/validate_problem.py create mode 100644 templates/problem-template.def.py create mode 100644 templates/problem-template.md diff --git a/.github/workflows/validate-problems.yml b/.github/workflows/validate-problems.yml new file mode 100644 index 0000000..2baf56e --- /dev/null +++ b/.github/workflows/validate-problems.yml @@ -0,0 +1,20 @@ +name: Validate Problems + +on: + pull_request: + push: + branches: ["main"] + +jobs: + contract: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Structural validation + run: python scripts/validate_problem.py --runtime none --format text diff --git a/README.md b/README.md index c174524..93f847e 100644 --- a/README.md +++ b/README.md @@ -41,3 +41,30 @@ The `problem.md` file should contain a description of the problem written in Mar - `const`: boolean Once you add a problem, make sure to test both correct (slow/fast) and incorrect submissions. Let us know if you encounter any issues/bugs! + +## Contract And Validation + +For new work, use the contract-first docs in this repo: + +- `docs/problem-authoring-contract.md` +- `docs/problem-validation-contract.md` +- `docs/problem-automation-roadmap.md` + +There is now a machine-readable validator: + +```bash +python scripts/validate_problem.py --runtime none +python scripts/validate_problem.py relu softmax --runtime first +python scripts/validate_problem.py gemm-relu-divide --runtime all --enforce-wrong-answer-rejection +``` + +Recommended workflow: + +1. pass structural validation in CI +2. validate locally on CUDA or Together H100 +3. validate through the Modal-backed product runtime before merge when the runtime path matters + +Templates for new problems live in: + +- `templates/problem-template.def.py` +- `templates/problem-template.md` diff --git a/docs/problem-authoring-contract.md b/docs/problem-authoring-contract.md new file mode 100644 index 0000000..7de4d57 --- /dev/null +++ b/docs/problem-authoring-contract.md @@ -0,0 +1,149 @@ +# Problem Authoring Contract + +This document defines the backward-compatible authoring format for `tensara/problems`. + +## Goals + +- Make problem authoring deterministic for agents. +- Preserve compatibility with existing published problems. +- Separate authoring truth from runtime truth. +- Keep `sync-problems.ts` compatible while stricter validation is rolled out. + +## Stable Problem Layout + +Each problem lives in: + +```text +problems// +├── def.py +└── problem.md +``` + +Both files remain required. + +## `problem.md` Contract + +### Required Frontmatter + +These fields are already expected by sync and remain required: + +```yaml +slug: "relu" +title: "ReLU" +difficulty: "EASY" +author: "sarthak" +``` + +### Recommended Frontmatter + +These are now the recommended fields for new problems. They are backward-compatible because unknown frontmatter is ignored by the current sync path. + +```yaml +tags: ["activation-function"] +source: + kind: "kernelbench" + repo: "ScalingIntelligence/KernelBench" + path: "KernelBench/level2/63_Gemm_ReLU_Divide.py" +authoring: + mode: "exact-port" # exact-port | normalized +validation: + deterministic: true + sample_path: true + wrong_answer_rejection: true + runtime_targets: ["local-cuda", "modal-sample"] +``` + +### Content Rules + +- `slug` must match the directory name. +- `difficulty` must be `EASY`, `MEDIUM`, or `HARD`. +- Markdown body should describe the mathematical contract, not implementation trivia. +- If the problem is adapted from an external source, include attribution in the body or `source` block. + +## `def.py` Contract + +### Required Class Shape + +`def.py` must define one primary problem class that subclasses `Problem`. + +Canonical class naming: + +- directory slug: `gemm-relu-divide` +- class name: `gemm_relu_divide` + +### Required Methods + +New problems should implement: + +- `reference_solution(self, *args)` +- `generate_test_cases(self)` +- `generate_sample(self)` +- `verify_result(self, expected_output, actual_output)` + +Backward-compatible rule: + +- existing problems are allowed to keep current behavior if they already work +- new validation treats extra required args on `generate_test_cases`, `generate_sample`, or `verify_result` as contract errors + +### Parameters + +Preferred: + +- define `parameters = [...]` in `def.py` + +Fallback still supported: + +- override `get_function_signature(...)` +- include legacy `parameters` frontmatter in `problem.md` + +New problems should define parameters in `def.py`, because that is the most machine-readable source for agents and validators. + +### Test Case Shape + +`generate_test_cases()` should return a list of dicts. + +Each dict should contain: + +- `name`: stable string label +- `create_inputs`: zero-arg callable returning the reference inputs + +Additional keys such as dimensions, seed, or descriptive metadata are encouraged. + +### Sample Shape + +Canonical form: + +- `generate_sample()` returns a single dict with the same shape as one test case + +Backward-compatible form still accepted by the validator: + +- a list containing exactly one dict + +### Verifier Rules + +`verify_result(...)` should: + +- accept the exact reference output +- reject intentionally perturbed outputs +- return `(bool, debug_info)` +- provide debug info that is useful for automated repair + +## Canonical Authoring Pattern + +1. Define the mathematical contract in `problem.md`. +2. Define parameters in `def.py`. +3. Make `reference_solution(...)` deterministic. +4. Use `Problem.get_seed(...)` for seeded test generation. +5. Add one small sample case. +6. Add several larger generated cases. +7. Make verifier failures explicit and debuggable. + +## Compatibility Policy + +This contract is intentionally additive: + +- existing frontmatter fields keep working +- old problems are not forced to adopt new optional metadata immediately +- validation distinguishes structural errors from migration warnings + +The long-term direction is to move all new problems toward this contract and then tighten sync around it. diff --git a/docs/problem-automation-roadmap.md b/docs/problem-automation-roadmap.md new file mode 100644 index 0000000..4dd22c6 --- /dev/null +++ b/docs/problem-automation-roadmap.md @@ -0,0 +1,78 @@ +# Problem Automation Roadmap + +This roadmap defines how `tensara/problems` becomes agent-friendly without breaking current workflows. + +## Immediate Direction + +The first priority is not contests. It is deterministic ingestion. + +That means: + +1. agents can author to a stable contract +2. validators can reject weak or broken problems automatically +3. accepted problems are safe to sync and publish + +## Phase 1: Contract First + +Ship: + +- a stable authoring contract +- a stable validation contract +- a machine-readable validator +- structural CI on every PR +- templates for new problems + +This PR implements Phase 1. + +## Phase 2: Stronger Runtime Validation + +Add: + +- routine H100 validation for cheap local CUDA truth +- a Modal-backed product-runtime check as the final acceptance gate +- persisted validation artifacts that agents can inspect + +## Phase 3: Verifier Strength + +Add: + +- adversarial wrong-answer checks +- mutation tests for common failure modes +- problem-family-specific negative cases + +This is how testcase quality becomes measurable instead of subjective. + +## Phase 4: Sync Hardening + +Tighten `sync-problems.ts` so sync fails on structural contract violations instead of only warning. + +Examples: + +- missing parameters +- bad method signatures +- slug/frontmatter mismatch + +## Phase 5: Automated Growth + +Once the contract and validators are stable: + +- agents can open daily problem PRs +- CI can auto-classify failures +- maintainers can review mostly by exception +- accepted problems can auto-sync to production + +## Phase 6: Contest Automation + +Only after validation is trusted: + +- assemble contest sets automatically +- require hidden-test quality and difficulty spread +- publish and schedule through the same validated pipeline + +## Non-Goals For This Phase + +- rewriting all existing problems +- forcing immediate frontmatter migration +- making Modal validation mandatory in this repo before the surrounding auth/runtime glue is ready + +The immediate goal is a reliable contract and validator surface that later Modal automation can plug into. diff --git a/docs/problem-validation-contract.md b/docs/problem-validation-contract.md new file mode 100644 index 0000000..99b5e8a --- /dev/null +++ b/docs/problem-validation-contract.md @@ -0,0 +1,107 @@ +# Problem Validation Contract + +This document defines the validation ladder for `tensara/problems`. + +## Validation Tiers + +### Tier 1: Structural Validation + +Runs in ordinary CI without GPUs. + +Checks: + +- required files exist +- frontmatter has required fields +- slug matches directory +- `def.py` has a `Problem` subclass +- required methods exist +- method signatures match the stable contract +- parameters are present in `def.py` or `get_function_signature(...)` is overridden + +This tier should run on every PR. + +### Tier 2: Local CUDA Validation + +Runs on a real GPU such as Together H100. + +Checks: + +- `generate_sample()` executes +- `generate_test_cases()` executes +- `reference_solution(...)` runs on CUDA +- `verify_result(...)` accepts correct outputs +- verifier rejects perturbed outputs when enabled +- `get_flops(...)` is positive when provided + +This tier is the fast author-side correctness gate. + +### Tier 3: Product Runtime Validation + +Runs through the same runtime path as the real Tensara product. + +Authoritative target: + +- Modal-backed sample/checker endpoints from `tensara/tensara` + +Why this is authoritative: + +- it exercises the same engine loading path used by the product +- it catches signature, allocation, and runner mismatches that local file execution can miss + +This is the final acceptance gate for automation. + +## Source of Truth + +Validation truth should be ordered as: + +1. structural CI +2. local CUDA runtime +3. Modal/product runtime + +Local H100 success is necessary but not sufficient. Modal runtime is the product-truth layer. + +## Standard Validator Output + +Validators should emit machine-readable results: + +```json +{ + "ok": true, + "summary": { + "problems_checked": 1, + "errors": 0, + "warnings": 1 + }, + "results": [ + { + "slug": "relu", + "ok": true, + "diagnostics": [ + {"level": "warning", "code": "missing_source_metadata", "message": "Recommended metadata not present"} + ] + } + ] +} +``` + +That output shape is chosen so agents can repair failures automatically. + +## Required Checks For New Problems + +New problems should pass: + +- structural validation +- sample execution +- at least one generated test case +- wrong-answer rejection +- Modal sample or checker validation before merge + +## Migration Policy + +Backward compatibility matters, so the validator distinguishes: + +- `error`: merge blocker +- `warning`: migration or quality issue +- `info`: useful metadata only + +Existing published problems should initially be brought under structural validation first. Stronger runtime requirements can then tighten in phases. diff --git a/scripts/validate_problem.py b/scripts/validate_problem.py new file mode 100644 index 0000000..0a8a7c3 --- /dev/null +++ b/scripts/validate_problem.py @@ -0,0 +1,620 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import ast +import importlib.util +import json +import sys +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any + + +REQUIRED_FRONTMATTER = ("slug", "title", "difficulty", "author") +VALID_DIFFICULTIES = {"EASY", "MEDIUM", "HARD"} + + +@dataclass +class Diagnostic: + level: str + code: str + message: str + path: str | None = None + + +@dataclass +class ProblemResult: + slug: str + ok: bool = True + diagnostics: list[Diagnostic] = field(default_factory=list) + runtime: dict[str, Any] = field(default_factory=dict) + + def add(self, level: str, code: str, message: str, path: Path | None = None) -> None: + if level == "error": + self.ok = False + self.diagnostics.append( + Diagnostic( + level=level, + code=code, + message=message, + path=str(path) if path else None, + ) + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Validate Tensara problem definitions") + parser.add_argument("targets", nargs="*", help="Problem slugs or paths under problems/") + parser.add_argument( + "--repo-root", + default=Path(__file__).resolve().parents[1], + type=Path, + help="Path to the problems repo root", + ) + parser.add_argument( + "--problems-dir", + type=Path, + help="Path to the problems directory (defaults to /problems)", + ) + parser.add_argument( + "--runtime", + choices=("none", "sample", "first", "all"), + default="none", + help="Runtime validation mode", + ) + parser.add_argument( + "--engine-path", + type=Path, + help="Path to tensara/engine for runtime imports", + ) + parser.add_argument( + "--enforce-wrong-answer-rejection", + action="store_true", + help="Treat verifier acceptance of perturbed outputs as an error", + ) + parser.add_argument( + "--format", + choices=("text", "json"), + default="text", + help="Output format", + ) + parser.add_argument( + "--warnings-as-errors", + action="store_true", + help="Treat warnings as errors in the final exit code", + ) + return parser.parse_args() + + +def parse_frontmatter(markdown_text: str) -> tuple[dict[str, str], str]: + if not markdown_text.startswith("---\n"): + return {}, markdown_text + + lines = markdown_text.splitlines() + end_index = None + for idx in range(1, len(lines)): + if lines[idx].strip() == "---": + end_index = idx + break + + if end_index is None: + return {}, markdown_text + + frontmatter_lines = lines[1:end_index] + content = "\n".join(lines[end_index + 1 :]).strip() + parsed: dict[str, str] = {} + + for raw_line in frontmatter_lines: + line = raw_line.rstrip() + if not line or line.lstrip().startswith("#"): + continue + if line.startswith(" ") or line.startswith("\t") or ":" not in line: + continue + key, value = line.split(":", 1) + parsed[key.strip()] = value.strip().strip('"').strip("'") + + return parsed, content + + +def normalize_slug(value: str) -> str: + return value.replace("_", "-") + + +def discover_problem_dirs(problems_dir: Path, targets: list[str]) -> list[Path]: + if not targets: + return sorted( + path + for path in problems_dir.iterdir() + if path.is_dir() and not path.name.startswith(".") and path.name != "__pycache__" + ) + + resolved: list[Path] = [] + for target in targets: + target_path = Path(target) + if target_path.exists(): + if target_path.is_dir(): + resolved.append(target_path) + else: + resolved.append(target_path.parent) + continue + + candidate = problems_dir / target + if candidate.exists(): + resolved.append(candidate) + continue + + raise FileNotFoundError(f"Could not resolve target: {target}") + + seen = set() + unique: list[Path] = [] + for path in resolved: + if path not in seen: + seen.add(path) + unique.append(path) + return unique + + +def required_positional_after_self(fn: ast.FunctionDef) -> int: + positional = list(fn.args.posonlyargs) + list(fn.args.args) + if positional and positional[0].arg == "self": + positional = positional[1:] + required_count = max(0, len(positional) - len(fn.args.defaults)) + return required_count + + +def total_positional_after_self(fn: ast.FunctionDef) -> int: + positional = list(fn.args.posonlyargs) + list(fn.args.args) + if positional and positional[0].arg == "self": + positional = positional[1:] + return len(positional) + + +def find_problem_class(module: ast.Module) -> ast.ClassDef | None: + for node in module.body: + if isinstance(node, ast.ClassDef): + for base in node.bases: + if isinstance(base, ast.Name) and base.id == "Problem": + return node + if isinstance(base, ast.Attribute) and base.attr == "Problem": + return node + return None + + +def has_parameters_assignment(problem_class: ast.ClassDef) -> bool: + for node in problem_class.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "parameters": + return True + if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name): + if node.target.id == "parameters": + return True + return False + + +def get_method_map(problem_class: ast.ClassDef) -> dict[str, ast.FunctionDef]: + return { + node.name: node + for node in problem_class.body + if isinstance(node, ast.FunctionDef) + } + + +def analyze_structure(problem_dir: Path) -> ProblemResult: + slug = problem_dir.name + result = ProblemResult(slug=slug) + + def_path = problem_dir / "def.py" + md_path = problem_dir / "problem.md" + + if not def_path.exists(): + result.add("error", "missing_def", "Missing def.py", def_path) + return result + if not md_path.exists(): + result.add("error", "missing_markdown", "Missing problem.md", md_path) + return result + + frontmatter, markdown_body = parse_frontmatter(md_path.read_text()) + if not frontmatter: + result.add("error", "missing_frontmatter", "problem.md is missing YAML frontmatter", md_path) + for field_name in REQUIRED_FRONTMATTER: + if not frontmatter.get(field_name): + result.add( + "error", + "missing_frontmatter_field", + f"Missing required frontmatter field: {field_name}", + md_path, + ) + + if markdown_body == "": + result.add("warning", "empty_markdown_body", "problem.md body is empty", md_path) + + if frontmatter.get("slug") and frontmatter["slug"] != slug: + if normalize_slug(frontmatter["slug"]) == slug: + result.add( + "warning", + "legacy_slug_style", + f"Frontmatter slug '{frontmatter['slug']}' should migrate to '{slug}'", + md_path, + ) + else: + result.add( + "error", + "slug_mismatch", + f"Frontmatter slug '{frontmatter['slug']}' does not match directory '{slug}'", + md_path, + ) + + difficulty = frontmatter.get("difficulty") + if difficulty and difficulty not in VALID_DIFFICULTIES: + result.add( + "error", + "invalid_difficulty", + f"Difficulty must be one of {sorted(VALID_DIFFICULTIES)}", + md_path, + ) + + python_source = def_path.read_text() + try: + module = ast.parse(python_source, filename=str(def_path)) + except SyntaxError as exc: + result.add("error", "python_syntax_error", str(exc), def_path) + return result + + problem_class = find_problem_class(module) + if problem_class is None: + result.add("error", "missing_problem_class", "No class inheriting from Problem found", def_path) + return result + + expected_class_name = slug.replace("-", "_") + if problem_class.name != expected_class_name: + result.add( + "warning", + "class_name_mismatch", + f"Expected class name '{expected_class_name}', found '{problem_class.name}'", + def_path, + ) + + methods = get_method_map(problem_class) + for method_name in ("reference_solution", "generate_test_cases", "generate_sample", "verify_result"): + if method_name not in methods: + result.add("error", "missing_method", f"Missing required method: {method_name}", def_path) + + if "generate_test_cases" in methods and required_positional_after_self(methods["generate_test_cases"]) != 0: + result.add( + "error", + "bad_generate_test_cases_signature", + "generate_test_cases() must not require arguments beyond self", + def_path, + ) + + if "generate_sample" in methods and required_positional_after_self(methods["generate_sample"]) != 0: + result.add( + "error", + "bad_generate_sample_signature", + "generate_sample() must not require arguments beyond self", + def_path, + ) + + if "verify_result" in methods: + verify_fn = methods["verify_result"] + if required_positional_after_self(verify_fn) != 2 or total_positional_after_self(verify_fn) < 2: + result.add( + "error", + "bad_verify_result_signature", + "verify_result() must accept expected_output and actual_output after self", + def_path, + ) + + has_parameters = has_parameters_assignment(problem_class) + has_signature_override = "get_function_signature" in methods + if not has_parameters and not has_signature_override: + result.add( + "error", + "missing_parameters", + "Problem must define parameters or override get_function_signature()", + def_path, + ) + + if "tags" not in frontmatter: + result.add( + "info", + "missing_tags", + "problem.md is missing tags; current sync tolerates this, but agents benefit from tags", + md_path, + ) + + if "source" not in frontmatter: + result.add( + "info", + "missing_source_metadata", + "Recommended source metadata is missing from problem.md frontmatter", + md_path, + ) + + return result + + +def resolve_engine_path(repo_root: Path, explicit_engine_path: Path | None) -> Path | None: + if explicit_engine_path: + return explicit_engine_path + + sibling_engine = repo_root.parent / "tensara" / "engine" + if sibling_engine.exists(): + return sibling_engine + + env_path = Path.cwd() / "engine" + if env_path.exists(): + return env_path + + return None + + +def load_problem_instance(problem_dir: Path, engine_path: Path): + if str(engine_path) not in sys.path: + sys.path.insert(0, str(engine_path)) + + module_name = problem_dir.name.replace("-", "_") + def_path = problem_dir / "def.py" + spec = importlib.util.spec_from_file_location(module_name, def_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Could not import {def_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + problem_class = getattr(module, module_name) + return problem_class() + + +def clone_output(value, torch_module): + if isinstance(value, torch_module.Tensor): + return value.clone() + if isinstance(value, tuple): + return tuple(clone_output(item, torch_module) for item in value) + if isinstance(value, list): + return [clone_output(item, torch_module) for item in value] + if isinstance(value, (int, float, bool)): + return value + raise TypeError(f"Unsupported output type for cloning: {type(value)!r}") + + +def perturb_output(value, torch_module): + if isinstance(value, torch_module.Tensor): + mutated = value.clone() + if mutated.numel() == 0: + raise ValueError("Cannot perturb an empty tensor output") + flat = mutated.reshape(-1) + delta = 1.0 if mutated.dtype.is_floating_point else 1 + flat[0] = flat[0] + delta + return mutated + if isinstance(value, tuple) and value: + first = perturb_output(value[0], torch_module) + return (first, *value[1:]) + if isinstance(value, list) and value: + mutated = list(value) + mutated[0] = perturb_output(mutated[0], torch_module) + return mutated + if isinstance(value, bool): + return not value + if isinstance(value, (int, float)): + return value + 1 + raise TypeError(f"Unsupported output type for perturbation: {type(value)!r}") + + +def normalize_case_collection(value, *, allow_single_dict: bool) -> list[dict[str, Any]]: + if isinstance(value, dict): + return [value] if allow_single_dict else [] + if isinstance(value, list) and all(isinstance(item, dict) for item in value): + return value + return [] + + +def validate_runtime( + repo_root: Path, + problem_dir: Path, + result: ProblemResult, + runtime_mode: str, + engine_path: Path | None, + enforce_wrong_answer_rejection: bool, +) -> None: + if runtime_mode == "none": + return + + if engine_path is None: + result.add( + "error", + "missing_engine_path", + "Runtime validation needs tensara/engine. Pass --engine-path or place a sibling tensara clone next to this repo.", + ) + return + + import torch + + if not torch.cuda.is_available(): + result.add("error", "cuda_unavailable", "Runtime validation requires a CUDA-enabled torch environment") + return + + try: + problem = load_problem_instance(problem_dir, engine_path) + except Exception as exc: # noqa: BLE001 + result.add("error", "import_failed", f"Failed to import problem: {exc}", problem_dir / "def.py") + return + + if getattr(problem, "name", None) != problem_dir.name: + result.add( + "warning", + "name_mismatch", + f"Problem instance name '{getattr(problem, 'name', None)}' does not match slug '{problem_dir.name}'", + problem_dir / "def.py", + ) + + try: + sample_cases = normalize_case_collection(problem.generate_sample(), allow_single_dict=True) + except Exception as exc: # noqa: BLE001 + result.add("error", "sample_generation_failed", f"generate_sample() failed: {exc}", problem_dir / "def.py") + return + + if len(sample_cases) != 1: + result.add( + "error", + "bad_sample_shape", + "generate_sample() must return one dict or a one-element list of dicts", + problem_dir / "def.py", + ) + return + + try: + test_cases = normalize_case_collection(problem.generate_test_cases(), allow_single_dict=False) + except Exception as exc: # noqa: BLE001 + result.add("error", "test_case_generation_failed", f"generate_test_cases() failed: {exc}", problem_dir / "def.py") + return + + if not test_cases: + result.add("error", "empty_test_cases", "generate_test_cases() returned no usable test cases", problem_dir / "def.py") + return + + selected_cases = [("sample", sample_cases[0])] + if runtime_mode == "sample": + pass + elif runtime_mode == "first": + selected_cases.append(("test#1", test_cases[0])) + else: + selected_cases.extend((f"test#{index}", case) for index, case in enumerate(test_cases, start=1)) + + executed_cases: list[str] = [] + for case_name, case in selected_cases: + if "name" not in case or not isinstance(case["name"], str): + result.add("error", "case_missing_name", f"{case_name} is missing a string 'name' field", problem_dir / "def.py") + return + if "create_inputs" not in case or not callable(case["create_inputs"]): + result.add("error", "case_missing_factory", f"{case_name} is missing callable create_inputs", problem_dir / "def.py") + return + + try: + raw_inputs = case["create_inputs"]() + if isinstance(raw_inputs, tuple): + inputs = raw_inputs + elif isinstance(raw_inputs, list): + inputs = tuple(raw_inputs) + else: + inputs = (raw_inputs,) + expected = problem.reference_solution(*inputs) + correct = clone_output(expected, torch) + correct_ok, correct_info = problem.verify_result(expected, correct) + except Exception as exc: # noqa: BLE001 + result.add("error", "runtime_case_failed", f"{case_name} failed during runtime validation: {exc}", problem_dir / "def.py") + return + + if not correct_ok: + result.add( + "error", + "verifier_rejected_reference", + f"{case_name} verifier rejected the reference output: {correct_info}", + problem_dir / "def.py", + ) + return + + try: + wrong = perturb_output(expected, torch) + wrong_ok, wrong_info = problem.verify_result(expected, wrong) + if wrong_ok: + level = "error" if enforce_wrong_answer_rejection else "warning" + result.add( + level, + "verifier_accepted_perturbed_output", + f"{case_name} verifier accepted an intentionally perturbed output: {wrong_info}", + problem_dir / "def.py", + ) + except Exception as exc: # noqa: BLE001 + result.add( + "warning", + "wrong_answer_check_failed", + f"{case_name} wrong-answer rejection check could not run: {exc}", + problem_dir / "def.py", + ) + + try: + flops = problem.get_flops(case) + if flops is not None and flops <= 0: + result.add("error", "non_positive_flops", f"{case_name} reported non-positive FLOPs: {flops}", problem_dir / "def.py") + return + except Exception as exc: # noqa: BLE001 + result.add("warning", "flops_check_failed", f"{case_name} FLOPs check failed: {exc}", problem_dir / "def.py") + + executed_cases.append(case["name"]) + + result.runtime = { + "mode": runtime_mode, + "engine_path": str(engine_path), + "executed_cases": executed_cases, + } + + +def summarize(results: list[ProblemResult], warnings_as_errors: bool) -> tuple[dict[str, int], bool]: + summary = { + "problems_checked": len(results), + "errors": sum(1 for result in results for diag in result.diagnostics if diag.level == "error"), + "warnings": sum(1 for result in results for diag in result.diagnostics if diag.level == "warning"), + "infos": sum(1 for result in results for diag in result.diagnostics if diag.level == "info"), + } + ok = summary["errors"] == 0 and (not warnings_as_errors or summary["warnings"] == 0) + return summary, ok + + +def print_text(results: list[ProblemResult], summary: dict[str, int]) -> None: + for result in results: + status = "ok" if result.ok else "failed" + print(f"[{status}] {result.slug}") + for diag in result.diagnostics: + if diag.level == "info": + continue + path_suffix = f" ({diag.path})" if diag.path else "" + print(f" - {diag.level.upper()} {diag.code}: {diag.message}{path_suffix}") + if result.runtime: + print(f" - runtime: mode={result.runtime['mode']} cases={', '.join(result.runtime['executed_cases'])}") + print() + print( + "Summary: " + f"{summary['problems_checked']} problems, " + f"{summary['errors']} errors, " + f"{summary['warnings']} warnings, " + f"{summary['infos']} infos" + ) + + +def main() -> int: + args = parse_args() + repo_root = args.repo_root.resolve() + problems_dir = (args.problems_dir or repo_root / "problems").resolve() + + results: list[ProblemResult] = [] + for problem_dir in discover_problem_dirs(problems_dir, args.targets): + result = analyze_structure(problem_dir) + validate_runtime( + repo_root=repo_root, + problem_dir=problem_dir, + result=result, + runtime_mode=args.runtime, + engine_path=resolve_engine_path(repo_root, args.engine_path.resolve() if args.engine_path else None), + enforce_wrong_answer_rejection=args.enforce_wrong_answer_rejection, + ) + results.append(result) + + summary, ok = summarize(results, warnings_as_errors=args.warnings_as_errors) + payload = { + "ok": ok, + "summary": summary, + "results": [asdict(result) for result in results], + } + + if args.format == "json": + print(json.dumps(payload, indent=2)) + else: + print_text(results, summary) + + return 0 if ok else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/templates/problem-template.def.py b/templates/problem-template.def.py new file mode 100644 index 0000000..a5ac96f --- /dev/null +++ b/templates/problem-template.def.py @@ -0,0 +1,72 @@ +import torch +from typing import Any, Dict, List, Tuple + +from problem import Problem + + +class template_problem(Problem): + is_exact = False + + parameters = [ + {"name": "input", "type": "float", "pointer": True, "const": True}, + {"name": "output", "type": "float", "pointer": True, "const": False}, + {"name": "n", "type": "size_t", "pointer": False, "const": False}, + ] + + def __init__(self): + super().__init__(name="template-problem") + + def reference_solution(self, input_tensor: torch.Tensor) -> torch.Tensor: + with torch.no_grad(), torch.autocast("cuda", enabled=False, dtype=input_tensor.dtype): + return input_tensor + + def generate_test_cases(self) -> List[Dict[str, Any]]: + dtype = self.param_dtype("input") + test_cases = [] + for size in (256, 1024, 4096): + seed = Problem.get_seed(f"{self.name}_{size}") + test_cases.append( + { + "name": f"n={size}", + "n": size, + "seed": seed, + "create_inputs": lambda size=size, seed=seed, dtype=dtype: ( + torch.randn( + (size,), + device="cuda", + dtype=dtype, + generator=torch.Generator(device="cuda").manual_seed(seed), + ), + ), + } + ) + return test_cases + + def generate_sample(self) -> Dict[str, Any]: + dtype = self.param_dtype("input") + return { + "name": "n=8", + "n": 8, + "create_inputs": lambda dtype=dtype: ( + torch.tensor([1, -2, 3, -4, 5, -6, 7, -8], device="cuda", dtype=dtype), + ), + } + + def verify_result( + self, expected_output: torch.Tensor, actual_output: torch.Tensor + ) -> Tuple[bool, Dict[str, Any]]: + is_close = torch.allclose(actual_output, expected_output, rtol=1e-5, atol=1e-6) + if is_close: + return True, {} + + diff = actual_output - expected_output + return False, { + "max_difference": torch.max(torch.abs(diff)).item(), + "mean_difference": torch.mean(torch.abs(diff)).item(), + } + + def get_flops(self, test_case: Dict[str, Any]) -> int: + return test_case["n"] + + def get_extra_params(self, test_case: Dict[str, Any]) -> List[Any]: + return [test_case["n"]] diff --git a/templates/problem-template.md b/templates/problem-template.md new file mode 100644 index 0000000..c92d539 --- /dev/null +++ b/templates/problem-template.md @@ -0,0 +1,30 @@ +--- +slug: "template-problem" +title: "Template Problem" +difficulty: "EASY" +author: "your-name" +tags: ["template"] +source: + kind: "internal" +authoring: + mode: "normalized" +validation: + deterministic: true + sample_path: true + wrong_answer_rejection: true + runtime_targets: ["local-cuda", "modal-sample"] +--- + +Describe the mathematical contract here. + +## Input + +- `input`: describe the input tensor + +## Output + +- `output`: describe the output tensor + +## Notes + +- Include source attribution or implementation caveats here.