Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions backends/cuda/tests/test_int4_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import unittest

import torch

from executorch.backends.cuda.triton.kernels.int4_matmul import (
dequant_w4_to_bf16,
int4_matmul,
Expand All @@ -28,6 +27,41 @@

ATOL = 0.01
DEVICE = "cuda"
SNR_THRESHOLD_DB = 50.0


def _assert_snr(test_case, actual, expected, label, threshold_db=SNR_THRESHOLD_DB):
"""Assert signal-to-noise ratio (in dB) of `actual` vs `expected` >= threshold.

SNR = 20*log10(||expected||_2 / ||actual - expected||_2)

Why SNR rather than torch.allclose(atol/rtol):
* Size-invariant: ||signal|| and ||noise|| both scale with sqrt(N) and
with sqrt(K) (CLT + random-walk rounding), so the ratio is independent
of tensor size and reduction depth. The same threshold works for
K=64 and K=4096, M=1 and M=1024.
* Robust to bf16 ULP outliers: with K=2048 and output magnitudes ~200,
a single element can differ by ~1.0 just from differing reduction
orders (Triton fused vs cuBLAS). atol/rtol false-fails on these;
SNR averages them out.
* Sensitive to real bugs: wrong stride, flipped nibble, off-by-one
group_idx, or a missing mask all collapse SNR to <20 dB. The 50 dB
threshold (≈0.3% RMS error) sits comfortably between observed clean
noise floor (~80-90 dB) and any genuine functional break.
"""
a = actual.float()
b = expected.float()
diff = a - b
signal = b.norm()
noise = diff.norm()
snr_db = (20.0 * torch.log10(signal / noise.clamp(min=1e-9))).item()
test_case.assertGreater(
snr_db,
threshold_db,
f"{label}: SNR={snr_db:.1f} dB (threshold {threshold_db:.1f} dB), "
f"max_abs_err={diff.abs().max().item():.4f}, "
f"signal_norm={signal.item():.2f}, noise_norm={noise.item():.4f}",
)


def _quantize_simple(w_bf16, group_size):
Expand Down Expand Up @@ -118,12 +152,7 @@ def _run_matmul(self, M, N, K, group_size):

self.assertEqual(out.shape, (M, N))
self.assertEqual(out.dtype, torch.bfloat16)
self.assertTrue(
torch.allclose(out.float(), ref.float(), atol=ATOL, rtol=0.01),
f"int4_matmul M={M} [{N}x{K}] gs={group_size}: "
f"max_abs_err={(out.float() - ref.float()).abs().max().item():.4f}, "
f"max_rel_err={((out.float() - ref.float()).abs() / ref.float().abs().clamp(min=1e-6)).max().item():.4f}",
)
_assert_snr(self, out, ref, f"int4_matmul M={M} [{N}x{K}] gs={group_size}")

# --- Decode (M=1) ---
def test_decode_square(self):
Expand Down Expand Up @@ -189,13 +218,7 @@ def _run_matvec(self, N, K, group_size):

self.assertEqual(out.shape, (1, N))
self.assertEqual(out.dtype, torch.bfloat16)
# atol=1.0 for large accumulation across K, rtol=0.01 for relative
self.assertTrue(
torch.allclose(out.float(), ref.float(), atol=1.0, rtol=0.01),
f"int4_matvec [{N}x{K}] gs={group_size}: "
f"max_err={(out.float() - ref.float()).abs().max().item():.4f}, "
f"max_rel={((out.float()-ref.float()).abs()/(ref.float().abs().clamp(min=0.1))).max().item():.4f}",
)
_assert_snr(self, out, ref, f"int4_matvec [{N}x{K}] gs={group_size}")

def test_qkv_proj(self):
self._run_matvec(2048, 2048, 128)
Expand Down Expand Up @@ -226,10 +249,7 @@ def test_matches_int4_matmul(self):
out_mv = int4_matvec(x, packed, scale, gs)
out_mm = int4_matmul(x, packed, scale, gs)

self.assertTrue(
torch.allclose(out_mv.float(), out_mm.float(), atol=1.0, rtol=0.01),
f"matvec vs matmul: max_err={(out_mv.float() - out_mm.float()).abs().max().item():.4f}",
)
_assert_snr(self, out_mv, out_mm, "matvec vs matmul")


class TestDequantThenMatmul(unittest.TestCase):
Expand All @@ -248,13 +268,7 @@ def _run(self, M, N, K, group_size):
w_bf16 = dequant_w4_to_bf16(packed, scale, group_size)
out_dequant = torch.nn.functional.linear(x, w_bf16)

self.assertTrue(
torch.allclose(
out_fused.float(), out_dequant.float(), atol=ATOL, rtol=0.01
),
f"fused vs dequant M={M} [{N}x{K}]: "
f"max_abs_err={(out_fused.float() - out_dequant.float()).abs().max().item():.4f}",
)
_assert_snr(self, out_fused, out_dequant, f"fused vs dequant M={M} [{N}x{K}]")

def test_decode(self):
self._run(1, 2048, 2048, 128)
Expand Down
Loading