diff --git a/mlx/backend/cuda/device/utils.cuh b/mlx/backend/cuda/device/utils.cuh index 2e017ba21b..2f90713766 100644 --- a/mlx/backend/cuda/device/utils.cuh +++ b/mlx/backend/cuda/device/utils.cuh @@ -167,6 +167,9 @@ inline __device__ void store_vector( // Type limits utils /////////////////////////////////////////////////////////////////////////////// +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; + template struct Limits { static constexpr __host__ __device__ T max() { diff --git a/mlx/backend/cuda/quantized/qmm/qmm.h b/mlx/backend/cuda/quantized/qmm/qmm.h index 698fde0f6e..8d998cda40 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.h +++ b/mlx/backend/cuda/quantized/qmm/qmm.h @@ -122,6 +122,7 @@ void qmv( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, array& out, int bits, int group_size, diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index 7a293d2ce8..4ec6b95a7a 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -94,6 +94,7 @@ __device__ __forceinline__ void qmv_kernel_impl( const Q* w, const S* scales, const T* biases, + const float* global_scale, T* out, int row, int w_batch, @@ -155,6 +156,14 @@ __device__ __forceinline__ void qmv_kernel_impl( // Write result for current warp, which maps to rows 1-to-1. if (warp.thread_rank() == 0) { + if constexpr ( + cuda::std::is_same_v && + cuda::std::is_same_v) { + // Only nvfp4 supports global scale. + if (global_scale) { + sum *= (*global_scale / (F8E4M3_MAX * F4E2M1_MAX)); + } + } out[row] = static_cast(sum); } } @@ -173,6 +182,7 @@ __global__ void qmv_kernel( const Q* w, const S* scales, const T* biases, + const float* global_scale, T* out, int n, int k, @@ -195,7 +205,7 @@ __global__ void qmv_kernel( int w_batch = broadcast_w ? 0 : l; qmv_kernel_impl( - x, w, scales, biases, out, row, w_batch, n, k); + x, w, scales, biases, global_scale, out, row, w_batch, n, k); } template < @@ -235,7 +245,7 @@ __global__ void gather_qmv_kernel( out += block.group_index().y * n + m * n * l; qmv_kernel_impl( - x, w, scales, biases, out, row, w_idx, n, k); + x, w, scales, biases, nullptr, out, row, w_idx, n, k); } template < @@ -250,6 +260,7 @@ void qmv( const Q* w, const S* scales, const T* biases, + const float* global_scale, T* out, int m, int n, @@ -264,7 +275,8 @@ void qmv( dim3 num_blocks{ uint32_t(cuda::ceil_div(n, rows_per_block)), uint32_t(m), uint32_t(l)}; dim3 block_dims{WARP_SIZE, rows_per_block}; - void* args[] = {&x, &w, &scales, &biases, &out, &n, &k, &broadcast_w}; + void* args[] = { + &x, &w, &scales, &biases, &global_scale, &out, &n, &k, &broadcast_w}; dispatch_bool(k % (WARP_SIZE * elems_per_thread), [&](auto has_residue_k) { auto* kernel = &qmv_kernel< @@ -396,6 +408,7 @@ void qmv( const array& w, const array& scales, const std::optional& biases, + const std::optional& global_scale, array& out, int bits, int group_size, @@ -421,6 +434,9 @@ void qmv( if (biases) { encoder.set_input_array(*biases); } + if (global_scale) { + encoder.set_input_array(*global_scale); + } encoder.set_output_array(out); constexpr bool has_bias = !cutlass::has_negative_zero_v; cu::qmv( @@ -428,6 +444,7 @@ void qmv( gpu_ptr(w), gpu_ptr(scales), biases ? gpu_ptr(*biases) : nullptr, + global_scale ? gpu_ptr(*global_scale) : nullptr, gpu_ptr(out), m, n, diff --git a/mlx/backend/cuda/quantized/qqmm.cpp b/mlx/backend/cuda/quantized/qqmm.cpp index a4e019d662..eaec2ac8f4 100644 --- a/mlx/backend/cuda/quantized/qqmm.cpp +++ b/mlx/backend/cuda/quantized/qqmm.cpp @@ -76,23 +76,28 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { auto& encoder = cu::get_command_encoder(s); auto& device = encoder.device(); bool w_quantized = (inputs[1].dtype() == uint32); - int base_size = w_quantized ? 3 : 2; + // - 2 inputs: x, w (non-quantized w) + // - 3 inputs: x, w, scales_w (quantized w) + int base_size = w_quantized ? 3 : 2; assert( inputs.size() == base_size || (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2)); + // For nvfp4, global scales are optional but must be both present or both + // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) + bool has_global_scales = + mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; + std::optional global_scale_x = std::nullopt; + std::optional global_scale_w = std::nullopt; + if (has_global_scales) { + global_scale_x = inputs[inputs.size() - 2]; + global_scale_w = inputs[inputs.size() - 1]; + } + if (w_quantized && inputs[0].shape(-2) == 1) { out.set_data(cu::malloc_async(out.nbytes(), encoder)); - // For nvfp4, get global scale for x from inputs if present - bool has_global_scale = - mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; - std::optional global_scale = std::nullopt; - if (has_global_scale) { - global_scale = inputs[inputs.size() - 2]; - } - bool donate_x = inputs[0].is_donatable(); array x = ensure_row_contiguous(inputs[0], encoder, s); // If x is a copy it should be donatable @@ -104,11 +109,20 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { encoder.add_temporary(xhat); } fp_quantize_dequantize( - x, xhat, group_size_, bits_, global_scale, encoder, s); + x, xhat, group_size_, bits_, global_scale_x, encoder, s); const array& w = inputs[1]; const array& scales = inputs[2]; - qmv(xhat, w, scales, std::nullopt, out, bits_, group_size_, mode_, encoder); + qmv(xhat, + w, + scales, + std::nullopt, + global_scale_w, + out, + bits_, + group_size_, + mode_, + encoder); return; } @@ -119,22 +133,6 @@ void QQMatmul::eval_gpu(const std::vector& inputs, array& out) { "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher."); } - // - 2 inputs: x, w (non-quantized w) - // - 3 inputs: x, w, scales_w (quantized w) - - // For nvfp4, global scales are optional but must be both present or both - // absent If present, they add 2 more inputs (global_scale_x, global_scale_w) - bool has_global_scales = - mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size; - - // For nvfp4, get global scales from inputs if present - std::optional global_scale_x = std::nullopt; - std::optional global_scale_w = std::nullopt; - if (has_global_scales) { - global_scale_x = inputs[inputs.size() - 2]; - global_scale_w = inputs[inputs.size() - 1]; - } - // Quantize inputs (or use pre-quantized) auto [x_q, scale_x_pre] = quantize_input( inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x); diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu index 100df140a0..96a1fff7bb 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.cu +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -70,9 +70,6 @@ inline std::tuple get_swizzle_launch_args( namespace cu { -constexpr float F8E4M3_MAX = 448.0f; -constexpr float F4E2M1_MAX = 6.0f; - __global__ void compute_qqmm_pointers( float* alpha_out, float* beta_out, diff --git a/mlx/backend/cuda/quantized/quantized.cpp b/mlx/backend/cuda/quantized/quantized.cpp index 645b24cca6..2a1a268c91 100644 --- a/mlx/backend/cuda/quantized/quantized.cpp +++ b/mlx/backend/cuda/quantized/quantized.cpp @@ -84,7 +84,16 @@ void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { if (can_use_fp_qmv) { fp_qmv(x, w, scales, out, bits_, group_size_, encoder, s); } else { - qmv(x, w, scales, biases, out, bits_, group_size_, mode_, encoder); + qmv(x, + w, + scales, + biases, + std::nullopt, + out, + bits_, + group_size_, + mode_, + encoder); } }; diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 8f199dea14..0de3baec00 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -179,31 +179,50 @@ def test_qqmv(self): tests = product( [256, 512, 67], # M [64, 256], # N + ["nvfp4", "mxfp8"], # mode ) - modes = ["nvfp4", "mxfp8"] - for M, N in tests: - for mode in modes: - with self.subTest(shape=(M, N), mode=mode): - x_shape = (1, N) - w_shape = (M, N) + for M, N, mode in tests: + with self.subTest(shape=(M, N), mode=mode): + x_shape = (1, N) + w_shape = (M, N) + + # TODO: Fix qmv with global scale in Metal/CPU backends. + has_global_scale = ( + mode == "nvfp4" + and mx.cuda.is_available() + and mx.default_device() == mx.gpu + ) - x = mx.random.normal(shape=x_shape, key=k1) - x_hat = mx.dequantize( - *mx.quantize(x, mode=mode), mode=mode, dtype=mx.float32 - ) + x = mx.random.normal(shape=x_shape, key=k1) + global_scale_x = mx.max(mx.abs(x)) if has_global_scale else None + x_hat = mx.dequantize( + *mx.quantize(x, mode=mode, global_scale=global_scale_x), + mode=mode, + dtype=mx.float32, + global_scale=global_scale_x, + ) - w = mx.random.normal(shape=w_shape, key=k2) - w_q, scales = mx.quantize(w, mode=mode) - w_hat = mx.dequantize(w_q, scales, mode=mode, dtype=mx.float32) - y_q = mx.qqmm( - x, - w_q, - scales, - mode=mode, - ) - y_hat = x_hat @ mx.swapaxes(w_hat, -1, -2) - self.assertEqual(y_q.shape, y_hat.shape) - self.assertLess((y_q - y_hat).abs().max(), 1e-3) + w = mx.random.normal(shape=w_shape, key=k2) + global_scale_w = mx.max(mx.abs(w)) if has_global_scale else None + w_q, scales = mx.quantize(w, mode=mode, global_scale=global_scale_w) + w_hat = mx.dequantize( + w_q, + scales, + mode=mode, + global_scale=global_scale_w, + dtype=mx.float32, + ) + y_q = mx.qqmm( + x, + w_q, + scales, + mode=mode, + global_scale_x=global_scale_x, + global_scale_w=global_scale_w, + ) + y_hat = x_hat @ mx.swapaxes(w_hat, -1, -2) + self.assertEqual(y_q.shape, y_hat.shape) + self.assertLess((y_q - y_hat).abs().max(), 1e-3) def test_qqmm_metal_global_scale_rejected(self): # Tensor-scale nvfp4 (global_scale_x / global_scale_w) is not