Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions benchmarks/benchmark_det_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2026 RL-Kernel Contributors
"""Overhead of batch-invariant det_gemm vs cuBLAS + Triton (WS1 #146).
det_gemm (CUDA: SM90 TMA + mma.sync tensor cores; naive scalar fallback below
SM90) and the Triton path are batch-invariant and SLOWER than cuBLAS by design
(no split-K/stream-K, fixed accumulation order, FP32 accum, no TF32). Reports
overhead vs the fair baseline (cuBLAS, TF32 disabled), not a speedup. This is a
correctness-and-invariance-first milestone; occupancy/throughput tuning of the
tensor-core path is deferred per #146.
"""
import argparse

import torch

from rl_engine.kernels.ops.cuda.matmul import deterministic_gemm
from rl_engine.kernels.ops.pytorch.matmul import native_gemm

try:
from rl_engine.kernels.ops.triton.matmul import deterministic_gemm_triton

_HAS_TRITON = True
except ImportError:
_HAS_TRITON = False

DEV = "cuda"
WARMUP, ITERS = 10, 50

SHAPES = [
("qkv", 4096, 4096, 12288),
("o_proj", 4096, 4096, 4096),
("mlp_up", 4096, 4096, 14336),
("mlp_dn", 4096, 14336, 4096),
("lm_head", 4096, 4096, 32000),
]


def _time(fn, a, b):
for _ in range(WARMUP):
fn(a, b)
torch.cuda.synchronize()
s = torch.cuda.Event(enable_timing=True)
e = torch.cuda.Event(enable_timing=True)
s.record()
for _ in range(ITERS):
fn(a, b)
e.record()
torch.cuda.synchronize()
return s.elapsed_time(e) / ITERS


def run():
rows = []
for name, M, K, N in SHAPES:
a = torch.randn(M, K, device=DEV, dtype=torch.bfloat16)
b = torch.randn(K, N, device=DEV, dtype=torch.bfloat16)
torch.backends.cuda.matmul.allow_tf32 = True
t_tf32 = _time(lambda x, y: torch.matmul(x, y), a, b)
torch.backends.cuda.matmul.allow_tf32 = False
t_fp32 = _time(native_gemm, a, b)
t_cuda = _time(deterministic_gemm, a, b)
t_tri = _time(deterministic_gemm_triton, a, b) if _HAS_TRITON else float("nan")
rows.append((name, M, K, N, t_tf32, t_fp32, t_cuda, t_tri, t_cuda / t_fp32))
Comment thread
Flink-ddd marked this conversation as resolved.
return rows


def to_markdown(rows, dev, cap):
out = [
f"## det_gemm overhead — {dev} (SM{cap[0]}{cap[1]})",
"",
"| shape | M | K | N | cuBLAS tf32 | cuBLAS fp32 | det CUDA | det Triton | overhead |",
"|---|---|---|---|---|---|---|---|---|",
]
for n, M, K, N, t1, t2, t3, t4, ov in rows:
out.append(
f"| {n} | {M} | {K} | {N} | {t1:.3f} | {t2:.3f} | {t3:.3f} | {t4:.3f} | {ov:.1f}x |"
)
out += [
"",
"_Overhead = det CUDA vs cuBLAS (TF32 disabled). The det CUDA path uses "
"SM90 TMA + mma.sync tensor cores with a fixed single-CTA-per-tile "
"schedule (no split-K) for bitwise batch-invariance; both det paths "
"trade speed for invariance. Throughput tuning is deferred per #146._",
]
return "\n".join(out)


def main():
ap = argparse.ArgumentParser()
ap.add_argument("--out", default=None)
args = ap.parse_args()
name, cap = torch.cuda.get_device_name(), torch.cuda.get_device_capability()
print(name, cap)
md = to_markdown(run(), name, cap)
print("\n" + md)
if args.out:
with open(args.out, "w") as f:
f.write(md + "\n")


if __name__ == "__main__":
main()
11 changes: 11 additions & 0 deletions benchmarks/results/det_gemm_h100_tma.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## det_gemm overhead — NVIDIA H100 80GB HBM3 (SM90)

| shape | M | K | N | cuBLAS tf32 | cuBLAS fp32 | det CUDA | det Triton | overhead |
|---|---|---|---|---|---|---|---|---|
| qkv | 4096 | 4096 | 12288 | 0.538 | 0.538 | 3.280 | 1.421 | 6.1x |
| o_proj | 4096 | 4096 | 4096 | 0.190 | 0.190 | 1.164 | 0.478 | 6.1x |
| mlp_up | 4096 | 4096 | 14336 | 0.656 | 0.704 | 3.800 | 1.688 | 5.4x |
| mlp_dn | 4096 | 14336 | 4096 | 0.629 | 0.685 | 3.779 | 1.787 | 5.5x |
| lm_head | 4096 | 4096 | 32000 | 1.513 | 1.528 | 8.269 | 3.897 | 5.4x |

_Overhead = det CUDA vs cuBLAS (TF32 disabled). The det CUDA path uses SM90 TMA + mma.sync tensor cores with a fixed single-CTA-per-tile schedule (no split-K) for bitwise batch-invariance; both det paths trade speed for invariance. Throughput tuning is deferred per #146._
Loading
Loading