Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/triton/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,9 @@ def _kernel_dequantize_mx4(
other=0.0,
)
# Remove fp32 exponent bias.
exp = exp.to(tl.int16) - FP32_EXP_BIAS
# Mask with 0xFF to interpret as unsigned before subtracting bias,
# since biased exponents >= 128 are sign-extended when cast from int8.
exp = (exp.to(tl.int16) & 0xFF) - FP32_EXP_BIAS

# Convert exponent to scale and apply to input.
# Requires higher precision to avoid rounding out small values.
Expand Down
5 changes: 3 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/triton/quantize_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,10 @@ def py_dequantize_mx4(
num_groups = a.numel() // ((group_size // 2) + 1)
packed_input = a[:, :-1]
shared_exp = a[:, -1:]
# Remove fp32 exponent bias
# Remove fp32 exponent bias.
# View as uint8 first to avoid sign-extension for biased exponents >= 128.
FP32_EXP_BIAS = 127
shared_exp = shared_exp - FP32_EXP_BIAS
shared_exp = shared_exp.view(torch.uint8).to(torch.int32) - FP32_EXP_BIAS
# First pull shared exponent off the end of each row.
M, K_2 = packed_input.shape

Expand Down
66 changes: 64 additions & 2 deletions fbgemm_gpu/test/quantize/mx4_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class TestMXQuantizationConversion(unittest.TestCase):
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_mx4(self, power: int, sizes: int) -> None:
group_size = 2**power
device = torch.device("cuda")
device = torch.device(torch.accelerator.current_accelerator() or "cuda")

input = all_encodings(8, sizes, device=device)
assert input.numel() % group_size == 0
Expand Down Expand Up @@ -366,7 +366,9 @@ def test_mx4_to_float_correctness(
)
def test_mx4_index_overflow(self) -> None:
"""Tests that mx4 quantization kernels can handle inputs that would overflow int32 indices."""
large_input = torch.zeros(2**32, dtype=torch.float32).to("cuda")
large_input = torch.zeros(2**32, dtype=torch.float32).to(
torch.accelerator.current_accelerator() or "cuda"
)
mx_quantized = fp32_to_mx4(large_input, 32)
mx_dequantized = mx4_to_fp32(mx_quantized, 32)
# We just need to check that everything ran without an illegal memory access.
Expand Down Expand Up @@ -424,6 +426,66 @@ def test_mx4_large_cases(
# I give quite a bit of wiggle room to make sure this isnt flaky.
torch.testing.assert_close(input, mx_dequantized, rtol=1.0, atol=magnitude / 2)

@unittest.skipIf(*gpu_unavailable)
def test_mx4_high_bit_scale_exponent_triton(self) -> None:
"""Regression test: scale exponents >= 128 (biased) must not be sign-extended.

Scale exponents are stored as uint8 bit patterns in int8 storage.
Values >= 128 have the high bit set and were incorrectly sign-extended
when cast to int16 in the Triton dequantize kernel, producing wildly
wrong (negative) exponents.
"""
device = torch.device(torch.accelerator.current_accelerator() or "cuda")
group_size = 32

# Values with exponents >= 1 produce biased scale exponents >= 128.
# Use powers of 2 so quantization is exact and roundtrip is lossless
# for the scale component.
for magnitude in [2.0, 4.0, 64.0, 1024.0]:
input_tensor = torch.full(
[1, group_size], magnitude, device=device, dtype=torch.float32
)

quantized = fp32_to_mx4(
input_tensor, group_size, rounding_mode=RoundingMode.nearest
)
output_triton = mx4_to_fp32(quantized, group_size, use_triton=True)
output_triton = output_triton.reshape(input_tensor.shape)

torch.testing.assert_close(
input_tensor,
output_triton,
rtol=0.0,
atol=0.0,
msg=f"Triton dequantize failed for magnitude={magnitude}",
)

def test_mx4_high_bit_scale_exponent_ref(self) -> None:
"""Regression test: py_dequantize_mx4 sign-extension for scale exponents >= 128.

The Python reference dequantize viewed packed data as int8 and
subtracted FP32_EXP_BIAS directly, causing sign-extension for biased
exponents >= 128.
"""
group_size = 32

for magnitude in [2.0, 4.0, 64.0, 1024.0]:
input_tensor = torch.full([1, group_size], magnitude, dtype=torch.float32)

quantized = py_quantize_mx4(
input_tensor, group_size, rounding_mode=RoundingMode.nearest
)
output_ref = py_dequantize_mx4(quantized, group_size)
output_ref = output_ref.reshape(input_tensor.shape)

torch.testing.assert_close(
input_tensor,
output_ref,
rtol=0.0,
atol=0.0,
msg=f"py_dequantize_mx4 failed for magnitude={magnitude}",
)


if __name__ == "__main__":
unittest.main()
Loading