Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlx/backend/cuda/device/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ inline __device__ void store_vector(
// Type limits utils
///////////////////////////////////////////////////////////////////////////////

constexpr float F8E4M3_MAX = 448.0f;
constexpr float F4E2M1_MAX = 6.0f;

template <typename T, typename = void>
struct Limits {
static constexpr __host__ __device__ T max() {
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/cuda/quantized/qmm/qmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ void qmv(
const array& w,
const array& scales,
const std::optional<array>& biases,
const std::optional<array>& global_scale,
array& out,
int bits,
int group_size,
Expand Down
23 changes: 20 additions & 3 deletions mlx/backend/cuda/quantized/qmm/qmv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ __device__ __forceinline__ void qmv_kernel_impl(
const Q* w,
const S* scales,
const T* biases,
const float* global_scale,
T* out,
int row,
int w_batch,
Expand Down Expand Up @@ -155,6 +156,14 @@ __device__ __forceinline__ void qmv_kernel_impl(

// Write result for current warp, which maps to rows 1-to-1.
if (warp.thread_rank() == 0) {
if constexpr (
cuda::std::is_same_v<Q, cutlass::float_e2m1_t> &&
cuda::std::is_same_v<S, cutlass::float_e4m3_t>) {
// Only nvfp4 supports global scale.
if (global_scale) {
sum *= (*global_scale / (F8E4M3_MAX * F4E2M1_MAX));
}
}
out[row] = static_cast<T>(sum);
}
}
Expand All @@ -173,6 +182,7 @@ __global__ void qmv_kernel(
const Q* w,
const S* scales,
const T* biases,
const float* global_scale,
T* out,
int n,
int k,
Expand All @@ -195,7 +205,7 @@ __global__ void qmv_kernel(
int w_batch = broadcast_w ? 0 : l;

qmv_kernel_impl<elems_per_thread, group_size, has_bias, has_residue_k>(
x, w, scales, biases, out, row, w_batch, n, k);
x, w, scales, biases, global_scale, out, row, w_batch, n, k);
}

template <
Expand Down Expand Up @@ -235,7 +245,7 @@ __global__ void gather_qmv_kernel(
out += block.group_index().y * n + m * n * l;

qmv_kernel_impl<elems_per_thread, group_size, has_bias, has_residue_k>(
x, w, scales, biases, out, row, w_idx, n, k);
x, w, scales, biases, nullptr, out, row, w_idx, n, k);
}

template <
Expand All @@ -250,6 +260,7 @@ void qmv(
const Q* w,
const S* scales,
const T* biases,
const float* global_scale,
T* out,
int m,
int n,
Expand All @@ -264,7 +275,8 @@ void qmv(
dim3 num_blocks{
uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)};
dim3 block_dims{WARP_SIZE, rows_per_block};
void* args[] = {&x, &w, &scales, &biases, &out, &n, &k, &broadcast_w};
void* args[] = {
&x, &w, &scales, &biases, &global_scale, &out, &n, &k, &broadcast_w};

dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) {
auto* kernel = &qmv_kernel<
Expand Down Expand Up @@ -396,6 +408,7 @@ void qmv(
const array& w,
const array& scales,
const std::optional<array>& biases,
const std::optional<array>& global_scale,
array& out,
int bits,
int group_size,
Expand All @@ -421,13 +434,17 @@ void qmv(
if (biases) {
encoder.set_input_array(*biases);
}
if (global_scale) {
encoder.set_input_array(*global_scale);
}
encoder.set_output_array(out);
constexpr bool has_bias = !cutlass::has_negative_zero_v<Q>;
cu::qmv<group_size, has_bias>(
gpu_ptr<T>(x),
gpu_ptr<Q>(w),
gpu_ptr<S>(scales),
biases ? gpu_ptr<T>(*biases) : nullptr,
global_scale ? gpu_ptr<float>(*global_scale) : nullptr,
gpu_ptr<T>(out),
m,
n,
Expand Down
52 changes: 25 additions & 27 deletions mlx/backend/cuda/quantized/qqmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,28 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& encoder = cu::get_command_encoder(s);
auto& device = encoder.device();
bool w_quantized = (inputs[1].dtype() == uint32);
int base_size = w_quantized ? 3 : 2;

// - 2 inputs: x, w (non-quantized w)
// - 3 inputs: x, w, scales_w (quantized w)
int base_size = w_quantized ? 3 : 2;
assert(
inputs.size() == base_size ||
(mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2));

// For nvfp4, global scales are optional but must be both present or both
// absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
bool has_global_scales =
mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;
std::optional<array> global_scale_x = std::nullopt;
std::optional<array> global_scale_w = std::nullopt;
if (has_global_scales) {
global_scale_x = inputs[inputs.size() - 2];
global_scale_w = inputs[inputs.size() - 1];
}

if (w_quantized && inputs[0].shape(-2) == 1) {
out.set_data(cu::malloc_async(out.nbytes(), encoder));

// For nvfp4, get global scale for x from inputs if present
bool has_global_scale =
mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;
std::optional<array> global_scale = std::nullopt;
if (has_global_scale) {
global_scale = inputs[inputs.size() - 2];
}

bool donate_x = inputs[0].is_donatable();
array x = ensure_row_contiguous(inputs[0], encoder, s);
// If x is a copy it should be donatable
Expand All @@ -104,11 +109,20 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
encoder.add_temporary(xhat);
}
fp_quantize_dequantize(
x, xhat, group_size_, bits_, global_scale, encoder, s);
x, xhat, group_size_, bits_, global_scale_x, encoder, s);

const array& w = inputs[1];
const array& scales = inputs[2];
qmv(xhat, w, scales, std::nullopt, out, bits_, group_size_, mode_, encoder);
qmv(xhat,
w,
scales,
std::nullopt,
global_scale_w,
out,
bits_,
group_size_,
mode_,
encoder);
return;
}

Expand All @@ -119,22 +133,6 @@ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
"[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher.");
}

// - 2 inputs: x, w (non-quantized w)
// - 3 inputs: x, w, scales_w (quantized w)

// For nvfp4, global scales are optional but must be both present or both
// absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
bool has_global_scales =
mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;

// For nvfp4, get global scales from inputs if present
std::optional<array> global_scale_x = std::nullopt;
std::optional<array> global_scale_w = std::nullopt;
if (has_global_scales) {
global_scale_x = inputs[inputs.size() - 2];
global_scale_w = inputs[inputs.size() - 1];
}

// Quantize inputs (or use pre-quantized)
auto [x_q, scale_x_pre] = quantize_input(
inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x);
Expand Down
3 changes: 0 additions & 3 deletions mlx/backend/cuda/quantized/qqmm_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ inline std::tuple<dim3, dim3> get_swizzle_launch_args(

namespace cu {

constexpr float F8E4M3_MAX = 448.0f;
constexpr float F4E2M1_MAX = 6.0f;

__global__ void compute_qqmm_pointers(
float* alpha_out,
float* beta_out,
Expand Down
11 changes: 10 additions & 1 deletion mlx/backend/cuda/quantized/quantized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,16 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (can_use_fp_qmv) {
fp_qmv(x, w, scales, out, bits_, group_size_, encoder, s);
} else {
qmv(x, w, scales, biases, out, bits_, group_size_, mode_, encoder);
qmv(x,
w,
scales,
biases,
std::nullopt,
out,
bits_,
group_size_,
mode_,
encoder);
}
};

Expand Down
63 changes: 41 additions & 22 deletions python/tests/test_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,31 +179,50 @@ def test_qqmv(self):
tests = product(
[256, 512, 67], # M
[64, 256], # N
["nvfp4", "mxfp8"], # mode
)
modes = ["nvfp4", "mxfp8"]
for M, N in tests:
for mode in modes:
with self.subTest(shape=(M, N), mode=mode):
x_shape = (1, N)
w_shape = (M, N)
for M, N, mode in tests:
with self.subTest(shape=(M, N), mode=mode):
x_shape = (1, N)
w_shape = (M, N)

# TODO: Fix qmv with global scale in Metal/CPU backends.
has_global_scale = (
mode == "nvfp4"
and mx.cuda.is_available()
and mx.default_device() == mx.gpu
)

x = mx.random.normal(shape=x_shape, key=k1)
x_hat = mx.dequantize(
*mx.quantize(x, mode=mode), mode=mode, dtype=mx.float32
)
x = mx.random.normal(shape=x_shape, key=k1)
global_scale_x = mx.max(mx.abs(x)) if has_global_scale else None
x_hat = mx.dequantize(
*mx.quantize(x, mode=mode, global_scale=global_scale_x),
mode=mode,
dtype=mx.float32,
global_scale=global_scale_x,
)

w = mx.random.normal(shape=w_shape, key=k2)
w_q, scales = mx.quantize(w, mode=mode)
w_hat = mx.dequantize(w_q, scales, mode=mode, dtype=mx.float32)
y_q = mx.qqmm(
x,
w_q,
scales,
mode=mode,
)
y_hat = x_hat @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
w = mx.random.normal(shape=w_shape, key=k2)
global_scale_w = mx.max(mx.abs(w)) if has_global_scale else None
w_q, scales = mx.quantize(w, mode=mode, global_scale=global_scale_w)
w_hat = mx.dequantize(
w_q,
scales,
mode=mode,
global_scale=global_scale_w,
dtype=mx.float32,
)
y_q = mx.qqmm(
x,
w_q,
scales,
mode=mode,
global_scale_x=global_scale_x,
global_scale_w=global_scale_w,
)
y_hat = x_hat @ mx.swapaxes(w_hat, -1, -2)
self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3)

def test_qqmm_metal_global_scale_rejected(self):
# Tensor-scale nvfp4 (global_scale_x / global_scale_w) is not
Expand Down
Loading