Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions problems/mxfp4-gemv/def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import torch
from typing import List, Dict, Tuple, Any

from problem import Problem

BLOCK_SIZE = 32


class mxfp4_gemv(Problem):

is_exact = False

parameters = [
{"name": "q_a", "type": "uint8_t", "pointer": True, "const": True},
{"name": "scale_a", "type": "uint8_t", "pointer": True, "const": True},
{"name": "q_x", "type": "uint8_t", "pointer": True, "const": True},
{"name": "scale_x", "type": "uint8_t", "pointer": True, "const": True},
{"name": "y", "type": "float", "pointer": True, "const": False},
{"name": "m", "type": "size_t", "pointer": False, "const": False},
{"name": "k", "type": "size_t", "pointer": False, "const": False},
]

def __init__(self):
super().__init__(name="mxfp4-gemv")

@staticmethod
def _mx_tensor_api():
try:
from torchao.prototype.mx_formats.mx_tensor import to_mx, to_dtype
except Exception as e:
raise RuntimeError(
"TorchAO MXTensor APIs are required. Install a torchao build with "
"torchao.prototype.mx_formats.mx_tensor support."
) from e
return to_mx, to_dtype

def reference_solution(
self, q_a: torch.Tensor, scale_a: torch.Tensor, q_x: torch.Tensor, scale_x: torch.Tensor
) -> torch.Tensor:
_, to_dtype = self._mx_tensor_api()

with torch.no_grad():
a_deq = to_dtype(
q_a,
scale_a.view(torch.float8_e8m0fnu),
torch.float4_e2m1fn_x2,
BLOCK_SIZE,
torch.float32,
).float()
x_deq = to_dtype(
q_x,
scale_x.view(torch.float8_e8m0fnu),
torch.float4_e2m1fn_x2,
BLOCK_SIZE,
torch.float32,
).float().squeeze(0)

return torch.matmul(a_deq, x_deq)

def _make_case(self, m: int, k: int, name: str) -> Dict[str, Any]:
if k % BLOCK_SIZE != 0:
raise ValueError(f"K must be divisible by {BLOCK_SIZE}, got K={k}")

seed = Problem.get_seed(f"{self.name}_{name}_M={m}_K={k}")

def create_inputs(m=m, k=k, seed=seed):
to_mx, _ = self._mx_tensor_api()

g = torch.Generator(device="cuda").manual_seed(seed)
a = torch.randn((m, k), device="cuda", dtype=self.param_dtype("y"), generator=g)
x = torch.randn((1, k), device="cuda", dtype=self.param_dtype("y"), generator=g)

scale_a_e8m0, a_lp = to_mx(a, torch.float4_e2m1fn_x2, BLOCK_SIZE)
scale_x_e8m0, x_lp = to_mx(x, torch.float4_e2m1fn_x2, BLOCK_SIZE)

q_a = a_lp.contiguous().view(torch.uint8)
scale_a = scale_a_e8m0.contiguous().view(torch.uint8)
q_x = x_lp.contiguous().view(torch.uint8)
scale_x = scale_x_e8m0.contiguous().view(torch.uint8)

return q_a, scale_a, q_x, scale_x

return {"name": name, "dims": (m, k), "create_inputs": create_inputs}

def generate_test_cases(self) -> List[Dict[str, Any]]:
configs = [
(1024, 1024, "1024 x 1024"),
(2048, 2048, "2048 x 2048"),
(4096, 4096, "4096 x 4096"),
(8192, 4096, "8192 x 4096"),
(4096, 8192, "4096 x 8192"),
]
return [self._make_case(m, k, name) for m, k, name in configs]

def generate_sample(self) -> Dict[str, Any]:
return self._make_case(32, 32, "sample_32x32")

def verify_result(
self, expected_output: torch.Tensor, actual_output: torch.Tensor
) -> Tuple[bool, Dict[str, Any]]:
is_close = torch.allclose(actual_output, expected_output, rtol=2e-2, atol=5e-2)

debug_info: Dict[str, Any] = {}
if not is_close:
diff = actual_output - expected_output
abs_diff = torch.abs(diff)
debug_info = {
"max_difference": abs_diff.max().item(),
"mean_difference": abs_diff.mean().item(),
}

return is_close, debug_info

def get_flops(self, test_case: Dict[str, Any]) -> int:
m, k = test_case["dims"]
return 2 * m * k

def get_extra_params(self, test_case: Dict[str, Any]) -> List[Any]:
m, k = test_case["dims"]
return [m, k]
31 changes: 31 additions & 0 deletions problems/mxfp4-gemv/problem.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
slug: "mxfp4-gemv"
title: "MXFP4 GEMV"
difficulty: "HARD"
author: "sarthak"
tags: ["quantization", "mxfp4", "matmul", "vector"]
gpus: ["B200"]
---

Compute matrix-vector multiplication where both matrix $A$ and vector $x$ are stored in MXFP4 format.

$$
y_i = \sum_{\ell=0}^{K-1} A_{\mathrm{dequant},i\ell} \, x_{\mathrm{dequant},\ell}.
$$

Equivalently, $y = A_{\mathrm{dequant}} x_{\mathrm{dequant}}$ where $A_{\mathrm{dequant}} \in \mathbb{R}^{M \times K}$ and $x_{\mathrm{dequant}} \in \mathbb{R}^{K}$.

## Input
- $q_a$: MXFP4 payload bytes for matrix $A$ of shape $M \times K$ (row-major)
- $scale_a$: per-block E8M0 scale bytes for $A$, logical shape $M \times K/32$
- $q_x$: MXFP4 payload bytes for vector $x$, represented as logical shape $1 \times K$
- $scale_x$: per-block E8M0 scale bytes for $x$, logical shape $1 \times K/32$
- $M$, $K$: dimensions ($K$ divisible by 32)

## Output
- $y$: FP32 vector of shape $M$

## Notes
- The reference dequantizes MXFP4 inputs with TorchAO MXTensor semantics and performs FP32 `matmul`.
- Scale tensors in this problem are row-major blocked order (not swizzled).
- Correctness is based on dequantized semantics, not bitwise equality of quantized payloads.
120 changes: 120 additions & 0 deletions problems/mxfp8-gemv/def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import torch
from typing import List, Dict, Tuple, Any

from problem import Problem

BLOCK_SIZE = 32


class mxfp8_gemv(Problem):

is_exact = False

parameters = [
{"name": "q_a", "type": "uint8_t", "pointer": True, "const": True},
{"name": "scale_a", "type": "uint8_t", "pointer": True, "const": True},
{"name": "q_x", "type": "uint8_t", "pointer": True, "const": True},
{"name": "scale_x", "type": "uint8_t", "pointer": True, "const": True},
{"name": "y", "type": "float", "pointer": True, "const": False},
{"name": "m", "type": "size_t", "pointer": False, "const": False},
{"name": "k", "type": "size_t", "pointer": False, "const": False},
]

def __init__(self):
super().__init__(name="mxfp8-gemv")

@staticmethod
def _mx_tensor_api():
try:
from torchao.prototype.mx_formats.mx_tensor import to_mx, to_dtype
except Exception as e:
raise RuntimeError(
"TorchAO MXTensor APIs are required. Install a torchao build with "
"torchao.prototype.mx_formats.mx_tensor support."
) from e
return to_mx, to_dtype

def reference_solution(
self, q_a: torch.Tensor, scale_a: torch.Tensor, q_x: torch.Tensor, scale_x: torch.Tensor
) -> torch.Tensor:
_, to_dtype = self._mx_tensor_api()

with torch.no_grad():
a_deq = to_dtype(
q_a.view(torch.float8_e4m3fn),
scale_a.view(torch.float8_e8m0fnu),
torch.float8_e4m3fn,
BLOCK_SIZE,
torch.float32,
).float()
x_deq = to_dtype(
q_x.view(torch.float8_e4m3fn),
scale_x.view(torch.float8_e8m0fnu),
torch.float8_e4m3fn,
BLOCK_SIZE,
torch.float32,
).float().squeeze(0)

return torch.matmul(a_deq, x_deq)

def _make_case(self, m: int, k: int, name: str) -> Dict[str, Any]:
if k % BLOCK_SIZE != 0:
raise ValueError(f"K must be divisible by {BLOCK_SIZE}, got K={k}")

seed = Problem.get_seed(f"{self.name}_{name}_M={m}_K={k}")

def create_inputs(m=m, k=k, seed=seed):
to_mx, _ = self._mx_tensor_api()

g = torch.Generator(device="cuda").manual_seed(seed)
a = torch.randn((m, k), device="cuda", dtype=self.param_dtype("y"), generator=g)
x = torch.randn((1, k), device="cuda", dtype=self.param_dtype("y"), generator=g)

scale_a_e8m0, a_lp = to_mx(a, torch.float8_e4m3fn, BLOCK_SIZE)
scale_x_e8m0, x_lp = to_mx(x, torch.float8_e4m3fn, BLOCK_SIZE)

q_a = a_lp.contiguous().view(torch.uint8)
scale_a = scale_a_e8m0.contiguous().view(torch.uint8)
q_x = x_lp.contiguous().view(torch.uint8)
scale_x = scale_x_e8m0.contiguous().view(torch.uint8)

return q_a, scale_a, q_x, scale_x

return {"name": name, "dims": (m, k), "create_inputs": create_inputs}

def generate_test_cases(self) -> List[Dict[str, Any]]:
configs = [
(1024, 1024, "1024 x 1024"),
(2048, 2048, "2048 x 2048"),
(4096, 4096, "4096 x 4096"),
(8192, 4096, "8192 x 4096"),
(4096, 8192, "4096 x 8192"),
]
return [self._make_case(m, k, name) for m, k, name in configs]

def generate_sample(self) -> Dict[str, Any]:
return self._make_case(32, 32, "sample_32x32")

def verify_result(
self, expected_output: torch.Tensor, actual_output: torch.Tensor
) -> Tuple[bool, Dict[str, Any]]:
is_close = torch.allclose(actual_output, expected_output, rtol=2e-2, atol=5e-2)

debug_info: Dict[str, Any] = {}
if not is_close:
diff = actual_output - expected_output
abs_diff = torch.abs(diff)
debug_info = {
"max_difference": abs_diff.max().item(),
"mean_difference": abs_diff.mean().item(),
}

return is_close, debug_info

def get_flops(self, test_case: Dict[str, Any]) -> int:
m, k = test_case["dims"]
return 2 * m * k

def get_extra_params(self, test_case: Dict[str, Any]) -> List[Any]:
m, k = test_case["dims"]
return [m, k]
31 changes: 31 additions & 0 deletions problems/mxfp8-gemv/problem.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
---
slug: "mxfp8-gemv"
title: "MXFP8 GEMV"
difficulty: "HARD"
author: "sarthak"
tags: ["quantization", "mxfp8", "matmul", "vector"]
gpus: ["B200"]
---

Compute matrix-vector multiplication where both matrix $A$ and vector $x$ are stored in MXFP8 format.

$$
y_i = \sum_{\ell=0}^{K-1} A_{\mathrm{dequant},i\ell} \, x_{\mathrm{dequant},\ell}.
$$

Equivalently, $y = A_{\mathrm{dequant}} x_{\mathrm{dequant}}$ where $A_{\mathrm{dequant}} \in \mathbb{R}^{M \times K}$ and $x_{\mathrm{dequant}} \in \mathbb{R}^{K}$.

## Input
- $q_a$: MXFP8 payload bytes for matrix $A$ of shape $M \times K$ (row-major)
- $scale_a$: per-block E8M0 scale bytes for $A$, logical shape $M \times K/32$
- $q_x$: MXFP8 payload bytes for vector $x$, represented as logical shape $1 \times K$
- $scale_x$: per-block E8M0 scale bytes for $x$, logical shape $1 \times K/32$
- $M$, $K$: dimensions ($K$ divisible by 32)

## Output
- $y$: FP32 vector of shape $M$

## Notes
- The reference dequantizes MXFP8 inputs with TorchAO MXTensor semantics and performs FP32 `matmul`.
- Scale tensors in this problem are row-major blocked order (not swizzled).
- Correctness is based on dequantized semantics, not bitwise equality of quantized payloads.