Skip to content

Commit 0de13f3

Browse files
Fix sign-extension bug in fbgemm MX4 Python reference dequantize (#5706)
Summary: X-link: facebookresearch/FBGEMM#2643 py_dequantize_mx4 viewed packed data as int8 and subtracted FP32_EXP_BIAS directly. For biased exponents >= 128, the int8 value is negative, producing incorrect results. Fix by viewing as uint8 then casting to int32 before subtracting the bias. Same class of bug as D101680517 and the Triton kernel fix. GH PR: #5706 Reviewed By: q10 Differential Revision: D102195911
1 parent 91bf688 commit 0de13f3

2 files changed

Lines changed: 29 additions & 2 deletions

File tree

fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,10 @@ def py_dequantize_mx4(
254254
num_groups = a.numel() // ((group_size // 2) + 1)
255255
packed_input = a[:, :-1]
256256
shared_exp = a[:, -1:]
257-
# Remove fp32 exponent bias
257+
# Remove fp32 exponent bias.
258+
# View as uint8 first to avoid sign-extension for biased exponents >= 128.
258259
FP32_EXP_BIAS = 127
259-
shared_exp = shared_exp - FP32_EXP_BIAS
260+
shared_exp = shared_exp.view(torch.uint8).to(torch.int32) - FP32_EXP_BIAS
260261
# First pull shared exponent off the end of each row.
261262
M, K_2 = packed_input.shape
262263

fbgemm_gpu/test/quantize/mx4_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,32 @@ def test_mx4_high_bit_scale_exponent_triton(self) -> None:
460460
msg=f"Triton dequantize failed for magnitude={magnitude}",
461461
)
462462

463+
def test_mx4_high_bit_scale_exponent_ref(self) -> None:
464+
"""Regression test: py_dequantize_mx4 sign-extension for scale exponents >= 128.
465+
466+
The Python reference dequantize viewed packed data as int8 and
467+
subtracted FP32_EXP_BIAS directly, causing sign-extension for biased
468+
exponents >= 128.
469+
"""
470+
group_size = 32
471+
472+
for magnitude in [2.0, 4.0, 64.0, 1024.0]:
473+
input_tensor = torch.full([1, group_size], magnitude, dtype=torch.float32)
474+
475+
quantized = py_quantize_mx4(
476+
input_tensor, group_size, rounding_mode=RoundingMode.nearest
477+
)
478+
output_ref = py_dequantize_mx4(quantized, group_size)
479+
output_ref = output_ref.reshape(input_tensor.shape)
480+
481+
torch.testing.assert_close(
482+
input_tensor,
483+
output_ref,
484+
rtol=0.0,
485+
atol=0.0,
486+
msg=f"py_dequantize_mx4 failed for magnitude={magnitude}",
487+
)
488+
463489

464490
if __name__ == "__main__":
465491
unittest.main()

0 commit comments

Comments
 (0)