diff --git a/benchmarks/python/comparative/README.md b/benchmarks/python/comparative/README.md index dc0e8306a7..6f77c0da28 100644 --- a/benchmarks/python/comparative/README.md +++ b/benchmarks/python/comparative/README.md @@ -11,5 +11,16 @@ tensor on the cpu. `compare.py` runs several benchmarks and compares the speed-up or lack thereof in comparison to PyTorch. +Use `--json` to produce machine-readable results suitable for storing as CI +artifacts. `--repeats` runs each selected benchmark in a fresh process multiple +times and reports the median as well as the individual samples: + +```shell +python compare.py \ + --filter "^cumsum --size 1024x1024 --axis 0 --cpu$" \ + --repeats 3 \ + --json +``` + Each bench script can be run with `--print-pid` to print the PID and wait for a key in order to ease attaching a debugger. diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py index 68b4a5bd32..da12e3d66d 100644 --- a/benchmarks/python/comparative/compare.py +++ b/benchmarks/python/comparative/compare.py @@ -3,7 +3,11 @@ #!/usr/bin/env python import argparse +import json +import platform import re +import statistics +import sys from pathlib import Path from subprocess import run @@ -11,28 +15,70 @@ BENCH_TORCH = Path(__file__).parent / "bench_torch.py" -def run_or_raise(*args, **kwargs): +def run_or_raise(command): + result = run(command, capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"Command failed: {' '.join(map(str, command))}\n" + f"stdout: {result.stdout}\nstderr: {result.stderr}" + ) try: - result = run(*args, capture_output=True, **kwargs) return float(result.stdout) except ValueError: raise ValueError( - f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}" + f"Command returned invalid timing data: {' '.join(map(str, command))}\n" + f"stdout: {result.stdout}\nstderr: {result.stderr}" ) -def compare(args): - t_mlx = run_or_raise(["python", BENCH_MLX] + args) - t_torch = run_or_raise(["python", BENCH_TORCH] + args) - - print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t") - +def run_repeated(command, repeats): + samples = [run_or_raise(command) for _ in range(repeats)] + return { + "median_seconds": statistics.median(samples), + "min_seconds": min(samples), + "max_seconds": max(samples), + "samples_seconds": samples, + } + + +def compare(args, repeats, results): + mlx = run_repeated([sys.executable, BENCH_MLX] + args, repeats) + torch = run_repeated([sys.executable, BENCH_TORCH] + args, repeats) + speedup = (torch["median_seconds"] - mlx["median_seconds"]) / torch[ + "median_seconds" + ] + result = { + "benchmark": args[0], + "arguments": args[1:], + "mlx": mlx, + "torch": torch, + "relative_speedup": speedup, + } + if results is None: + print(speedup, " ".join(args), sep="\t") + else: + results.append(result) -def compare_mlx_dtypes(args, dt1, dt2): - t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1]) - t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2]) - print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, " ".join(args), sep="\t") +def compare_mlx_dtypes(args, dt1, dt2, repeats, results): + first = run_repeated([sys.executable, BENCH_MLX] + args + ["--dtype", dt1], repeats) + second = run_repeated( + [sys.executable, BENCH_MLX] + args + ["--dtype", dt2], repeats + ) + speedup = (second["median_seconds"] - first["median_seconds"]) / second[ + "median_seconds" + ] + result = { + "benchmark": args[0], + "arguments": args[1:], + "dtypes": [dt1, dt2], + "mlx": {dt1: first, dt2: second}, + "relative_speedup": speedup, + } + if results is None: + print(speedup, " ".join(args), sep="\t") + else: + results.append(result) def make_regex_search(regexes): @@ -77,18 +123,38 @@ def predicate(x): help="Compare mlx benchmarks between the 2 provided data types", nargs=2, ) + parser.add_argument( + "--json", action="store_true", help="Emit machine-readable benchmark results" + ) + parser.add_argument( + "--repeats", + type=int, + default=1, + help="Run each benchmark this many times and report the median", + ) args, rest = parser.parse_known_args() + if args.repeats < 1: + parser.error("--repeats must be at least 1") _filter = make_predicate(args.filter, args.negative_filter) + results = [] if args.json else None if args.mlx_dtypes: compare_filtered = lambda x: ( - compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1]) + compare_mlx_dtypes( + x.split() + rest, + args.mlx_dtypes[0], + args.mlx_dtypes[1], + args.repeats, + results, + ) if _filter(x) else None ) else: - compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None + compare_filtered = lambda x: ( + compare(x.split() + rest, args.repeats, results) if _filter(x) else None + ) # Binary ops compare_filtered("add --size 10x1024x128 --size 1x1024x128 --cpu") @@ -282,3 +348,21 @@ def predicate(x): compare_filtered("topk --size 32768x128 --axis 1") compare_filtered("topk --size 128x128 --axis 0 --cpu") compare_filtered("topk --size 128x128 --axis 1 --cpu") + + if args.json: + print( + json.dumps( + { + "schema_version": 1, + "metadata": { + "platform": platform.platform(), + "python_version": platform.python_version(), + "python_executable": sys.executable, + "repeats": args.repeats, + }, + "results": results, + }, + indent=2, + sort_keys=True, + ) + ) diff --git a/python/tests/test_benchmark_compare.py b/python/tests/test_benchmark_compare.py new file mode 100644 index 0000000000..3180db8a3f --- /dev/null +++ b/python/tests/test_benchmark_compare.py @@ -0,0 +1,156 @@ +# Copyright © 2026 Apple Inc. + +import importlib.util +import json +import subprocess +import sys +import unittest +from pathlib import Path +from unittest import mock + +REPO_ROOT = Path(__file__).resolve().parents[2] +COMPARE_PATH = REPO_ROOT / "benchmarks/python/comparative/compare.py" +TORCH_AVAILABLE = importlib.util.find_spec("torch") is not None + + +def load_compare_module(): + spec = importlib.util.spec_from_file_location("benchmark_compare", COMPARE_PATH) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +class BenchmarkCompareUnitTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.compare = load_compare_module() + + def test_run_repeated_reports_samples_and_summary(self): + with mock.patch.object( + self.compare, "run_or_raise", side_effect=[0.3, 0.1, 0.2] + ): + result = self.compare.run_repeated(["benchmark"], 3) + + self.assertEqual(result["samples_seconds"], [0.3, 0.1, 0.2]) + self.assertEqual(result["median_seconds"], 0.2) + self.assertEqual(result["min_seconds"], 0.1) + self.assertEqual(result["max_seconds"], 0.3) + + def test_run_or_raise_includes_child_process_error(self): + command = [ + sys.executable, + "-c", + "import sys; print('details', file=sys.stderr); sys.exit(7)", + ] + + with self.assertRaisesRegex(RuntimeError, "details"): + self.compare.run_or_raise(command) + + def test_run_or_raise_rejects_invalid_timing_output(self): + command = [sys.executable, "-c", "print('not-a-timing')"] + + with self.assertRaisesRegex(ValueError, "invalid timing data"): + self.compare.run_or_raise(command) + + def test_dtype_comparison_records_both_summaries(self): + first = { + "median_seconds": 1.0, + "min_seconds": 0.9, + "max_seconds": 1.1, + "samples_seconds": [1.0], + } + second = { + "median_seconds": 2.0, + "min_seconds": 1.9, + "max_seconds": 2.1, + "samples_seconds": [2.0], + } + results = [] + + with mock.patch.object( + self.compare, "run_repeated", side_effect=[first, second] + ): + self.compare.compare_mlx_dtypes( + ["cumsum", "--cpu"], "float32", "float16", 1, results + ) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["dtypes"], ["float32", "float16"]) + self.assertEqual(results[0]["mlx"]["float32"], first) + self.assertEqual(results[0]["mlx"]["float16"], second) + self.assertEqual(results[0]["relative_speedup"], 0.5) + + def test_predicate_combines_positive_and_negative_filters(self): + predicate = self.compare.make_predicate([r"^cumsum", r"--cpu$"], [r"axis 1"]) + + self.assertTrue(predicate("cumsum --size 8x8 --axis 0 --cpu")) + self.assertFalse(predicate("cumsum --size 8x8 --axis 1 --cpu")) + self.assertFalse(predicate("sum_axis --size 8x8 --axis 0 --cpu")) + + +class BenchmarkCompareCliTest(unittest.TestCase): + def test_invalid_repeat_count_fails_before_running_benchmarks(self): + result = subprocess.run( + [sys.executable, COMPARE_PATH, "--repeats", "0"], + capture_output=True, + cwd=REPO_ROOT, + text=True, + timeout=30, + ) + + self.assertEqual(result.returncode, 2) + self.assertIn("--repeats must be at least 1", result.stderr) + + +@unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is required for comparative benchmarks") +class BenchmarkCompareIntegrationTest(unittest.TestCase): + def run_compare(self, *args): + return subprocess.run( + [sys.executable, COMPARE_PATH, *args], + capture_output=True, + check=True, + cwd=REPO_ROOT, + text=True, + timeout=120, + ) + + def test_json_output_from_real_mlx_and_torch_benchmarks(self): + result = self.run_compare( + "--filter", + r"^cumsum --size 128x1024 --axis 0$", + "--cpu", + "--repeats", + "2", + "--json", + ) + report = json.loads(result.stdout) + + self.assertEqual(report["schema_version"], 1) + self.assertEqual(report["metadata"]["repeats"], 2) + self.assertEqual(len(report["results"]), 1) + + benchmark = report["results"][0] + self.assertEqual(benchmark["benchmark"], "cumsum") + self.assertEqual( + benchmark["arguments"], + ["--size", "128x1024", "--axis", "0", "--cpu"], + ) + self.assertGreater(benchmark["mlx"]["median_seconds"], 0) + self.assertGreater(benchmark["torch"]["median_seconds"], 0) + self.assertEqual(len(benchmark["mlx"]["samples_seconds"]), 2) + self.assertEqual(len(benchmark["torch"]["samples_seconds"]), 2) + + def test_default_output_from_real_mlx_and_torch_benchmarks(self): + result = self.run_compare( + "--filter", + r"^cumsum --size 128x1024 --axis 0$", + "--cpu", + ) + + speedup, benchmark = result.stdout.rstrip().split("\t", maxsplit=1) + float(speedup) + self.assertEqual(benchmark, "cumsum --size 128x1024 --axis 0 --cpu") + + +if __name__ == "__main__": + unittest.main()