switch correctness checks to SNR-based assertion for cuda quant int4_matmul#19300
switch correctness checks to SNR-based assertion for cuda quant int4_matmul#19300
Conversation
Replace torch.allclose(atol/rtol) with an SNR (signal-to-noise ratio)
assertion across all int4_matmul / int4_matvec / dequant-vs-fused tests.
Why:
- test_prefill_short was flaking on CI (A10G) with max_abs_err=1.0000.
Root cause: bf16 GEMM with K=2048 reduction produces output magnitudes
up to ~200; at that scale, the bf16 ULP gap is 0.5-1.0. Triton fused
kernel and cuBLAS reduce in different orders (and Triton autotune
picks different tile configs on different hardware), so 1-ULP
element-wise differences are unavoidable. atol/rtol false-fails on
these outliers; SNR averages them out.
- atol/rtol thresholds also depend on size: a value tuned for K=2048
is too loose for K=64 and too tight for K=4096. SNR is size-invariant
(||signal|| and ||noise|| both scale with sqrt(N) and sqrt(K),
canceling in the ratio).
What:
- Add _assert_snr(test_case, actual, expected, label) helper that
asserts 20*log10(||expected|| / ||actual-expected||) >= 60 dB.
- Replace 4 call sites: TestInt4Matmul, TestInt4Matvec (x2),
TestDequantThenMatmul.
- 60 dB ~ 0.1% RMS error: well below observed clean noise (80-90 dB)
and well above any real functional bug (<20 dB SNR for wrong
stride / flipped nibble / off-by-one group_idx / missing mask).
Test plan:
python -m pytest backends/cuda/tests/test_int4_matmul.py -v
-> 35/35 passed
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19300
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 4 New Failures, 7 Pending, 3 Unrelated FailuresAs of commit 85d06de with merge base acffcb0 ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
Replace torch.allclose(atol/rtol) with an SNR (signal-to-noise ratio) assertion across all int4_matmul / int4_matvec / dequant-vs-fused tests.
Why:
What:
Test plan:
python -m pytest backends/cuda/tests/test_int4_matmul.py -v
-> 35/35 passed