diff --git a/backends/cuda/tests/test_int4_matmul.py b/backends/cuda/tests/test_int4_matmul.py index 2f33f888ac1..ed0ca47f3f6 100644 --- a/backends/cuda/tests/test_int4_matmul.py +++ b/backends/cuda/tests/test_int4_matmul.py @@ -19,7 +19,6 @@ import unittest import torch - from executorch.backends.cuda.triton.kernels.int4_matmul import ( dequant_w4_to_bf16, int4_matmul, @@ -28,6 +27,41 @@ ATOL = 0.01 DEVICE = "cuda" +SNR_THRESHOLD_DB = 50.0 + + +def _assert_snr(test_case, actual, expected, label, threshold_db=SNR_THRESHOLD_DB): + """Assert signal-to-noise ratio (in dB) of `actual` vs `expected` >= threshold. + + SNR = 20*log10(||expected||_2 / ||actual - expected||_2) + + Why SNR rather than torch.allclose(atol/rtol): + * Size-invariant: ||signal|| and ||noise|| both scale with sqrt(N) and + with sqrt(K) (CLT + random-walk rounding), so the ratio is independent + of tensor size and reduction depth. The same threshold works for + K=64 and K=4096, M=1 and M=1024. + * Robust to bf16 ULP outliers: with K=2048 and output magnitudes ~200, + a single element can differ by ~1.0 just from differing reduction + orders (Triton fused vs cuBLAS). atol/rtol false-fails on these; + SNR averages them out. + * Sensitive to real bugs: wrong stride, flipped nibble, off-by-one + group_idx, or a missing mask all collapse SNR to <20 dB. The 50 dB + threshold (≈0.3% RMS error) sits comfortably between observed clean + noise floor (~80-90 dB) and any genuine functional break. + """ + a = actual.float() + b = expected.float() + diff = a - b + signal = b.norm() + noise = diff.norm() + snr_db = (20.0 * torch.log10(signal / noise.clamp(min=1e-9))).item() + test_case.assertGreater( + snr_db, + threshold_db, + f"{label}: SNR={snr_db:.1f} dB (threshold {threshold_db:.1f} dB), " + f"max_abs_err={diff.abs().max().item():.4f}, " + f"signal_norm={signal.item():.2f}, noise_norm={noise.item():.4f}", + ) def _quantize_simple(w_bf16, group_size): @@ -118,12 +152,7 @@ def _run_matmul(self, M, N, K, group_size): self.assertEqual(out.shape, (M, N)) self.assertEqual(out.dtype, torch.bfloat16) - self.assertTrue( - torch.allclose(out.float(), ref.float(), atol=ATOL, rtol=0.01), - f"int4_matmul M={M} [{N}x{K}] gs={group_size}: " - f"max_abs_err={(out.float() - ref.float()).abs().max().item():.4f}, " - f"max_rel_err={((out.float() - ref.float()).abs() / ref.float().abs().clamp(min=1e-6)).max().item():.4f}", - ) + _assert_snr(self, out, ref, f"int4_matmul M={M} [{N}x{K}] gs={group_size}") # --- Decode (M=1) --- def test_decode_square(self): @@ -189,13 +218,7 @@ def _run_matvec(self, N, K, group_size): self.assertEqual(out.shape, (1, N)) self.assertEqual(out.dtype, torch.bfloat16) - # atol=1.0 for large accumulation across K, rtol=0.01 for relative - self.assertTrue( - torch.allclose(out.float(), ref.float(), atol=1.0, rtol=0.01), - f"int4_matvec [{N}x{K}] gs={group_size}: " - f"max_err={(out.float() - ref.float()).abs().max().item():.4f}, " - f"max_rel={((out.float()-ref.float()).abs()/(ref.float().abs().clamp(min=0.1))).max().item():.4f}", - ) + _assert_snr(self, out, ref, f"int4_matvec [{N}x{K}] gs={group_size}") def test_qkv_proj(self): self._run_matvec(2048, 2048, 128) @@ -226,10 +249,7 @@ def test_matches_int4_matmul(self): out_mv = int4_matvec(x, packed, scale, gs) out_mm = int4_matmul(x, packed, scale, gs) - self.assertTrue( - torch.allclose(out_mv.float(), out_mm.float(), atol=1.0, rtol=0.01), - f"matvec vs matmul: max_err={(out_mv.float() - out_mm.float()).abs().max().item():.4f}", - ) + _assert_snr(self, out_mv, out_mm, "matvec vs matmul") class TestDequantThenMatmul(unittest.TestCase): @@ -248,13 +268,7 @@ def _run(self, M, N, K, group_size): w_bf16 = dequant_w4_to_bf16(packed, scale, group_size) out_dequant = torch.nn.functional.linear(x, w_bf16) - self.assertTrue( - torch.allclose( - out_fused.float(), out_dequant.float(), atol=ATOL, rtol=0.01 - ), - f"fused vs dequant M={M} [{N}x{K}]: " - f"max_abs_err={(out_fused.float() - out_dequant.float()).abs().max().item():.4f}", - ) + _assert_snr(self, out_fused, out_dequant, f"fused vs dequant M={M} [{N}x{K}]") def test_decode(self): self._run(1, 2048, 2048, 128)