diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py index 4d87fe680a..e1ed08213a 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize.py @@ -558,7 +558,9 @@ def _kernel_dequantize_mx4( other=0.0, ) # Remove fp32 exponent bias. - exp = exp.to(tl.int16) - FP32_EXP_BIAS + # Mask with 0xFF to interpret as unsigned before subtracting bias, + # since biased exponents >= 128 are sign-extended when cast from int8. + exp = (exp.to(tl.int16) & 0xFF) - FP32_EXP_BIAS # Convert exponent to scale and apply to input. # Requires higher precision to avoid rounding out small values. diff --git a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py index 06d0cb0f74..5be3e089fd 100644 --- a/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py +++ b/fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py @@ -254,9 +254,10 @@ def py_dequantize_mx4( num_groups = a.numel() // ((group_size // 2) + 1) packed_input = a[:, :-1] shared_exp = a[:, -1:] - # Remove fp32 exponent bias + # Remove fp32 exponent bias. + # View as uint8 first to avoid sign-extension for biased exponents >= 128. FP32_EXP_BIAS = 127 - shared_exp = shared_exp - FP32_EXP_BIAS + shared_exp = shared_exp.view(torch.uint8).to(torch.int32) - FP32_EXP_BIAS # First pull shared exponent off the end of each row. M, K_2 = packed_input.shape diff --git a/fbgemm_gpu/test/quantize/mx4_test.py b/fbgemm_gpu/test/quantize/mx4_test.py index 8601b05f2b..8303628dcb 100644 --- a/fbgemm_gpu/test/quantize/mx4_test.py +++ b/fbgemm_gpu/test/quantize/mx4_test.py @@ -158,7 +158,7 @@ class TestMXQuantizationConversion(unittest.TestCase): @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) def test_mx4(self, power: int, sizes: int) -> None: group_size = 2**power - device = torch.device("cuda") + device = torch.device(torch.accelerator.current_accelerator() or "cuda") input = all_encodings(8, sizes, device=device) assert input.numel() % group_size == 0 @@ -366,7 +366,9 @@ def test_mx4_to_float_correctness( ) def test_mx4_index_overflow(self) -> None: """Tests that mx4 quantization kernels can handle inputs that would overflow int32 indices.""" - large_input = torch.zeros(2**32, dtype=torch.float32).to("cuda") + large_input = torch.zeros(2**32, dtype=torch.float32).to( + torch.accelerator.current_accelerator() or "cuda" + ) mx_quantized = fp32_to_mx4(large_input, 32) mx_dequantized = mx4_to_fp32(mx_quantized, 32) # We just need to check that everything ran without an illegal memory access. @@ -424,6 +426,66 @@ def test_mx4_large_cases( # I give quite a bit of wiggle room to make sure this isnt flaky. torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2) + @unittest.skipIf(*gpu_unavailable) + def test_mx4_high_bit_scale_exponent_triton(self) -> None: + """Regression test: scale exponents >= 128 (biased) must not be sign-extended. + + Scale exponents are stored as uint8 bit patterns in int8 storage. + Values >= 128 have the high bit set and were incorrectly sign-extended + when cast to int16 in the Triton dequantize kernel, producing wildly + wrong (negative) exponents. + """ + device = torch.device(torch.accelerator.current_accelerator() or "cuda") + group_size = 32 + + # Values with exponents >= 1 produce biased scale exponents >= 128. + # Use powers of 2 so quantization is exact and roundtrip is lossless + # for the scale component. + for magnitude in [2.0, 4.0, 64.0, 1024.0]: + input_tensor = torch.full( + [1, group_size], magnitude, device=device, dtype=torch.float32 + ) + + quantized = fp32_to_mx4( + input_tensor, group_size, rounding_mode=RoundingMode.nearest + ) + output_triton = mx4_to_fp32(quantized, group_size, use_triton=True) + output_triton = output_triton.reshape(input_tensor.shape) + + torch.testing.assert_close( + input_tensor, + output_triton, + rtol=0.0, + atol=0.0, + msg=f"Triton dequantize failed for magnitude={magnitude}", + ) + + def test_mx4_high_bit_scale_exponent_ref(self) -> None: + """Regression test: py_dequantize_mx4 sign-extension for scale exponents >= 128. + + The Python reference dequantize viewed packed data as int8 and + subtracted FP32_EXP_BIAS directly, causing sign-extension for biased + exponents >= 128. + """ + group_size = 32 + + for magnitude in [2.0, 4.0, 64.0, 1024.0]: + input_tensor = torch.full([1, group_size], magnitude, dtype=torch.float32) + + quantized = py_quantize_mx4( + input_tensor, group_size, rounding_mode=RoundingMode.nearest + ) + output_ref = py_dequantize_mx4(quantized, group_size) + output_ref = output_ref.reshape(input_tensor.shape) + + torch.testing.assert_close( + input_tensor, + output_ref, + rtol=0.0, + atol=0.0, + msg=f"py_dequantize_mx4 failed for magnitude={magnitude}", + ) + if __name__ == "__main__": unittest.main()