diff --git a/tester/operator_compare/README.md b/tester/operator_compare/README.md new file mode 100644 index 00000000..3cab0b68 --- /dev/null +++ b/tester/operator_compare/README.md @@ -0,0 +1,286 @@ +# Operator Compare + +`tester/operator_compare` 复用 PaddleAPITest config、`TensorConfig` 和 Paddle→Torch converter,对指定算子 config 进行 Paddle / Torch / custom 多实现数值对比,并生成结构化结果与报告。 + +## 1. 使用场景 + +用于在开发、调试或扩展算子时,对单个或少量 API config 做多实现数值对比、误差分析和报告归档。 + +## 2. 功能简介 + +### 2.1 执行流程 + +```text +PaddleAPITest config txt + │ + ▼ +APIConfig / TensorConfig 解析输入 + │ + ▼ +CompareSuite 展开 implementation × dtype × precision + │ + ▼ +runner 执行 paddle / torch / custom 实现 + │ + ▼ +metrics 计算 standard 与 pairwise 误差 + │ + ▼ +artifacts / report 写出 JSON、CSV、Markdown 和 figures +``` + +### 2.2 核心能力 + +- **复用 config**:直接读取 PaddleAPITest config 行或 config txt 文件。 +- **复用输入生成**:通过 `APIConfig` / `TensorConfig` 构造 Paddle 和 Torch 共用输入。 +- **复用 Torch reference**:通过 `Paddle2TorchConverter` 将 Paddle API 映射到 Torch 实现。 +- **多实现对比**:支持 `paddle`、`torch` 和注册的 custom implementation。 +- **矩阵展开**:支持按 implementation、dtype、precision 展开比较。 +- **统一输出**:生成 `results.json`、`summary.csv`、pairwise CSV、`report.md` 和可选图示。 + +### 2.3 目录结构 + +```text +tester/operator_compare/ +├── spec.py # CompareCase / ImplementationSpec / CompareSuite +├── config_loader.py # config txt -> CompareCase +├── implementations.py # paddle/torch/custom 实现声明、参数绑定、dtype matrix 展开 +├── runner.py # 多实现执行、standard 和 pairwise metrics 调度 +├── metrics.py # 误差指标和 tensor fingerprint +├── artifacts.py # env/results/summary/pairwise 写出 +├── report.py # Markdown 报告和图示 +└── profile.py # profile 状态和 sqlite kernel summary 解析 +``` + +工具入口: + +```text +tools/operator_compare.py +tools/operator_compare_profile_worker.py +``` + +## 3. 开发指南 + +### 3.1 优先方式:只增加 config + +适用于: + +- Paddle API 能直接执行。 +- Torch reference 已在 `tester/paddle_to_torch/mapping.json` / `rules.py` 中支持。 +- 输出可以由通用输出处理转换为 `torch.Tensor`。 + +示例: + +```bash +python tools/operator_compare.py \ + --case 'paddle.add(Tensor([2], "float32"), Tensor([2], "float32"), )' \ + --implementations paddle,torch \ + --standard 'torch|config|default' \ + --output-dir test_log_operator_compare/add_smoke +``` + +批量覆盖 shape / dtype 时,把多条 config 写入 txt,再通过 `--config-file` 运行。 + +### 3.2 Torch reference 不存在:补充 converter rule + +如果 `Paddle2TorchConverter.convert(api_name)` 不支持目标 API,需要补充 Paddle→Torch 映射。 + +简单映射优先在 `tester/paddle_to_torch/mapping.json` 中使用 `GenericRule` 字段: + +```json +"paddle.add": { + "torch_api": "torch.add", + "paddle_torch_args_map": { + "x": "input", + "y": "other" + }, + "torch_kwargs": { + "alpha": 1 + } +} +``` + +如果需要特殊计算逻辑,在 `tester/paddle_to_torch/rules.py` 增加专用 rule,并在 `mapping.json` 中配置: + +```json +"paddle.xxx": { + "Rule": "YourRule" +} +``` + +### 3.3 `_C_ops` 无签名算子:补充参数名绑定 + +部分 `_C_ops` 没有 Python signature,而 converter rule 可能通过 `locals().get("x")` 等命名变量取参数。此时在 `tester/operator_compare/implementations.py` 的 `MANUAL_ARGUMENT_NAMES` 中补充位置参数名: + +```python +MANUAL_ARGUMENT_NAMES: dict[str, tuple[str, ...]] = { + "paddle._C_ops.fused_linear_param_grad_add": ( + "x", + "dout", + "dweight", + "dbias", + "multi_precision", + "has_bias", + ), +} +``` + +要求: + +1. key 使用完整 API name。 +2. tuple 顺序与 config 中位置参数顺序一致。 +3. 名称与 converter rule 中读取的 locals 名称一致。 +4. 仅在 `inspect.signature()` 无法绑定时增加手工映射。 + +### 3.4 通用实现无法表达:注册 custom implementation + +当 `paddle` / `torch` 通用路径无法表达某个实验实现时,在 `tester/operator_compare/implementations.py` 中注册 custom runner: + +```python +from tester.operator_compare.implementations import register_custom_implementation + + +def my_runner(case, spec): + api_config = case.tensors["api_config"] + ... + return torch_tensor + + +register_custom_implementation("my_impl", my_runner) +``` + +runner 接口: + +```python +def runner(case: CompareCase, spec: ImplementationSpec) -> torch.Tensor: + ... +``` + +要求: + +- 返回值必须是 `torch.Tensor`。 +- 尽量复用 `case.tensors["api_config"]` 中的 `APIConfig` / `TensorConfig`。 +- custom 名称通过 `--implementations paddle,torch,my_impl` 使用。 + +### 3.5 特殊输出处理 + +通用输出处理在 `to_torch_tensor()` 中完成: + +- `torch.Tensor`:直接返回。 +- `paddle.Tensor`:通过 DLPack 转成 Torch tensor。 +- `list` / `tuple`:取第一个 tensor 输出。 + +如果新算子的有效输出不是第一个 tensor,优先在 `implementations.py` 中集中补充输出选择逻辑,并增加对应测试。 + +## 4. 测试指南 + +新增算子时建议覆盖: + +1. config 能被 `APIConfig` 正确解析。 +2. `build_compare_suite()` 能生成预期 implementation id。 +3. Paddle / Torch 两个实现都能运行成功。 +4. target 相对 standard 的误差符合预期。 +5. 如果新增 `_C_ops` 参数绑定,使用对应 `_C_ops` config 做 CLI smoke。 +6. 如果新增 custom implementation,使用包含 custom implementation id 的 CLI smoke 验证。 + +语法检查: + +```bash +python -m py_compile \ + tools/operator_compare.py \ + tools/operator_compare_profile_worker.py \ + tester/operator_compare/__init__.py \ + tester/operator_compare/artifacts.py \ + tester/operator_compare/config_loader.py \ + tester/operator_compare/implementations.py \ + tester/operator_compare/metrics.py \ + tester/operator_compare/profile.py \ + tester/operator_compare/report.py \ + tester/operator_compare/runner.py \ + tester/operator_compare/spec.py +``` + +真实执行 smoke 示例见下文 `add smoke` 和 `fused_linear_param_grad_add smoke`。 + +## 5. 使用示例 + +### 5.1 单条 config + +```bash +python tools/operator_compare.py \ + --case 'paddle.Tensor.__abs__(Tensor([], "float32"), )' \ + --implementations paddle,torch \ + --standard 'torch|config|default' \ + --output-dir test_log_operator_compare/abs_smoke +``` + +### 5.2 多条 config 文件 + +```bash +python tools/operator_compare.py \ + --config-file tester/api_config/5_accuracy/example.txt \ + --op paddle.Tensor.__abs__ \ + --implementations paddle,torch \ + --output-dir test_log_operator_compare/abs_file +``` + +### 5.3 指定 dtype matrix + +```bash +python tools/operator_compare.py \ + --case 'paddle.Tensor.__abs__(Tensor([2], "float32"), )' \ + --implementations paddle,torch \ + --dtypes float32,float64 \ + --standard 'torch|float32|default' +``` + +### 5.4 `add` smoke + +```bash +python tools/operator_compare.py \ + --case 'paddle.add(Tensor([2], "float32"), Tensor([2], "float32"), )' \ + --implementations paddle,torch \ + --standard 'torch|config|default' \ + --output-dir test_log_operator_compare/add_smoke_flat +``` + +### 5.5 `fused_linear_param_grad_add` smoke + +```bash +python tools/operator_compare.py \ + --case 'paddle._C_ops.fused_linear_param_grad_add(Tensor([8, 4], "float32"), Tensor([8, 4], "float32"), Tensor([4, 4], "float32"), None, False, False, )' \ + --implementations paddle,torch \ + --standard 'torch|config|default' \ + --output-dir test_log_operator_compare/fused_linear_param_grad_add_smoke_flat +``` + +`--standard` 中包含 `|` 时需要加引号。 + +## 6. 参数与输出 + +### 6.1 常用参数 + +| 参数 | 说明 | +| --- | --- | +| `--case` | 单条 PaddleAPITest config 字符串。 | +| `--config-file` | PaddleAPITest config txt 文件。 | +| `--op` | 可选 API name 过滤,用于 config 文件。 | +| `--implementations` | 逗号分隔实现,默认 `paddle,torch`。 | +| `--dtypes` | 逗号分隔 dtype 覆盖矩阵;不指定时使用 config 原 dtype。 | +| `--precisions` | 逗号分隔精度策略名,默认 `default`。 | +| `--standard` | 标准实现 id,例如 `torch|config|default`。 | +| `--metrics-dtype` | 指标计算 dtype,支持 `fp32` / `fp64`。 | +| `--no-fingerprint` | 关闭输出 tensor SHA256 fingerprint。 | +| `--output-dir` | 输出目录。 | + +### 6.2 输出文件 + +```text +env.json # Python/Paddle/Torch/CUDA/env/config 信息 +results.json # 完整结构化结果 +summary.csv # 每个实现相对 standard 的误差摘要 +pairwise_summary.csv # target vs reference pairwise 误差 +reference_pairwise_summary.csv # reference 内部 pairwise 误差 +report.md # Markdown 报告 +figures/ # 图示,依赖 matplotlib/numpy +``` diff --git a/tester/operator_compare/__init__.py b/tester/operator_compare/__init__.py new file mode 100644 index 00000000..9d48db4f --- /dev/null +++ b/tester/operator_compare/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/tester/operator_compare/artifacts.py b/tester/operator_compare/artifacts.py new file mode 100644 index 00000000..adeddba4 --- /dev/null +++ b/tester/operator_compare/artifacts.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +import csv +import importlib.metadata +import json +import os +import pathlib +import platform +import sys +from dataclasses import asdict +from datetime import datetime +from typing import Any + +from .spec import CompareSuite, ImplementationResult, PairwiseResult + + +def timestamped_output_dir(root: pathlib.Path, op_name: str) -> pathlib.Path: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + base = root / op_name / timestamp + out_dir = base + suffix = 1 + while out_dir.exists(): + out_dir = pathlib.Path(f"{base}_{suffix:02d}") + suffix += 1 + out_dir.mkdir(parents=True) + return out_dir + + +def metric_dict(metric) -> dict[str, Any] | None: + return None if metric is None else asdict(metric) + + +def package_version(name: str) -> str: + try: + return importlib.metadata.version(name) + except importlib.metadata.PackageNotFoundError: + return "unavailable" + + +def collect_env_info(suite: CompareSuite) -> dict[str, Any]: + info: dict[str, Any] = { + "python_executable": sys.executable, + "python_version": sys.version.replace("\n", " "), + "platform": platform.platform(), + "numpy_version": package_version("numpy"), + "transformer_engine_version": package_version("transformer-engine"), + "transformer_engine_torch_version": package_version("transformer-engine-torch"), + "CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES"), + "LD_PRELOAD": os.environ.get("LD_PRELOAD"), + "TE_CUBLASLT_PRELOAD": os.environ.get("TE_CUBLASLT_PRELOAD"), + "NVIDIA_TF32_OVERRIDE": os.environ.get("NVIDIA_TF32_OVERRIDE"), + "config": { + "op_name": suite.op_name, + "standard_id": suite.standard_id, + "metrics_dtype": suite.metrics_dtype, + "reference_pairwise_metrics_dtype": suite.reference_pairwise_metrics_dtype, + "enable_fingerprint": suite.enable_fingerprint, + **suite.metadata, + }, + } + try: + import paddle + + info["paddle_version"] = paddle.__version__ + except Exception as err: + info["paddle_error"] = f"{type(err).__name__}: {err}" + try: + import torch + + info.update( + { + "torch_version": torch.__version__, + "torch_cuda_version": torch.version.cuda, + "torch_cudnn_version": torch.backends.cudnn.version(), + "torch_allow_tf32": torch.backends.cuda.matmul.allow_tf32, + } + ) + try: + info["torch_float32_matmul_precision"] = torch.get_float32_matmul_precision() + except Exception as err: + info["torch_float32_matmul_precision_error"] = f"{type(err).__name__}: {err}" + if torch.cuda.is_available(): + info["gpu_name"] = torch.cuda.get_device_name(0) + info["gpu_capability"] = list(torch.cuda.get_device_capability(0)) + except Exception as err: + info["torch_error"] = f"{type(err).__name__}: {err}" + return info + + +def metadata_value(result: ImplementationResult, key: str, default: Any = None) -> Any: + return result.metadata.get(key, default) + + +def case_metadata_value(result: ImplementationResult, key: str, default: Any = None) -> Any: + return result.metadata.get("case_metadata", {}).get(key, default) + + +def result_to_dict(result: ImplementationResult, include_tensor: bool = False) -> dict[str, Any]: + data = { + "case_id": result.case_id, + "id": result.spec.id, + "display_name": result.spec.display_name, + "group": result.spec.group, + "dtype": result.spec.dtype, + "multi_precision": result.spec.multi_precision, + "status": result.status, + "output_dtype": result.output_dtype, + "metrics_vs_standard": metric_dict(result.metrics_vs_standard), + "error": result.error, + "metadata": result.metadata, + } + if include_tensor and result.tensor is not None: + data["tensor_shape"] = list(result.tensor.shape) + return data + + +def pairwise_to_dict(pairwise: PairwiseResult) -> dict[str, Any]: + return { + "case_id": pairwise.case_id, + "actual_id": pairwise.actual.spec.id, + "actual_display_name": pairwise.actual.spec.display_name, + "actual_group": pairwise.actual.spec.group, + "actual_dtype": pairwise.actual.spec.dtype, + "actual_multi_precision": pairwise.actual.spec.multi_precision, + "actual_output_dtype": pairwise.actual.output_dtype, + "expect_id": pairwise.expect.spec.id, + "expect_display_name": pairwise.expect.spec.display_name, + "expect_group": pairwise.expect.spec.group, + "expect_dtype": pairwise.expect.spec.dtype, + "expect_multi_precision": pairwise.expect.spec.multi_precision, + "expect_output_dtype": pairwise.expect.output_dtype, + "metrics": metric_dict(pairwise.metrics), + } + + +def write_json(path: pathlib.Path, data: Any) -> None: + path.write_text( + json.dumps(data, indent=2, ensure_ascii=False, sort_keys=True) + "\n", encoding="utf-8" + ) + + +def write_summary_csv( + path: pathlib.Path, results: list[ImplementationResult], suite: CompareSuite +) -> None: + case_columns = suite.report_config.get("summary_case_metadata_columns", ["m", "k", "n"]) + implementation_columns = suite.report_config.get( + "summary_implementation_metadata_columns", + ["category", "implementation", "input_dtype", "dweight_dtype", "output_fingerprint"], + ) + columns = [ + "case_id", + *case_columns, + "id", + "display_name", + "group", + *implementation_columns, + "dtype", + "multi_precision", + "status", + "output_dtype", + "max_abs", + "mean_abs", + "rmse", + "p99_abs", + "max_rel", + "mean_rel", + "p99_rel", + "max_abs_idx", + "max_rel_idx", + "actual_at_max_abs", + "expect_at_max_abs", + "error", + ] + with path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=columns) + writer.writeheader() + for result in results: + metric = metric_dict(result.metrics_vs_standard) or {} + row = { + "case_id": result.case_id, + "id": result.spec.id, + "display_name": result.spec.display_name, + "group": result.spec.group, + "dtype": result.spec.dtype, + "multi_precision": result.spec.multi_precision, + "status": result.status, + "output_dtype": result.output_dtype, + "max_abs": metric.get("max_abs"), + "mean_abs": metric.get("mean_abs"), + "rmse": metric.get("rmse"), + "p99_abs": metric.get("p99_abs"), + "max_rel": metric.get("max_rel"), + "mean_rel": metric.get("mean_rel"), + "p99_rel": metric.get("p99_rel"), + "max_abs_idx": ";".join(str(i) for i in metric.get("max_abs_idx", [])), + "max_rel_idx": ";".join(str(i) for i in metric.get("max_rel_idx", [])), + "actual_at_max_abs": metric.get("actual_at_max_abs"), + "expect_at_max_abs": metric.get("expect_at_max_abs"), + "error": result.error, + } + row.update({key: case_metadata_value(result, key) for key in case_columns}) + row.update({key: metadata_value(result, key) for key in implementation_columns}) + writer.writerow(row) + + +def write_pairwise_csv(path: pathlib.Path, pairwise_results: list[PairwiseResult]) -> None: + columns = [ + "case_id", + "actual_id", + "actual_display_name", + "actual_group", + "actual_dtype", + "actual_multi_precision", + "actual_output_dtype", + "expect_id", + "expect_display_name", + "expect_group", + "expect_dtype", + "expect_multi_precision", + "expect_output_dtype", + "max_abs", + "mean_abs", + "rmse", + "p99_abs", + "max_rel", + "mean_rel", + "p99_rel", + "max_abs_idx", + "actual_at_max_abs", + "expect_at_max_abs", + ] + with path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=columns) + writer.writeheader() + for pairwise in pairwise_results: + metric = asdict(pairwise.metrics) + writer.writerow( + { + **{ + key: value + for key, value in pairwise_to_dict(pairwise).items() + if key != "metrics" + }, + "max_abs": metric["max_abs"], + "mean_abs": metric["mean_abs"], + "rmse": metric["rmse"], + "p99_abs": metric["p99_abs"], + "max_rel": metric["max_rel"], + "mean_rel": metric["mean_rel"], + "p99_rel": metric["p99_rel"], + "max_abs_idx": ";".join(str(i) for i in metric["max_abs_idx"]), + "actual_at_max_abs": metric["actual_at_max_abs"], + "expect_at_max_abs": metric["expect_at_max_abs"], + } + ) + + +def write_artifacts(out_dir: pathlib.Path, run_data: dict[str, Any]) -> dict[str, pathlib.Path]: + suite: CompareSuite = run_data["suite"] + results: list[ImplementationResult] = run_data["results"] + pairwise_results: list[PairwiseResult] = run_data["pairwise_results"] + reference_pairwise_results: list[PairwiseResult] = run_data["reference_pairwise_results"] + + paths = { + "env": out_dir / "env.json", + "results": out_dir / "results.json", + "summary": out_dir / "summary.csv", + "pairwise": out_dir / "pairwise_summary.csv", + "reference_pairwise": out_dir / "reference_pairwise_summary.csv", + } + write_json(paths["env"], collect_env_info(suite)) + write_json( + paths["results"], + { + "results": [result_to_dict(result) for result in results], + "pairwise_results": [pairwise_to_dict(item) for item in pairwise_results], + "reference_pairwise_results": [ + pairwise_to_dict(item) for item in reference_pairwise_results + ], + "profile": run_data.get("profile"), + }, + ) + write_summary_csv(paths["summary"], results, suite) + write_pairwise_csv(paths["pairwise"], pairwise_results) + write_pairwise_csv(paths["reference_pairwise"], reference_pairwise_results) + return paths diff --git a/tester/operator_compare/config_loader.py b/tester/operator_compare/config_loader.py new file mode 100644 index 00000000..052dc977 --- /dev/null +++ b/tester/operator_compare/config_loader.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from collections.abc import Iterable +from pathlib import Path + +from tester.api_config.config_analyzer import APIConfig + +from .spec import CompareCase + + +def case_from_config_line(config_line: str, case_id: str = "case_0") -> CompareCase: + raw_config = config_line.strip() + api_config = APIConfig(raw_config) + return CompareCase( + id=case_id, + shape=(), + tensors={"api_config": api_config}, + metadata={ + "api_name": api_config.api_name, + "raw_config": raw_config, + }, + ) + + +def cases_from_config_lines(config_lines: Iterable[str]) -> list[CompareCase]: + cases = [] + for index, line in enumerate(config_lines): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + cases.append(case_from_config_line(stripped, case_id=f"case_{len(cases)}")) + return cases + + +def cases_from_config_file(config_file: str | Path) -> list[CompareCase]: + path = Path(config_file) + return cases_from_config_lines(path.read_text().splitlines()) diff --git a/tester/operator_compare/implementations.py b/tester/operator_compare/implementations.py new file mode 100644 index 00000000..05c986a7 --- /dev/null +++ b/tester/operator_compare/implementations.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import copy +import inspect +from collections import OrderedDict +from collections.abc import Callable, Iterable +from typing import Any + +import paddle +import torch +from tester.api_config.config_analyzer import APIConfig, TensorConfig +from tester.paddle_to_torch.converter import Paddle2TorchConverter, get_converter + +from .config_loader import cases_from_config_lines +from .spec import CompareCase, CompareSuite, ImplementationSpec + +CustomRunner = Callable[[CompareCase, ImplementationSpec], torch.Tensor] + +CUSTOM_IMPLEMENTATIONS: dict[str, CustomRunner] = {} +MANUAL_ARGUMENT_NAMES: dict[str, tuple[str, ...]] = { + "paddle._C_ops.fused_linear_param_grad_add": ( + "x", + "dout", + "dweight", + "dbias", + "multi_precision", + "has_bias", + ), +} +SIGNATURE_ARGUMENT_NAMES: dict[str, tuple[str, ...]] = {} + + +def register_custom_implementation(name: str, runner: CustomRunner) -> None: + if name in CUSTOM_IMPLEMENTATIONS: + raise ValueError(f"custom implementation already registered: {name}") + CUSTOM_IMPLEMENTATIONS[name] = runner + + +def implementation_id(name: str, dtype: str | None, precision: str) -> str: + dtype_part = dtype or "config" + return f"{name}|{dtype_part}|{precision}" + + +def expand_implementations( + *, + op_name: str, + implementation_names: Iterable[str], + dtypes: Iterable[str | None] | None = None, + precisions: Iterable[str] | None = None, +) -> list[ImplementationSpec]: + dtype_values = list(dtypes or [None]) + precision_values = list(precisions or ["default"]) + specs: list[ImplementationSpec] = [] + for dtype in dtype_values: + for precision in precision_values: + for name in implementation_names: + kind = implementation_kind(name) + specs.append( + ImplementationSpec( + id=implementation_id(name, dtype, precision), + display_name=name, + group="reference" if kind == "torch" else "target", + runner=runner_for_kind(kind, name), + dtype=dtype, + multi_precision=precision != "default", + metadata={ + "kind": kind, + "implementation": name, + "api_name": op_name, + "precision": precision, + }, + ) + ) + return specs + + +def implementation_kind(name: str) -> str: + if name in {"paddle", "torch"}: + return name + return "custom" + + +def runner_for_kind(kind: str, name: str): + if kind == "paddle": + return run_paddle_case + if kind == "torch": + return run_torch_case + try: + custom_runner = CUSTOM_IMPLEMENTATIONS[name] + except KeyError as err: + raise ValueError(f"unknown custom implementation: {name}") from err + + def runner(case: CompareCase) -> torch.Tensor: + return custom_runner(case, current_spec(case)) + + return runner + + +def build_compare_suite( + *, + config_lines: Iterable[str], + implementation_names: Iterable[str], + standard: str, + dtypes: Iterable[str | None] | None = None, + precisions: Iterable[str] | None = None, + metrics_dtype: str = "fp64", + enable_fingerprint: bool = True, +) -> CompareSuite: + cases = cases_from_config_lines(config_lines) + if not cases: + raise ValueError("no operator compare cases loaded") + op_name = cases[0].metadata["api_name"] + if any(case.metadata["api_name"] != op_name for case in cases): + raise ValueError("all cases in one compare suite must use the same api") + implementations = expand_implementations( + op_name=op_name, + implementation_names=implementation_names, + dtypes=dtypes, + precisions=precisions, + ) + return CompareSuite( + op_name=op_name, + cases=cases, + implementations=implementations, + standard_id=standard, + target_groups={"target"}, + reference_groups={"reference"}, + metrics_dtype=metrics_dtype, + enable_fingerprint=enable_fingerprint, + metadata={"source": "config"}, + report_config={ + "title": f"{op_name} 多实现对比报告", + "method_intro": "测试用例来自 PaddleAPITest config。", + "shape_metadata_keys": ["api_name"], + "shape_label_prefix": "API", + "summary_case_metadata_columns": ["api_name", "raw_config"], + "summary_implementation_metadata_columns": [ + "kind", + "implementation", + "precision", + "output_fingerprint", + ], + }, + ) + + +def clone_api_config(api_config: APIConfig, dtype: str | None) -> APIConfig: + seeded = seeded_api_config(api_config, dtype) + cloned = copy.deepcopy(seeded) + prepare_tensor_configs(cloned.args, dtype) + prepare_tensor_configs(cloned.kwargs.values(), dtype) + copy_numpy_tensors(seeded.args, cloned.args) + copy_numpy_tensors(seeded.kwargs.values(), cloned.kwargs.values()) + return cloned + + +def seeded_api_config(api_config: APIConfig, dtype: str | None) -> APIConfig: + cache = getattr(api_config, "_operator_compare_seeded_configs", None) + if cache is None: + cache = {} + setattr(api_config, "_operator_compare_seeded_configs", cache) + cache_key = dtype or "config" + if cache_key not in cache: + seeded = copy.deepcopy(api_config) + prepare_tensor_configs(seeded.args, dtype) + prepare_tensor_configs(seeded.kwargs.values(), dtype) + materialize_numpy_tensors(seeded.args, seeded) + materialize_numpy_tensors(seeded.kwargs.values(), seeded) + cache[cache_key] = seeded + return cache[cache_key] + + +def prepare_tensor_configs(values: Iterable[Any], dtype: str | None) -> None: + for value in values: + if isinstance(value, TensorConfig): + if dtype is not None: + value.dtype = dtype + value.numpy_tensor = None + value.paddle_tensor = None + value.torch_tensor = None + if not hasattr(value, "shuffle_dims"): + value.shuffle_dims = None + elif isinstance(value, (list, tuple)): + prepare_tensor_configs(value, dtype) + + +def materialize_numpy_tensors(values: Iterable[Any], api_config: APIConfig) -> None: + for value in values: + if isinstance(value, TensorConfig): + value.get_numpy_tensor(api_config) + elif isinstance(value, (list, tuple)): + materialize_numpy_tensors(value, api_config) + + +def copy_numpy_tensors(source_values: Iterable[Any], target_values: Iterable[Any]) -> None: + for source, target in zip(source_values, target_values): + if isinstance(source, TensorConfig) and isinstance(target, TensorConfig): + target.numpy_tensor = copy_tensor_value(source.numpy_tensor) + elif isinstance(source, (list, tuple)) and isinstance(target, (list, tuple)): + copy_numpy_tensors(source, target) + + +def copy_tensor_value(value: Any) -> Any: + if hasattr(value, "copy"): + return value.copy() + return copy.deepcopy(value) + + +def run_paddle_case(case: CompareCase) -> torch.Tensor: + spec = current_spec(case) + api_config = clone_api_config(case.tensors["api_config"], spec.dtype) + args = materialize_paddle_args(api_config.args, api_config) + kwargs = OrderedDict( + (key, materialize_paddle_value(value, api_config)) + for key, value in api_config.kwargs.items() + ) + output = eval(api_config.api_name)(*args, **kwargs) + return to_torch_tensor(output) + + +def run_torch_case(case: CompareCase) -> torch.Tensor: + spec = current_spec(case) + api_config = clone_api_config(case.tensors["api_config"], spec.dtype) + convert_result = get_converter().convert(api_config.api_name) + if not convert_result.is_supported: + raise RuntimeError( + convert_result.error_message or f"unsupported torch mapping: {api_config.api_name}" + ) + torch_args = materialize_torch_args(api_config.args, api_config) + torch_kwargs = bind_torch_kwargs(api_config, torch_args) + output = Paddle2TorchConverter.execute(convert_result, torch_args, torch_kwargs) + return to_torch_tensor(output) + + +def current_spec(case: CompareCase) -> ImplementationSpec: + try: + return case.tensors["_current_spec"] + except KeyError as err: + raise RuntimeError("operator compare runner missing current implementation spec") from err + + +def bind_torch_kwargs(api_config: APIConfig, torch_args: list[Any]) -> OrderedDict[str, Any]: + named_values = bind_api_arguments(api_config, torch_args) + named_values.update( + (key, materialize_torch_value(value, api_config)) + for key, value in api_config.kwargs.items() + ) + return OrderedDict((key, value) for key, value in named_values.items() if value is not None) + + +def bind_api_arguments(api_config: APIConfig, torch_args: list[Any]) -> OrderedDict[str, Any]: + argument_names = argument_names_for_api(api_config.api_name) + return OrderedDict( + (name, torch_args[index]) + for index, name in enumerate(argument_names) + if index < len(torch_args) + ) + + +def argument_names_for_api(api_name: str) -> tuple[str, ...]: + if api_name in MANUAL_ARGUMENT_NAMES: + return MANUAL_ARGUMENT_NAMES[api_name] + if api_name not in SIGNATURE_ARGUMENT_NAMES: + SIGNATURE_ARGUMENT_NAMES[api_name] = signature_argument_names(api_name) + return SIGNATURE_ARGUMENT_NAMES[api_name] + + +def signature_argument_names(api_name: str) -> tuple[str, ...]: + try: + signature = inspect.signature(eval(api_name)) + except (TypeError, ValueError): + return () + return tuple( + name + for name, parameter in signature.parameters.items() + if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD) + ) + + +def materialize_paddle_args(values: Iterable[Any], api_config: APIConfig) -> list[Any]: + return [materialize_paddle_value(value, api_config) for value in values] + + +def materialize_paddle_value(value: Any, api_config: APIConfig) -> Any: + if isinstance(value, TensorConfig): + tensor = value.get_paddle_tensor(api_config) + value.clear_paddle_tensor() + return tensor + if isinstance(value, list): + return [materialize_paddle_value(item, api_config) for item in value] + if isinstance(value, tuple): + return tuple(materialize_paddle_value(item, api_config) for item in value) + return value + + +def materialize_torch_args(values: Iterable[Any], api_config: APIConfig) -> list[Any]: + return [materialize_torch_value(value, api_config) for value in values] + + +def materialize_torch_value(value: Any, api_config: APIConfig) -> Any: + if isinstance(value, TensorConfig): + tensor = value.get_torch_tensor(api_config) + value.clear_torch_tensor() + return tensor + if isinstance(value, list): + return [materialize_torch_value(item, api_config) for item in value] + if isinstance(value, tuple): + return tuple(materialize_torch_value(item, api_config) for item in value) + return value + + +def to_torch_tensor(output: Any) -> torch.Tensor: + if isinstance(output, torch.Tensor): + return output + if isinstance(output, paddle.Tensor): + return torch.utils.dlpack.from_dlpack(paddle.utils.dlpack.to_dlpack(output)) + if isinstance(output, (list, tuple)): + for item in output: + if isinstance(item, (paddle.Tensor, torch.Tensor)): + return to_torch_tensor(item) + raise TypeError(f"output {type(output).__name__} contains no tensor") + raise TypeError(f"output type {type(output).__name__} is not supported") diff --git a/tester/operator_compare/metrics.py b/tester/operator_compare/metrics.py new file mode 100644 index 00000000..327ecbd7 --- /dev/null +++ b/tester/operator_compare/metrics.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import hashlib +from typing import Any + +import torch + +METRICS_DTYPES = { + "fp32": torch.float32, + "fp64": torch.float64, +} + + +def tensor_fingerprint(tensor: torch.Tensor) -> str: + byte_tensor = tensor.detach().contiguous().view(torch.uint8) + digest = hashlib.sha256(byte_tensor.cpu().numpy().tobytes()).hexdigest() + return f"sha256:{digest}" + + +def unravel_index(flat_index: int, shape: tuple[int, ...]) -> list[int]: + indices: list[int] = [] + remainder = flat_index + for size in reversed(shape): + indices.append(remainder % size) + remainder //= size + return list(reversed(indices)) + + +def tensor_value(tensor: torch.Tensor, indices: list[int]) -> float: + value: Any = tensor + for index in indices: + value = value[index] + return float(value.item()) + + +def compute_metrics( + actual: torch.Tensor, + expect: torch.Tensor, + eps: float = 1e-12, + metrics_dtype: str = "fp32", +) -> dict[str, Any]: + if actual.shape != expect.shape: + raise ValueError( + f"shape mismatch: actual={tuple(actual.shape)}, expect={tuple(expect.shape)}" + ) + if metrics_dtype not in METRICS_DTYPES: + raise ValueError( + f"metrics_dtype must be one of {sorted(METRICS_DTYPES)}, got {metrics_dtype!r}" + ) + + dtype = METRICS_DTYPES[metrics_dtype] + actual_compute = actual.to(dtype) + expect_compute = expect.to(dtype) + diff = actual_compute - expect_compute + abs_diff = diff.abs() + rel_diff = abs_diff / torch.clamp(expect_compute.abs(), min=eps) + + flat_abs_idx = int(abs_diff.reshape(-1).argmax().item()) + flat_rel_idx = int(rel_diff.reshape(-1).argmax().item()) + abs_idx = unravel_index(flat_abs_idx, tuple(abs_diff.shape)) + rel_idx = unravel_index(flat_rel_idx, tuple(rel_diff.shape)) + + return { + "max_abs": float(abs_diff.max().item()), + "mean_abs": float(abs_diff.mean().item()), + "rmse": float(torch.sqrt((diff * diff).mean()).item()), + "p99_abs": float(torch.quantile(abs_diff.reshape(-1), 0.99).item()), + "max_rel": float(rel_diff.max().item()), + "mean_rel": float(rel_diff.mean().item()), + "p99_rel": float(torch.quantile(rel_diff.reshape(-1), 0.99).item()), + "max_abs_idx": abs_idx, + "max_rel_idx": rel_idx, + "actual_at_max_abs": tensor_value(actual_compute, abs_idx), + "expect_at_max_abs": tensor_value(expect_compute, abs_idx), + } diff --git a/tester/operator_compare/profile.py b/tester/operator_compare/profile.py new file mode 100644 index 00000000..d4b4dde9 --- /dev/null +++ b/tester/operator_compare/profile.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import csv +import json +import pathlib +import shutil +import sqlite3 +import subprocess +import sys +from typing import Any + +from .spec import CompareSuite + + +def write_json(path: pathlib.Path, data: Any) -> None: + path.write_text( + json.dumps(data, indent=2, ensure_ascii=False, sort_keys=True) + "\n", encoding="utf-8" + ) + + +def tail_text(value: str | bytes | None, limit: int = 4000) -> str: + if value is None: + return "" + if isinstance(value, bytes): + return value.decode("utf-8", errors="replace")[-limit:] + return value[-limit:] + + +def summarize_kernel_rows(rows: list[dict[str, Any]]) -> list[dict[str, Any]]: + grouped: dict[str, list[int]] = {} + for row in rows: + grouped.setdefault(str(row["kernel_name"]), []).append(int(row.get("duration_ns") or 0)) + summary = [] + for name, durations in sorted(grouped.items(), key=lambda item: (-sum(item[1]), item[0])): + total = sum(durations) + count = len(durations) + summary.append( + { + "kernel_name": name, + "count": count, + "total_time_ns": total, + "mean_time_ns": total / count if count else 0, + "min_time_ns": min(durations) if durations else 0, + "max_time_ns": max(durations) if durations else 0, + } + ) + return summary + + +def string_id_map(conn: sqlite3.Connection) -> dict[int, str]: + tables = {row[0] for row in conn.execute("select name from sqlite_master where type='table'")} + if "StringIds" not in tables: + return {} + columns = [row[1] for row in conn.execute("pragma table_info(StringIds)")] + id_col = "id" if "id" in columns else columns[0] + value_col = "value" if "value" in columns else columns[-1] + result = {} + for key, value in conn.execute(f"select {id_col}, {value_col} from StringIds"): + try: + result[int(key)] = str(value) + except Exception: + continue + return result + + +def parse_nsys_sqlite(sqlite_path: pathlib.Path) -> dict[str, Any]: + conn = sqlite3.connect(str(sqlite_path)) + conn.row_factory = sqlite3.Row + try: + tables = sorted( + row[0] for row in conn.execute("select name from sqlite_master where type='table'") + ) + if "CUPTI_ACTIVITY_KIND_KERNEL" not in tables: + return {"status": "parsed_no_kernel_table", "tables": tables, "kernels": []} + + id_to_string = string_id_map(conn) + columns = [row[1] for row in conn.execute("pragma table_info(CUPTI_ACTIVITY_KIND_KERNEL)")] + name_columns = ["demangledName", "shortName", "name", "mangledName"] + direct_name_col = next((name for name in name_columns if name in columns), None) + id_name_col = next( + ( + name + for name in ["demangledName", "shortName", "name", "mangledName"] + if name in columns + ), + None, + ) + start_col = "start" if "start" in columns else None + end_col = "end" if "end" in columns else None + rows = [] + for row in conn.execute("select * from CUPTI_ACTIVITY_KIND_KERNEL"): + if direct_name_col and isinstance(row[direct_name_col], str): + kernel_name = row[direct_name_col] + elif id_name_col and row[id_name_col] is not None: + name_id = row[id_name_col] + try: + kernel_name = id_to_string.get(int(name_id), str(name_id)) + except (TypeError, ValueError): + kernel_name = str(name_id) + else: + kernel_name = "unknown" + duration = int(row[end_col] - row[start_col]) if start_col and end_col else 0 + rows.append({"kernel_name": kernel_name, "duration_ns": duration}) + return {"status": "parsed", "tables": tables, "kernels": summarize_kernel_rows(rows)} + finally: + conn.close() + + +def profile_worker_command( + worker: pathlib.Path, case_line: str, implementation: str, repeat: int +) -> list[str]: + return [ + sys.executable, + str(worker), + "--case", + case_line, + "--implementation", + implementation, + "--repeat", + str(repeat), + ] + + +def write_kernel_summary_csv(path: pathlib.Path, rows: list[dict[str, Any]]) -> None: + columns = [ + "implementation", + "kernel_name", + "count", + "total_time_ns", + "mean_time_ns", + "min_time_ns", + "max_time_ns", + "source_sqlite", + ] + with path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=columns) + writer.writeheader() + for row in rows: + writer.writerow({key: row.get(key) for key in columns}) + + +def run_kernel_profiles( + out_dir: pathlib.Path, + suite: CompareSuite, + case_line: str, + repo_root: pathlib.Path, + implementations: list[str], + repeat: int, + timeout_seconds: int = 600, +) -> dict[str, Any]: + profile_dir = out_dir / "profile" + profile_dir.mkdir(parents=True, exist_ok=True) + worker = repo_root / "tools" / "operator_compare_profile_worker.py" + nsys = shutil.which("nsys") + status_path = profile_dir / "profile_status.json" + if nsys is None: + example_implementation = implementations[0] if implementations else suite.standard_id + data = { + "status": "skipped", + "reason": "nsys not found in PATH", + "manual_command_template": " ".join( + profile_worker_command(worker, case_line, example_implementation, repeat) + ), + } + write_json(status_path, data) + return data + + all_kernel_rows: list[dict[str, Any]] = [] + runs = [] + for implementation in implementations: + base = profile_dir / f"nsys_{implementation.replace('|', '_')}" + rep_path = pathlib.Path(f"{base}.nsys-rep") + sqlite_path = pathlib.Path(f"{base}.sqlite") + worker_cmd = profile_worker_command(worker, case_line, implementation, repeat) + profile_cmd = [ + nsys, + "profile", + "--force-overwrite=true", + "--trace=cuda,nvtx,cublas,cudnn", + "--sample=none", + "--output", + str(base), + *worker_cmd, + ] + run_info: dict[str, Any] = { + "implementation": implementation, + "profile_command": profile_cmd, + } + try: + completed = subprocess.run( + profile_cmd, + cwd=str(repo_root), + text=True, + capture_output=True, + timeout=timeout_seconds, + ) + except subprocess.TimeoutExpired as err: + run_info.update( + { + "status": "profile_timeout", + "timeout_seconds": timeout_seconds, + "stdout": tail_text(err.stdout), + "stderr": tail_text(err.stderr), + } + ) + runs.append(run_info) + continue + run_info.update( + { + "profile_returncode": completed.returncode, + "stdout": tail_text(completed.stdout), + "stderr": tail_text(completed.stderr), + } + ) + if completed.returncode != 0: + run_info["status"] = "profile_failed" + runs.append(run_info) + continue + export_cmd = [ + nsys, + "export", + "--type", + "sqlite", + "--force-overwrite=true", + "--output", + str(sqlite_path), + str(rep_path), + ] + try: + exported = subprocess.run( + export_cmd, + cwd=str(repo_root), + text=True, + capture_output=True, + timeout=timeout_seconds, + ) + except subprocess.TimeoutExpired as err: + run_info.update( + { + "status": "export_timeout", + "timeout_seconds": timeout_seconds, + "export_stdout": tail_text(err.stdout), + "export_stderr": tail_text(err.stderr), + } + ) + runs.append(run_info) + continue + run_info.update( + { + "export_command": export_cmd, + "export_returncode": exported.returncode, + "export_stdout": tail_text(exported.stdout), + "export_stderr": tail_text(exported.stderr), + } + ) + if exported.returncode != 0 or not sqlite_path.exists(): + run_info["status"] = "export_failed" + runs.append(run_info) + continue + parsed = parse_nsys_sqlite(sqlite_path) + run_info.update( + { + "status": parsed.get("status"), + "sqlite": str(sqlite_path), + "nsys_rep": str(rep_path), + "tables": parsed.get("tables", []), + } + ) + for row in parsed.get("kernels", []): + all_kernel_rows.append( + {**row, "implementation": implementation, "source_sqlite": str(sqlite_path)} + ) + runs.append(run_info) + + successful_statuses = {"parsed", "parsed_no_kernel_table"} + successful_runs = [run for run in runs if run.get("status") in successful_statuses] + if not runs: + status = "skipped" + elif len(successful_runs) == len(runs): + status = "completed" + elif successful_runs: + status = "partial" + else: + status = "failed" + + write_kernel_summary_csv(profile_dir / "kernel_summary.csv", all_kernel_rows) + data = {"status": status, "runs": runs, "kernel_summary": all_kernel_rows} + write_json(profile_dir / "kernel_summary.json", all_kernel_rows) + write_json(status_path, data) + return data diff --git a/tester/operator_compare/report.py b/tester/operator_compare/report.py new file mode 100644 index 00000000..81c46115 --- /dev/null +++ b/tester/operator_compare/report.py @@ -0,0 +1,568 @@ +from __future__ import annotations + +import pathlib +from typing import Any + +from .spec import CompareSuite, ImplementationResult, PairwiseResult + +DEFAULT_LARGE_ERROR_THRESHOLD = 1e-1 +DEFAULT_CATEGORY_ORDER: dict[str, int] = {} +DEFAULT_IMPLEMENTATION_ORDER: dict[str, int] = {} +DEFAULT_DTYPE_ORDER = {"bf16": 0, "fp32": 1, "fp64": 2} +DEFAULT_CATEGORY_COLORS: dict[str, str] = {} + + +def fmt(value: float | str | None) -> str: + if value in (None, ""): + return "-" + number = float(value) + if number == 0: + return "0" + return f"{number:.4e}" + + +def safe_name(text: str) -> str: + return text.replace("/", "_").replace("|", "_").replace(" ", "_").replace(",", "_") + + +def metadata(result: ImplementationResult, key: str, default: Any = None) -> Any: + return result.metadata.get(key, default) + + +def case_metadata(result: ImplementationResult, key: str, default: Any = None) -> Any: + return result.metadata.get("case_metadata", {}).get(key, default) + + +def report_setting(suite: CompareSuite, key: str, default: Any = None) -> Any: + return suite.report_config.get(key, default) + + +def shape_metadata_keys(suite: CompareSuite) -> list[str]: + return list(report_setting(suite, "shape_metadata_keys", ["m", "k", "n"])) + + +def shape_tuple(result: ImplementationResult, suite: CompareSuite | None = None) -> tuple[Any, ...]: + keys = shape_metadata_keys(suite) if suite is not None else ["m", "k", "n"] + return tuple(case_metadata(result, key, 0) for key in keys) + + +def shape_label(shape: tuple[Any, ...]) -> str: + return ",".join(str(v) for v in shape) + + +def shape_file_suffix(suite: CompareSuite, shape: tuple[Any, ...]) -> str: + keys = shape_metadata_keys(suite) + return "_".join(f"{key}{value}" for key, value in zip(keys, shape)) + + +def ok_results(results: list[ImplementationResult]) -> list[ImplementationResult]: + return [ + item for item in results if item.status == "ok" and item.metrics_vs_standard is not None + ] + + +def result_sort_key( + result: ImplementationResult, suite: CompareSuite | None = None +) -> tuple[int, int, int, int, int, str]: + category = str(metadata(result, "category", "")) + implementation = str(metadata(result, "implementation", result.spec.id)) + input_dtype = str(metadata(result, "input_dtype", result.spec.dtype)) + dweight_dtype = str(metadata(result, "dweight_dtype", result.spec.dtype)) + category_order = ( + report_setting(suite, "category_order", DEFAULT_CATEGORY_ORDER) + if suite + else DEFAULT_CATEGORY_ORDER + ) + implementation_order = ( + report_setting(suite, "implementation_order", DEFAULT_IMPLEMENTATION_ORDER) + if suite + else DEFAULT_IMPLEMENTATION_ORDER + ) + dtype_order = ( + report_setting(suite, "dtype_order", DEFAULT_DTYPE_ORDER) if suite else DEFAULT_DTYPE_ORDER + ) + return ( + category_order.get(category, 99), + implementation_order.get(implementation, 99), + dtype_order.get(input_dtype, 99), + dtype_order.get(dweight_dtype, 99), + int(bool(result.spec.multi_precision)), + result.spec.id, + ) + + +def metric_value(result: ImplementationResult, key: str) -> float: + metric = result.metrics_vs_standard + return float(getattr(metric, key)) if metric is not None else 0.0 + + +def maybe_bold(text: str, result: ImplementationResult, suite: CompareSuite | None = None) -> str: + threshold = ( + report_setting(suite, "large_error_threshold", DEFAULT_LARGE_ERROR_THRESHOLD) + if suite + else DEFAULT_LARGE_ERROR_THRESHOLD + ) + return f"**{text}**" if metric_value(result, "max_abs") >= threshold else text + + +def implementation_config(result: ImplementationResult) -> str: + return ( + f"`{metadata(result, 'implementation', result.spec.id)}` " + f"(dtype={metadata(result, 'input_dtype', result.spec.dtype)}, " + f"dweight={metadata(result, 'dweight_dtype', result.spec.dtype)}, " + f"mp{1 if result.spec.multi_precision else 0})" + ) + + +def import_matplotlib(): + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + return plt + + +def positive_floor(values: list[float]) -> float: + positives = [value for value in values if value > 0] + if not positives: + return 1e-12 + return max(min(positives) / 10, 1e-12) + + +def plot_heatmaps( + out_dir: pathlib.Path, suite: CompareSuite, pairwise_results: list[PairwiseResult] +) -> tuple[list[pathlib.Path], str | None]: + try: + plt = import_matplotlib() + import numpy as np + from matplotlib.colors import LogNorm + except Exception as err: + return [], f"{type(err).__name__}: {err}" + + figures_dir = out_dir / "figures" + figures_dir.mkdir(exist_ok=True) + paths: list[pathlib.Path] = [] + case_ids = sorted({item.case_id for item in pairwise_results}) + reference_ids = suite.report_config.get("reference_order") or sorted( + {item.expect.spec.id for item in pairwise_results} + ) + target_ids = suite.report_config.get("target_order") or sorted( + {item.actual.spec.id for item in pairwise_results} + ) + + for case_id in case_ids: + rows = [item for item in pairwise_results if item.case_id == case_id] + matrix = np.zeros((len(target_ids), len(reference_ids)), dtype=float) + for item in rows: + if item.actual.spec.id in target_ids and item.expect.spec.id in reference_ids: + matrix[ + target_ids.index(item.actual.spec.id), reference_ids.index(item.expect.spec.id) + ] = item.metrics.max_abs + fig, ax = plt.subplots( + figsize=(max(8.0, len(reference_ids) * 1.1), max(4.0, len(target_ids) * 0.6)) + ) + positive = matrix[matrix > 0] + image = ( + ax.imshow(matrix, cmap="YlOrRd", norm=LogNorm(vmin=max(positive.min(), 1e-12))) + if positive.size + else ax.imshow(matrix, cmap="YlOrRd") + ) + ax.set_xticks(range(len(reference_ids))) + ax.set_xticklabels(reference_ids, rotation=45, ha="right", fontsize=8) + ax.set_yticks(range(len(target_ids))) + ax.set_yticklabels(target_ids, fontsize=8) + ax.set_title(f"{suite.op_name} pairwise max_abs: {case_id}") + fig.colorbar(image, ax=ax, fraction=0.046, pad=0.04).set_label("max_abs") + fig.tight_layout() + path = figures_dir / f"pairwise_heatmap_{safe_name(case_id)}.png" + fig.savefig(path, dpi=160) + plt.close(fig) + paths.append(path) + return paths, None + + +def plot_vs_standard( + out_dir: pathlib.Path, + suite: CompareSuite, + rows: list[ImplementationResult], + shape: tuple[Any, ...], +) -> pathlib.Path | None: + rows = sorted(ok_results(rows), key=lambda item: result_sort_key(item, suite)) + if not rows: + return None + plt = import_matplotlib() + figures_dir = out_dir / "figures" + figures_dir.mkdir(exist_ok=True) + values = [metric_value(row, "max_abs") for row in rows] + floor = positive_floor(values) + plot_values = [value if value > 0 else floor for value in values] + labels = [ + f"{metadata(row, 'implementation', row.spec.id)}\n{metadata(row, 'input_dtype', row.spec.dtype)}/{metadata(row, 'dweight_dtype', row.spec.dtype)} mp{1 if row.spec.multi_precision else 0}" + for row in rows + ] + category_colors = report_setting(suite, "category_colors", DEFAULT_CATEGORY_COLORS) + colors = [category_colors.get(str(metadata(row, "category", "")), "#7f7f7f") for row in rows] + fig, ax = plt.subplots(figsize=(18.0, max(6.0, 0.82 * len(rows) + 2.5)), dpi=160) + bars = ax.barh( + range(len(labels)), + plot_values, + color=colors, + edgecolor="#333333", + linewidth=0.7, + alpha=0.88, + ) + ax.set_yticks(range(len(labels)), labels=labels, fontsize=8) + ax.set_xscale("log") + threshold = report_setting(suite, "large_error_threshold", DEFAULT_LARGE_ERROR_THRESHOLD) + standard_label = report_setting(suite, "vs_standard_ylabel", "vs standard") + shape_prefix = report_setting(suite, "shape_label_prefix", "shape") + ax.axvline( + threshold, color="#b00020", linestyle="--", linewidth=1.2, label="large error threshold" + ) + ax.set_xlabel(f"max_abs {standard_label}") + ax.set_title(f"All implementations vs standard, {shape_prefix}={shape_label(shape)}") + ax.grid(axis="x", which="both", linestyle="--", linewidth=0.5, alpha=0.45) + xmax = max(plot_values + [threshold]) + for bar, row, value in zip(bars, rows, values): + ax.text( + bar.get_width() * 1.08, + bar.get_y() + bar.get_height() / 2, + fmt(value), + va="center", + fontsize=8, + fontweight="bold" if metric_value(row, "max_abs") >= threshold else "normal", + ) + ax.set_xlim(left=floor / 2, right=max(xmax * 8, threshold * 10)) + ax.legend(loc="lower right", fontsize=8) + fig.subplots_adjust(left=0.34, right=0.985, top=0.93, bottom=0.08) + path = figures_dir / f"vs_standard_overview_{shape_file_suffix(suite, shape)}.png" + fig.savefig(path, bbox_inches="tight", pad_inches=0.04) + plt.close(fig) + return path + + +def plot_reduce_trend( + out_dir: pathlib.Path, + suite: CompareSuite, + results: list[ImplementationResult], + metric_name: str, +) -> pathlib.Path | None: + rows = ok_results(results) + shapes = sorted({shape_tuple(row, suite) for row in rows}) + if len(shapes) < 2: + return None + plt = import_matplotlib() + figures_dir = out_dir / "figures" + figures_dir.mkdir(exist_ok=True) + fig, ax = plt.subplots(figsize=(12.5, 7.0), dpi=160) + groups: dict[str, list[ImplementationResult]] = {} + for row in rows: + trend_implementations = set(report_setting(suite, "reduce_trend_implementations", [])) + if not trend_implementations or metadata(row, "implementation") in trend_implementations: + key = f"{metadata(row, 'implementation')} {metadata(row, 'input_dtype')}/{metadata(row, 'dweight_dtype')} mp{1 if row.spec.multi_precision else 0}" + groups.setdefault(key, []).append(row) + plotted = False + for label, group_rows in groups.items(): + by_shape = {shape_tuple(row, suite): row for row in group_rows} + x_key = report_setting(suite, "reduce_trend_x_metadata_key", shape_metadata_keys(suite)[0]) + points = [ + ( + case_metadata(by_shape[shape], x_key, shape[0]), + metric_value(by_shape[shape], metric_name), + ) + for shape in shapes + if shape in by_shape + ] + if len(points) < 2: + continue + xs = [point[0] for point in points] + ys = [point[1] for point in points] + plot_ys = [value if value > 0 else positive_floor(ys) for value in ys] + is_fused = label.startswith("paddle_fused") + ax.plot( + xs, + plot_ys, + marker="D" if is_fused else "o", + linewidth=3.0 if is_fused else 1.6, + markersize=6.5 if is_fused else 4.2, + label=label, + ) + plotted = True + if not plotted: + plt.close(fig) + return None + ax.set_xscale("log") + ax.set_yscale("log") + ax.set_xlabel(report_setting(suite, "reduce_trend_x_label", "shape")) + ax.set_ylabel(f"{metric_name} {report_setting(suite, 'vs_standard_ylabel', 'vs standard')}") + ax.set_title(f"Reduce trend: M vs {metric_name}") + ax.grid(which="both", linestyle="--", linewidth=0.5, alpha=0.45) + ax.legend(loc="best", fontsize=7) + fig.tight_layout() + path = figures_dir / f"reduce_trend_{metric_name}.png" + fig.savefig(path, bbox_inches="tight", pad_inches=0.04) + plt.close(fig) + return path + + +def write_vs_standard_table( + lines: list[str], suite: CompareSuite, rows: list[ImplementationResult] +) -> None: + lines.append( + "| 类型 | 实现 | input dtype | dweight dtype | multi precision | output dtype | max_abs | rmse | p99_abs | max_rel | p99_rel |" + ) + lines.append("| --- | --- | --- | --- | --- | --- | ---: | ---: | ---: | ---: | ---: |") + for row in sorted(ok_results(rows), key=lambda item: result_sort_key(item, suite)): + metric = row.metrics_vs_standard + assert metric is not None + lines.append( + f"| `{metadata(row, 'category')}` | `{metadata(row, 'implementation', row.spec.id)}` | " + f"`{metadata(row, 'input_dtype', row.spec.dtype)}` | `{metadata(row, 'dweight_dtype', row.spec.dtype)}` | " + f"`{row.spec.multi_precision}` | `{row.output_dtype}` | {maybe_bold(fmt(metric.max_abs), row, suite)} | " + f"{maybe_bold(fmt(metric.rmse), row, suite)} | {maybe_bold(fmt(metric.p99_abs), row, suite)} | " + f"{maybe_bold(fmt(metric.max_rel), row, suite)} | {maybe_bold(fmt(metric.p99_rel), row, suite)} |" + ) + + +def write_bitwise_table( + lines: list[str], + suite: CompareSuite, + fused_rows: list[ImplementationResult], + comparison_rows: list[ImplementationResult], +) -> None: + lines.append("| Paddle fused 配置 | 逐位一致实现 | 逐位不一致实现 |") + lines.append("| --- | --- | --- |") + for fused in sorted(ok_results(fused_rows), key=lambda item: result_sort_key(item, suite)): + fused_fp = metadata(fused, "output_fingerprint") + comparable = [ + row + for row in ok_results(comparison_rows) + if metadata(row, "input_dtype") == metadata(fused, "input_dtype") + and row.spec.multi_precision == fused.spec.multi_precision + ] + identical = [ + implementation_config(row) + for row in comparable + if fused_fp and metadata(row, "output_fingerprint") == fused_fp + ] + different = [ + implementation_config(row) + for row in comparable + if metadata(row, "output_fingerprint") + and metadata(row, "output_fingerprint") != fused_fp + ] + lines.append( + f"| {implementation_config(fused)} | {';'.join(identical) if identical else '-'} | {';'.join(different) if different else '-'} |" + ) + + +def write_profile_section(lines: list[str], profile: dict[str, Any] | None) -> None: + lines.append("## 6. Kernel profile") + lines.append("") + if not profile: + lines.append("未启用 kernel profile。") + return + if profile.get("status") == "skipped": + lines.append(f"未采集 kernel 信息:`{profile.get('reason')}`。") + if profile.get("manual_command_template"): + lines.append("") + lines.append("可手动使用 nsys 包裹如下 workload 命令:") + lines.append("") + lines.append("```bash") + lines.append(str(profile["manual_command_template"])) + lines.append("```") + return + rows = profile.get("kernel_summary") or [] + if not rows: + lines.append( + f"未解析到 kernel summary;profile 状态为 `{profile.get('status')}`。可查看 `profile/profile_status.json` 和原始 nsys/sqlite 产物。" + ) + runs = profile.get("runs") or [] + if runs: + lines.append("") + lines.append("| implementation | status | detail |") + lines.append("| --- | --- | --- |") + for run in runs: + detail = ( + run.get("stderr") + or run.get("export_stderr") + or run.get("stdout") + or run.get("export_stdout") + or "-" + ) + detail = str(detail).replace("|", "\\|").replace("\n", " ")[:240] + lines.append( + f"| `{run.get('implementation')}` | `{run.get('status')}` | {detail} |" + ) + return + lines.append("| implementation | kernel | count | total ns | mean ns |") + lines.append("| --- | --- | ---: | ---: | ---: |") + for row in rows[:30]: + lines.append( + f"| `{row.get('implementation')}` | `{row.get('kernel_name')}` | {row.get('count')} | {fmt(row.get('total_time_ns'))} | {fmt(row.get('mean_time_ns'))} |" + ) + + +def render_report(out_dir: pathlib.Path, run_data: dict[str, Any]) -> pathlib.Path: + suite: CompareSuite = run_data["suite"] + pairwise_results: list[PairwiseResult] = run_data["pairwise_results"] + reference_pairwise_results: list[PairwiseResult] = run_data["reference_pairwise_results"] + results: list[ImplementationResult] = run_data["results"] + + figure_paths, figure_error = plot_heatmaps(out_dir, suite, pairwise_results) + vs_standard_error = None + vs_standard_paths: list[pathlib.Path] = [] + reduce_paths: list[pathlib.Path] = [] + try: + for shape in sorted({shape_tuple(row, suite) for row in results}): + path = plot_vs_standard( + out_dir, suite, [row for row in results if shape_tuple(row, suite) == shape], shape + ) + if path: + vs_standard_paths.append(path) + for metric_name in ["max_abs", "rmse"]: + path = plot_reduce_trend(out_dir, suite, results, metric_name) + if path: + reduce_paths.append(path) + except Exception as err: + vs_standard_error = f"{type(err).__name__}: {err}" + + global_max = ( + max(pairwise_results, key=lambda item: item.metrics.max_abs) if pairwise_results else None + ) + reference_max = ( + max(reference_pairwise_results, key=lambda item: item.metrics.max_abs) + if reference_pairwise_results + else None + ) + + lines = [f"# {report_setting(suite, 'title', suite.op_name + ' 精度对比报告')}", ""] + lines.append("## 1. 测试方法") + lines.append("") + lines.append(report_setting(suite, "method_intro", "测试对象为 operator compare suite。")) + formula = report_setting(suite, "formula") + if formula: + lines.append("") + lines.append("```text") + lines.append(str(formula)) + lines.append("```") + lines.append("") + lines.append(f"- 标准实现:`{suite.standard_id}`。") + lines.append(f"- 指标 dtype:`{suite.metrics_dtype}`。") + lines.append("- TF32:关闭。") + for key, value in suite.metadata.items(): + lines.append(f"- `{key}`:`{value}`。") + + lines.append("") + lines.append("## 2. 实现列表") + lines.append("") + lines.append("| id | 名称 | group | category | dtype | dweight dtype | multi_precision |") + lines.append("| --- | --- | --- | --- | --- | --- | --- |") + for spec in suite.implementations: + lines.append( + f"| `{spec.id}` | {spec.display_name} | `{spec.group}` | `{spec.metadata.get('category')}` | `{spec.metadata.get('input_dtype', spec.dtype)}` | `{spec.metadata.get('dweight_dtype', spec.dtype)}` | `{spec.multi_precision}` |" + ) + + lines.append("") + lines.append(f"## 3. {report_setting(suite, 'vs_standard_title', '相对标准实现误差')}") + lines.append("") + if vs_standard_error: + lines.append(f"图形生成失败:`{vs_standard_error}`。") + lines.append("") + if reduce_paths: + lines.append("### 3.1 跨 shape reduce 趋势") + lines.append("") + for path in reduce_paths: + lines.append(f"![{path.stem}]({path.relative_to(out_dir)})") + lines.append("") + for index, shape in enumerate(sorted({shape_tuple(row, suite) for row in results}), 2): + shape_rows = [row for row in results if shape_tuple(row, suite) == shape] + lines.append( + f"### 3.{index} {report_setting(suite, 'shape_label_prefix', 'shape')} = {shape_label(shape)}" + ) + lines.append("") + image = next( + (path for path in vs_standard_paths if shape_file_suffix(suite, shape) in path.name), + None, + ) + if image: + lines.append(f"![全部实现误差总览]({image.relative_to(out_dir)})") + lines.append("") + write_vs_standard_table(lines, suite, shape_rows) + lines.append("") + lines.append(f"#### {report_setting(suite, 'bitwise_title', '逐位一致性')}") + lines.append("") + primary_category = report_setting(suite, "bitwise_primary_category") + comparison_categories = set(report_setting(suite, "bitwise_comparison_categories", [])) + write_bitwise_table( + lines, + suite, + [row for row in shape_rows if metadata(row, "category") == primary_category], + [row for row in shape_rows if metadata(row, "category") in comparison_categories], + ) + lines.append("") + + lines.append("## 4. Pairwise 矩阵") + lines.append("") + if figure_error: + lines.append(f"未生成 heatmap:`{figure_error}`。") + elif figure_paths: + for path in figure_paths: + lines.append(f"![{path.stem}]({path.relative_to(out_dir)})") + lines.append("") + else: + lines.append("没有可绘制的 pairwise 数据。") + if global_max: + lines.append( + f"- 全局最大目标实现 vs 参考实现偏差:`max_abs={fmt(global_max.metrics.max_abs)}`,case `{global_max.case_id}`,`{global_max.actual.spec.id}` vs `{global_max.expect.spec.id}`。" + ) + if reference_max: + lines.append( + f"- 参考实现之间最大直接偏差:`max_abs={fmt(reference_max.metrics.max_abs)}`,case `{reference_max.case_id}`,`{reference_max.actual.spec.id}` vs `{reference_max.expect.spec.id}`。" + ) + + lines.append("") + lines.append("## 5. 失败实现") + lines.append("") + failed = [result for result in results if result.status != "ok"] + if not failed: + lines.append("无失败实现。") + else: + lines.append("| case | category | id | error |") + lines.append("| --- | --- | --- | --- |") + for result in failed: + error = (result.error or "").replace("|", "\\|") + lines.append( + f"| `{result.case_id}` | `{metadata(result, 'category')}` | `{result.spec.id}` | {error} |" + ) + + lines.append("") + write_profile_section(lines, run_data.get("profile")) + + lines.append("") + lines.append("## 7. 结论") + lines.append("") + conclusions = [] + for category, title in report_setting(suite, "conclusion_categories", []): + rows = [row for row in ok_results(results) if metadata(row, "category") == category] + if rows: + row = max(rows, key=lambda item: metric_value(item, "max_abs")) + conclusions.append( + f"{title} 相对最高标准的最大偏差为 `max_abs={fmt(metric_value(row, 'max_abs'))}`,配置为 {implementation_config(row)}。" + ) + exact = [row for row in ok_results(results) if metric_value(row, "max_abs") == 0] + if exact: + exact_text = ";".join(implementation_config(row) for row in exact[:8]) + if len(exact) > 8: + exact_text += f";等 {len(exact)} 个配置" + conclusions.append(f"与最高标准完全一致的配置包括:{exact_text}。") + conclusions.append( + "详细数据见 `summary.csv`、`pairwise_summary.csv`、`reference_pairwise_summary.csv`、`results.json` 和 `env.json`。" + ) + for index, conclusion in enumerate(conclusions, 1): + lines.append(f"{index}. {conclusion}") + + report_path = out_dir / "report.md" + report_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return report_path diff --git a/tester/operator_compare/runner.py b/tester/operator_compare/runner.py new file mode 100644 index 00000000..619715ae --- /dev/null +++ b/tester/operator_compare/runner.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from typing import Any + +import torch + +from .metrics import compute_metrics, tensor_fingerprint +from .spec import ( + CompareSuite, + ImplementationResult, + ImplementationSpec, + MetricResult, + PairwiseResult, +) + + +def metric_from_dict(metrics: dict[str, Any]) -> MetricResult: + return MetricResult(**metrics) + + +def run_implementation( + case_id: str, spec: ImplementationSpec, case, enable_fingerprint: bool +) -> ImplementationResult: + metadata = dict(spec.metadata) + metadata["case_metadata"] = dict(getattr(case, "metadata", {})) + try: + case.tensors["_current_spec"] = spec + tensor = spec.runner(case) + if not isinstance(tensor, torch.Tensor): + raise TypeError( + f"implementation {spec.id} returned {type(tensor).__name__}, expected torch.Tensor" + ) + if enable_fingerprint: + metadata["output_fingerprint"] = tensor_fingerprint(tensor) + return ImplementationResult( + case_id=case_id, + spec=spec, + status="ok", + tensor=tensor.detach(), + output_dtype=str(tensor.dtype), + metadata=metadata, + ) + except Exception as err: + return ImplementationResult( + case_id=case_id, + spec=spec, + status="failed", + error=f"{type(err).__name__}: {err}", + metadata=metadata, + ) + finally: + case.tensors.pop("_current_spec", None) + + +def comparable_pair(actual: ImplementationResult, expect: ImplementationResult) -> bool: + if actual.status != "ok" or expect.status != "ok": + return False + if actual.tensor is None or expect.tensor is None: + return False + if actual.spec.dtype != expect.spec.dtype: + return False + if actual.spec.multi_precision != expect.spec.multi_precision: + return False + return True + + +def run_compare_suite(suite: CompareSuite) -> dict[str, Any]: + all_results: list[ImplementationResult] = [] + pairwise_results: list[PairwiseResult] = [] + reference_pairwise_results: list[PairwiseResult] = [] + + standard_by_case: dict[str, ImplementationResult] = {} + for case in suite.cases: + case_results = [ + run_implementation(case.id, spec, case, suite.enable_fingerprint) + for spec in suite.implementations + ] + all_results.extend(case_results) + + standard = next( + (result for result in case_results if result.spec.id == suite.standard_id), None + ) + if standard is None: + raise ValueError(f"standard implementation {suite.standard_id!r} not registered") + if standard.status != "ok" or standard.tensor is None: + raise RuntimeError( + f"standard implementation failed for case {case.id}: {standard.error}" + ) + standard_by_case[case.id] = standard + + for result in case_results: + if result.status == "ok" and result.tensor is not None: + result.metrics_vs_standard = metric_from_dict( + compute_metrics( + result.tensor, standard.tensor, metrics_dtype=suite.metrics_dtype + ) + ) + + targets = [result for result in case_results if result.spec.group in suite.target_groups] + references = [ + result for result in case_results if result.spec.group in suite.reference_groups + ] + for actual in targets: + for expect in references: + if not comparable_pair(actual, expect): + continue + metrics = metric_from_dict( + compute_metrics(actual.tensor, expect.tensor, metrics_dtype=suite.metrics_dtype) + ) + pairwise_results.append(PairwiseResult(case.id, actual, expect, metrics)) + + for actual_index, actual in enumerate(references): + for expect in references[actual_index + 1 :]: + if not comparable_pair(actual, expect): + continue + metrics = metric_from_dict( + compute_metrics( + actual.tensor, + expect.tensor, + metrics_dtype=suite.reference_pairwise_metrics_dtype, + ) + ) + reference_pairwise_results.append(PairwiseResult(case.id, actual, expect, metrics)) + + return { + "suite": suite, + "results": all_results, + "pairwise_results": pairwise_results, + "reference_pairwise_results": reference_pairwise_results, + "standard_by_case": standard_by_case, + } diff --git a/tester/operator_compare/spec.py b/tester/operator_compare/spec.py new file mode 100644 index 00000000..c600e39a --- /dev/null +++ b/tester/operator_compare/spec.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +TensorRunner = Callable[["CompareCase"], Any] + + +@dataclass +class CompareCase: + id: str + shape: tuple[int, ...] + tensors: dict[str, Any] + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class ImplementationSpec: + id: str + display_name: str + group: str + runner: TensorRunner + dtype: str | None = None + multi_precision: bool | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class MetricResult: + max_abs: float + mean_abs: float + rmse: float + p99_abs: float + max_rel: float + mean_rel: float + p99_rel: float + max_abs_idx: list[int] + max_rel_idx: list[int] + actual_at_max_abs: float + expect_at_max_abs: float + + +@dataclass +class ImplementationResult: + case_id: str + spec: ImplementationSpec + status: str + tensor: Any | None = None + output_dtype: str | None = None + metrics_vs_standard: MetricResult | None = None + error: str | None = None + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PairwiseResult: + case_id: str + actual: ImplementationResult + expect: ImplementationResult + metrics: MetricResult + + +@dataclass +class CompareSuite: + op_name: str + cases: list[CompareCase] + implementations: list[ImplementationSpec] + standard_id: str + target_groups: set[str] + reference_groups: set[str] + metrics_dtype: str = "fp32" + reference_pairwise_metrics_dtype: str = "fp64" + enable_fingerprint: bool = True + metadata: dict[str, Any] = field(default_factory=dict) + report_config: dict[str, Any] = field(default_factory=dict) diff --git a/tools/operator_compare.py b/tools/operator_compare.py new file mode 100644 index 00000000..f9b826c4 --- /dev/null +++ b/tools/operator_compare.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import pathlib +import sys + +REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +DEFAULT_OUTPUT_ROOT = REPO_ROOT / "test_log_operator_compare" + + +def comma_list(value: str) -> list[str]: + return [item.strip() for item in value.split(",") if item.strip()] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run operator implementation comparisons and render a report." + ) + parser.add_argument("--case", default=None, help="One PaddleAPITest config line to compare.") + parser.add_argument( + "--config-file", default=None, help="PaddleAPITest config txt file to compare." + ) + parser.add_argument("--op", default=None, help="Optional API name filter for --config-file.") + parser.add_argument( + "--implementations", + default="paddle,torch", + help="Comma-separated implementations: paddle, torch, or registered custom names.", + ) + parser.add_argument( + "--dtypes", + default=None, + help="Comma-separated dtype override matrix. Defaults to config dtype.", + ) + parser.add_argument("--precisions", default="default", help="Comma-separated precision names.") + parser.add_argument( + "--standard", + default=None, + help="Standard implementation id. Defaults to torch||default when torch is enabled, otherwise first implementation.", + ) + parser.add_argument("--metrics-dtype", default="fp64", choices=["fp32", "fp64"]) + parser.add_argument( + "--no-fingerprint", action="store_true", help="Disable output tensor SHA256 fingerprints." + ) + parser.add_argument( + "--output-dir", + default=None, + help="Exact output directory. Defaults to test_log_operator_compare//.", + ) + args = parser.parse_args() + if not args.case and not args.config_file: + parser.error("one of --case or --config-file is required") + return args + + +def load_config_lines(args: argparse.Namespace) -> list[str]: + lines = [] + if args.case: + lines.append(args.case) + if args.config_file: + from tester.operator_compare.config_loader import cases_from_config_file + + loaded_cases = cases_from_config_file(args.config_file) + for case in loaded_cases: + if args.op is None or case.metadata["api_name"] == args.op: + lines.append(case.metadata["raw_config"]) + return lines + + +def default_standard(implementation_names: list[str], dtypes: list[str | None]) -> str: + dtype = dtypes[0] if dtypes else None + dtype_part = dtype or "config" + implementation = "torch" if "torch" in implementation_names else implementation_names[0] + return f"{implementation}|{dtype_part}|default" + + +def build_suite(args: argparse.Namespace): + from tester.operator_compare.implementations import build_compare_suite + + implementation_names = comma_list(args.implementations) + dtypes = comma_list(args.dtypes) if args.dtypes else [None] + standard = args.standard or default_standard(implementation_names, dtypes) + return build_compare_suite( + config_lines=load_config_lines(args), + implementation_names=implementation_names, + standard=standard, + dtypes=dtypes, + precisions=comma_list(args.precisions), + metrics_dtype=args.metrics_dtype, + enable_fingerprint=not args.no_fingerprint, + ) + + +def main() -> None: + args = parse_args() + + from tester.operator_compare.artifacts import timestamped_output_dir, write_artifacts + from tester.operator_compare.report import render_report + from tester.operator_compare.runner import run_compare_suite + + suite = build_suite(args) + out_dir = ( + pathlib.Path(args.output_dir).resolve() + if args.output_dir + else timestamped_output_dir(DEFAULT_OUTPUT_ROOT, suite.op_name) + ) + out_dir.mkdir(parents=True, exist_ok=True) + + run_data = run_compare_suite(suite) + write_artifacts(out_dir, run_data) + report_path = render_report(out_dir, run_data) + + print(f"Output directory: {out_dir}") + print(f"Report file: {report_path}") + + +if __name__ == "__main__": + main() diff --git a/tools/operator_compare_profile_worker.py b/tools/operator_compare_profile_worker.py new file mode 100644 index 00000000..fb8bc894 --- /dev/null +++ b/tools/operator_compare_profile_worker.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import pathlib +import sys + +REPO_ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run one operator implementation under NVTX ranges for nsys profiling." + ) + parser.add_argument("--case", required=True, help="One PaddleAPITest config line to profile.") + parser.add_argument( + "--implementation", + required=True, + help="Full implementation id, such as paddle|config|default.", + ) + parser.add_argument("--repeat", type=int, default=3) + parser.add_argument("--metrics-dtype", default="fp64", choices=["fp32", "fp64"]) + return parser.parse_args() + + +def main() -> None: + from tester.operator_compare.implementations import build_compare_suite + + args = parse_args() + implementation_name = args.implementation.split("|", 1)[0] + suite = build_compare_suite( + config_lines=[args.case], + implementation_names=[implementation_name], + standard=args.implementation, + dtypes=[None], + metrics_dtype=args.metrics_dtype, + enable_fingerprint=False, + ) + case = suite.cases[0] + spec = next((item for item in suite.implementations if item.id == args.implementation), None) + if spec is None: + raise ValueError(f"implementation not found: {args.implementation}") + + import torch + + spec.runner(case) + if torch.cuda.is_available(): + torch.cuda.synchronize() + for idx in range(args.repeat): + range_name = f"{args.implementation}|case={case.id}|iter={idx}" + torch.cuda.nvtx.range_push(range_name) + out = spec.runner(case) + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + print(range_name, tuple(out.shape), out.dtype, flush=True) + + +if __name__ == "__main__": + main()