diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index fac31193bd..10adffc23d 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -58,6 +58,12 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm/qmm.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm/qmm_naive.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm/qmm_sm80.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm/qmm_sm90.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm/qmv.cu + ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm/fp_qmv.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp @@ -67,7 +73,6 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmm) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary) # fp4 is not available on < 12.8 @@ -168,12 +173,6 @@ message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}") set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}") -# Skip Hopper-only kernels when not building for sm90a. -if(("90a" IN_LIST MLX_CUDA_ARCHITECTURES) OR ("90a-real" IN_LIST - MLX_CUDA_ARCHITECTURES)) - target_compile_definitions(mlx PRIVATE MLX_CUDA_SM90A_ENABLED) -endif() - # Search CUDA libs from installed python packages. if(WIN32) # Resolve paths of unfound DLL at runtime. diff --git a/mlx/backend/cuda/cutlass_utils.cuh b/mlx/backend/cuda/cutlass_utils.cuh index 4770f01a2d..ee92de74b4 100644 --- a/mlx/backend/cuda/cutlass_utils.cuh +++ b/mlx/backend/cuda/cutlass_utils.cuh @@ -2,8 +2,10 @@ #pragma once +#include "mlx/backend/cuda/utils.h" #include "mlx/dtype.h" +#include #include #include #include @@ -43,4 +45,24 @@ struct CTypeToCutlassType { template using cutlass_type_t = typename CTypeToCutlassType::type; +// Convert Dtype to CUTLASS C++ types. +inline const char* dtype_to_cutlass_type(const Dtype& dtype) { + if (dtype == float16) { + return "cutlass::half_t"; + } + if (dtype == bfloat16) { + return "cutlass::bfloat16_t"; + } + return dtype_to_cuda_type(dtype); +} + +// Convert cute shape to string. +inline auto cta_tiler_to_string(auto cta_tiler) { + return fmt::format( + "cute::Shape, cute::Int<{}>, cute::Int<{}>>", + int(cute::size<0>(cta_tiler)), + int(cute::size<1>(cta_tiler)), + int(cute::size<2>(cta_tiler))); +} + } // namespace mlx::core diff --git a/mlx/backend/cuda/device/gemm_sm70.cuh b/mlx/backend/cuda/device/gemm_sm70.cuh new file mode 100644 index 0000000000..f66b890750 --- /dev/null +++ b/mlx/backend/cuda/device/gemm_sm70.cuh @@ -0,0 +1,285 @@ +// Copyright © 2026 Apple Inc. + +#include + +// clang-format off + +namespace mlx::core::cu { + +using namespace cute; + +template +struct SharedStorage { + ArrayEngine> A; + ArrayEngine> B; +}; + +template +inline constexpr auto make_smem_layout(TileM bM, TileK bK) { + // TODO: Calculate swizzle based on tile shape. + if constexpr (KMajor) { + auto swizzle = composition(Swizzle<3,3,3>{}, + Layout>, + Stride<_8,Stride<_1,_64>>>{}); + return tile_to_shape(swizzle, make_shape(bM, bK)); + } else { + auto swizzle = composition(Swizzle<3,3,3>{}, + Layout, Stride<_8,_1>>{}); + return tile_to_shape(swizzle, make_shape(bM, bK)); + } +} + +template +inline constexpr auto make_smem_layouts(CtaTiler cta_tiler) { + // Note: Kernel launcher assumes num_threads being same for all parameters. + auto [bM, bN, bK] = cta_tiler; + auto sA_layout = make_smem_layout(bM, bK); + auto sB_layout = make_smem_layout(bN, bK); + return cute::make_tuple(sA_layout, sB_layout); +} + +template +inline constexpr auto make_tiled_mma(CtaTiler cta_tiler) { + // Note: Kernel launcher assumes num_threads being same for all parameters. + using Atom = cuda::std::conditional_t< + SM80, + cuda::std::conditional_t< + cuda::std::is_same_v, + SM80_16x8x16_F32F16F16F32_TN, + cuda::std::conditional_t< + cuda::std::is_same_v, + SM80_16x8x16_F32BF16BF16F32_TN, + UniversalFMA + > + >, + UniversalFMA>; + if constexpr (!SM80 || cuda::std::is_same_v) { + return make_tiled_mma(Atom{}, Layout>{}); + } else { + if constexpr (size<0>(cta_tiler) >= 32) { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); + } else { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); + } + } +} + +template +inline constexpr auto make_tiled_copy(NumThreads num_threads, TileM bM, TileK bK) { + // TODO: Only do 1-element read for the tile of residue. + auto n_read = Int{}; + auto atom = Copy_Atom>>, T>{}; + if constexpr (KMajor) { + auto k_threads = bK / n_read; + return make_tiled_copy( + atom, + make_layout(make_shape(Int{}, k_threads), LayoutRight{}), + make_layout(make_shape(Int<1>{}, n_read))); + } else { + auto m_threads = bM / n_read; + return make_tiled_copy( + atom, + make_layout(make_shape(m_threads, Int{}), LayoutLeft{}), + make_layout(make_shape(n_read, Int<1>{}))); + } +} + +template +CUTE_DEVICE void gemm_sm70_mainloop( + CtaTiler cta_tiler, + TensorA gA, + TensorB gB, + TensorC gC, + int m_max_coord, + int n_max_coord, + int k_residue, + int thread_idx) { + // Get the types of operands. + using Element = typename decltype(gA)::value_type; + + // Shift tensor so we handle residue of K in the 0th tile. + gA = domain_offset(make_coord(0, k_residue, 0), gA); + gB = domain_offset(make_coord(0, k_residue, 0), gB); + + // Define smem layouts. + auto [sA_layout, sB_layout] = make_smem_layouts(cta_tiler); + + // Shared memory buffer. + extern __shared__ char smem_buf[]; + using SharedStorage = SharedStorage; + SharedStorage& smem = *reinterpret_cast(smem_buf); + Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K) + Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K) + + // Define MMA. + auto mma = make_tiled_mma(cta_tiler); + auto num_threads = size(mma); + + // Define copy atoms. + auto [bM, bN, bK] = cta_tiler; + TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); + TiledCopy copy_b = make_tiled_copy(num_threads, bN, bK); + + // Partition the copying of A/B/C tiles across the threads. + ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); + Tensor tAgA = thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) + Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) + + ThrCopy thr_copy_b = copy_b.get_slice(thread_idx); + Tensor tBgB = thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) + Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_M,BCPY_K) + + // MMA. + ThrMMA thr_mma = mma.get_slice(thread_idx); + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + + // Predicates for m/n bounds. + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); // (CPY_M,CPY_K) + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); // (CPY_N,CPY_K) + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) + Tensor cC = make_identity_tensor(make_shape(size<0>(gC), size<1>(gC))); // (M,N) + Tensor tAcA = thr_copy_a.partition_S(cA); // (CPY,CPY_M,CPY_K) + Tensor tBcB = thr_copy_b.partition_S(cB); // (CPY,CPY_N,CPY_K) + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) + CUTE_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord; + } + CUTE_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < n_max_coord; + } + + // GMEM => RMEM. + auto fetch_gmem = [&](int tile) { + copy_if(copy_a, tApA, tAgA(_,_,_,tile), tArA); + copy_if(copy_b, tBpB, tBgB(_,_,_,tile), tBrB); + }; + // RMEM => SMEM. + auto store_smem = [&]() { + __syncthreads(); + copy(tArA, tAsA); + copy(tBrB, tBsB); + __syncthreads(); + }; + + // Clear the rmem tiles to account for predicated off loads. + clear(tArA); + clear(tBrB); + + // Prefetch first tile. + Tensor tAgA_k = tAgA(_,_,_,0); + Tensor tBgB_k = tBgB(_,_,_,0); + CUTE_UNROLL + for (int k = 0; k < size<2>(tArA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -k_residue) { + copy_if(copy_a, tApA(_,k), tAgA_k(_,_,k), tArA(_,_,k)); + } + } + CUTE_UNROLL + for (int k = 0; k < size<2>(tBrB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -k_residue) { + copy_if(copy_b, tBpB(_,k), tBgB_k(_,_,k), tBrB(_,_,k)); + } + } + + // Clear accumulators. + clear(tCrC); + + // Loop over CTA tiles. + auto K_TILE_MAX = size<3>(tAgA); + for (int tile = 0; tile < K_TILE_MAX; ++tile) { + store_smem(); + // Avoid fetching full 0th-tile when there is residue. + if (K_TILE_MAX > 1) { + fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); + } + gemm(mma, tCsA, tCsB, tCrC); + } + + // Epilogue. + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) { + if ((get<0>(tCcC(i)) < m_max_coord) && (get<1>(tCcC(i)) < n_max_coord)) { + tCgC(i) = Element(tCrC(i)); + } + } +} + +template +inline constexpr auto make_matrix_stride(int m, int k) { + if constexpr (KMajor) { + return cute::make_stride(k, cute::Int<1>{}, m * k); + } else { + return cute::make_stride(cute::Int<1>{}, m, m * k); + } +} + +template +__global__ +__launch_bounds__(decltype(size(make_tiled_mma(CtaTiler{})))::value) +void gemm_sm70_kernel( + const Element* A, + const Element* B, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, + Element* C, + int m, int n, int k, int l) { + int thread_idx = int(threadIdx.x); + int m_coord = int(blockIdx.x); + int n_coord = int(blockIdx.y); + int l_coord = int(blockIdx.z); + + // Define layouts (mixed). + auto dA = make_matrix_stride(m, k); // (dM,dK,dL) + auto dB = make_matrix_stride(n, k); // (dN,dK,dL) + auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) + + // Represent the full tensors. + Tensor mA_mkl = make_tensor(make_gmem_ptr(A), make_shape(m, k, l), dA); // (M,K,L) + Tensor mB_nkl = make_tensor(make_gmem_ptr(B), make_shape(n, k, l), dB); // (N,K,L) + Tensor mC_mnl = make_tensor(make_gmem_ptr(C), make_shape(m, n, l), dC); // (M,N,L) + + // For gather, use index lookup for input batch slicing. + uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; + uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; + + // Get batch slice. + Tensor mA = mA_mkl(_,_,a_batch); // (M,K) + Tensor mB = mB_nkl(_,_,b_batch); // (N,K) + Tensor mC = mC_mnl(_,_,l_coord); // (M,N) + + // Get the appropriate blocks for this thread block. + auto cta_tiler = CtaTiler{}; + auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // Compute tile residues for predication. + int m_max_coord = m - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord + int n_max_coord = n - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord + int k_residue = k - size<1>(gA) * size<2>(gA); + + gemm_sm70_mainloop( + cta_tiler, + gA, + gB, + gC, + m_max_coord, n_max_coord, k_residue, + thread_idx); +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/qmm_naive.cuh b/mlx/backend/cuda/device/qmm_naive.cuh index f9ee549962..01e5f444d5 100644 --- a/mlx/backend/cuda/device/qmm_naive.cuh +++ b/mlx/backend/cuda/device/qmm_naive.cuh @@ -1,6 +1,7 @@ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/device/cute_dequant.cuh" +#include "mlx/backend/cuda/device/gemm_sm70.cuh" #include @@ -10,83 +11,6 @@ namespace mlx::core::cu { using namespace cute; -template -struct SharedStorage { - ArrayEngine> A; - ArrayEngine> B; -}; - -template -inline constexpr auto make_smem_layout(TileM bM, TileN bK) { - // TODO: Calculate swizzle based on tile shape. - if constexpr (KMajor) { - auto swizzle = composition(Swizzle<3,3,3>{}, - Layout>, - Stride<_8,Stride<_1,_64>>>{}); - return tile_to_shape(swizzle, make_shape(bM, bK)); - } else { - auto swizzle = composition(Swizzle<3,3,3>{}, - Layout, Stride<_1,_64>>{}); - return tile_to_shape(swizzle, make_shape(bM, bK)); - } -} - -template -inline constexpr auto make_smem_layouts(CtaTiler cta_tiler) { - // Note: Kernel launcher assumes cosize being same for all KMajor. - auto [bM, bN, bK] = cta_tiler; - auto sA_layout = make_smem_layout(bM, bK); - auto sB_layout = make_smem_layout(bN, bK); - return cute::make_tuple(sA_layout, sB_layout); -} - -template -inline constexpr auto make_tiled_mma(CtaTiler cta_tiler) { - // Note: Kernel launcher assumes num_threads being same for all parameters. - using Atom = cuda::std::conditional_t< - SM80, - cuda::std::conditional_t< - cuda::std::is_same_v, - SM80_16x8x16_F32F16F16F32_TN, - cuda::std::conditional_t< - cuda::std::is_same_v, - SM80_16x8x16_F32BF16BF16F32_TN, - UniversalFMA - > - >, - UniversalFMA>; - if constexpr (!SM80 || cuda::std::is_same_v) { - return make_tiled_mma(Atom{}, Layout>{}); - } else { - if constexpr (size<0>(cta_tiler) >= 32) { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); - } else { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); - } - } -} - -template -inline constexpr auto make_tiled_copy(NumThreads num_threads, TileM bM, TileN bK) { - // TODO: Only do 1-element read for the tile of residue. - auto n_read = Int{}; - auto atom = Copy_Atom>>, T>{}; - if constexpr (KMajor) { - auto k_threads = bK / n_read; - return make_tiled_copy( - atom, - make_layout(make_shape(Int{}, k_threads), LayoutRight{}), - make_layout(make_shape(Int<1>{}, n_read))); - } else { - auto m_threads = bM / n_read; - return make_tiled_copy( - atom, - make_layout(make_shape(m_threads, Int{}), LayoutLeft{}), - make_layout(make_shape(n_read, Int<1>{}))); - } -} - template (cta_tiler); + auto [sA_layout, sB_layout] = make_smem_layouts(cta_tiler); // Shared memory buffer. extern __shared__ char smem_buf[]; @@ -133,13 +57,13 @@ CUTE_DEVICE void qmm_naive_mainloop( Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K) // Define MMA. - auto mma = make_tiled_mma(CtaTiler{}); + auto mma = make_tiled_mma(cta_tiler); auto num_threads = size(mma); // Define copy atoms. auto [bM, bN, bK] = cta_tiler; - TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); - TiledCopy copy_b = make_tiled_copy(num_threads, bN, bK); + TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); + TiledCopy copy_b = make_tiled_copy(num_threads, bN, bK); // Partition the copying of A/B/C tiles across the threads. ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); @@ -277,15 +201,6 @@ CUTE_DEVICE void qmm_naive_mainloop( } } -template -inline constexpr auto make_matrix_stride(int m, int k) { - if constexpr (KMajor) { - return make_stride(k, Int<1>{}, m * k); - } else { - return make_stride(Int<1>{}, m, m * k); - } -} - template inline constexpr auto make_scales_layout(int n, int k, int l) { auto group_size = Int{}; diff --git a/mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh b/mlx/backend/cuda/device/qmm_sm80.cuh similarity index 72% rename from mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh rename to mlx/backend/cuda/device/qmm_sm80.cuh index d07909747b..e6efe04022 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh +++ b/mlx/backend/cuda/device/qmm_sm80.cuh @@ -1,12 +1,12 @@ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/device/cute_dequant.cuh" -#include "mlx/dtype_utils.h" + +#include // clang-format off -// We can't put kernel code in mlx::core due to name conflicts of "Shape". -namespace cutlass_gemm { +namespace mlx::core::cu { using namespace cute; @@ -25,12 +25,15 @@ union SharedStorage { } epilogue; }; -inline constexpr auto make_smem_layouts(auto cta_tiler) { +template +inline constexpr auto make_smem_layouts(CtaTiler cta_tiler) { + // Note: Kernel launcher assumes cosize being same for all KMajor. + auto [bM, bN, bK] = cta_tiler; + // Define the A/B smem layouts (static). auto swizzle_ab = composition(Swizzle<3,3,3>{}, Layout>, Stride<_8,Stride<_1,_64>>>{}); - auto [bM, bN, bK] = cta_tiler; auto bP = Int<3>{}; // pipeline auto sA_layout = tile_to_shape(swizzle_ab, make_shape(bM, bK, bP)); auto sB_layout = tile_to_shape(swizzle_ab, make_shape(bN, bK, bP)); @@ -39,11 +42,25 @@ inline constexpr auto make_smem_layouts(auto cta_tiler) { // TODO: Find a better swizzle. auto sC_layout = tile_to_shape(swizzle_ab, make_shape(bM, bN)); - return std::make_tuple(sA_layout, sB_layout, sC_layout); + return cute::make_tuple(sA_layout, sB_layout, sC_layout); +} + +template +inline constexpr auto make_tiled_mma() { + // Note: Kernel launcher assumes num_threads being same for all parameters. + using Atom = cuda::std::conditional_t< + cuda::std::is_same_v, + SM80_16x8x16_F32F16F16F32_TN, + SM80_16x8x16_F32BF16BF16F32_TN>; + if constexpr (TileM >= 32) { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); + } else { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); + } } -template typename Atom> -inline constexpr auto make_tiled_copy(auto num_threads) { +template typename Atom, typename NumThreads> +inline constexpr auto make_tiled_copy(NumThreads num_threads) { return make_tiled_copy( Copy_Atom>, T>{}, make_layout(make_shape(Int{}, Int<8>{}), LayoutRight{}), @@ -55,8 +72,7 @@ template + typename TensorC> CUTE_DEVICE void qmm_sm80_mainloop( CtaTiler cta_tiler, TensorA gA, @@ -64,13 +80,12 @@ CUTE_DEVICE void qmm_sm80_mainloop( TensorS gS, TensorZ gZ, TensorC gC, - TiledMma mma, int m_max_coord, int thread_idx) { // Get the types of operands. - using Element = decltype(gA)::value_type; - using Quant = decltype(gB)::value_type; - using Scale = decltype(gS)::value_type; + using Element = typename decltype(gA)::value_type; + using Quant = typename decltype(gB)::value_type; + using Scale = typename decltype(gS)::value_type; // Define smem layouts. auto [sA_layout, sB_layout, sC_layout] = make_smem_layouts(cta_tiler); @@ -86,11 +101,14 @@ CUTE_DEVICE void qmm_sm80_mainloop( Tensor sB = make_tensor(make_smem_ptr(smem.mainloop.B.begin()), sB_layout); // (BLK_N,BLK_K) Tensor sC = make_tensor(make_smem_ptr(smem.epilogue.C.begin()), sC_layout); // (BLK_M,BLK_N) + // Define MMA. + auto mma = make_tiled_mma(cta_tiler)>(); + auto num_threads = size(mma); + // Define copy atoms. constexpr int element_bits = sizeof_bits_v; constexpr int quant_bits = sizeof_bits_v; constexpr int qload = 128 / (element_bits / quant_bits); - auto num_threads = size(mma); TiledCopy g2s_copy_a = make_tiled_copy(num_threads); TiledCopy g2s_copy_b = make_tiled_copy(num_threads); TiledCopy s2g_copy_c = make_tiled_copy(num_threads); @@ -254,93 +272,87 @@ CUTE_DEVICE void qmm_sm80_mainloop( copy_if(s2g_copy_c, tCpC, s2g_tCsC, s2g_tCgC); } -inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size) { +template +inline constexpr auto make_scales_layout(int n, int k, int l) { + auto group_size = Int{}; return make_layout( make_shape(n, make_shape(group_size, k / group_size), l), make_stride(k / group_size, Stride<_0,_1>{}, n * k / group_size)); } -template -inline constexpr auto make_cta_tiler(auto group_size) { - auto bM = Int{}; - auto bN = Int<128>{}; - auto bK = Int{}; - return make_shape(bM, bN, bK); -} - -template -inline constexpr auto make_tiled_mma() { - using Atom = std::conditional_t< - std::is_same_v, - SM80_16x8x16_F32F16F16F32_TN, - std::conditional_t< - std::is_same_v, - SM80_16x8x16_F32BF16BF16F32_TN, - UniversalFMA>>; - if constexpr (TileM >= 32) { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); - } else { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); +template +__global__ +__launch_bounds__(decltype(size(make_tiled_mma()))::value) +void qmm_sm80_kernel( + const Element* A, + const Quant* B, + const Scale* S, + const Element* Z, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, + Element* C, + int m, int n, int k, int l, + bool broadcast_b) { + int thread_idx = int(threadIdx.x); + int m_coord = int(blockIdx.x); + int n_coord = int(blockIdx.y); + int l_coord = int(blockIdx.z); + + // Define layouts (mixed). + auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL) + auto dB = make_stride(k, Int<1>{}, n * k); // (dN,dK,dL) + auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) + auto S_layout = make_scales_layout(n, k, l); + + // Handle broadcasting. + if (broadcast_b) { + get<2>(dB) = 0; + get<2>(stride(S_layout)) = 0; } -} -} // namespace cutlass_gemm - -// clang-format on - -namespace mlx::core { - -template -inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { - if (dtype == float16) { - f.template operator()(); - } else if (dtype == bfloat16) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); - } -} - -template -inline void dispatch_groups(int group_size, const char* tag, F&& f) { - if (group_size == 32) { - f.template operator()<32>(); - } else if (group_size == 64) { - f.template operator()<64>(); - } else if (group_size == 128) { - f.template operator()<128>(); - } else { - throw std::invalid_argument( - fmt::format("{} Group size {} is not supported.", tag, group_size)); - } -} - -template -inline void dispatch_quant_types( - int bits, - int group_size, - QuantizationMode mode, - const char* tag, - F&& f) { - if (mode == QuantizationMode::Mxfp4) { - f.template operator()(); - } else if (mode == QuantizationMode::Mxfp8) { - f.template operator()(); - } else if (mode == QuantizationMode::Nvfp4) { - f.template operator()(); - } else { - dispatch_groups(group_size, tag, [&]() { - if (bits == 4) { - f.template operator()(); - } else if (bits == 8) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} {}-bit quantization is not supported.", tag, bits)); - } - }); - } + // Represent the full tensors. + Tensor mA_mkl = make_tensor(make_gmem_ptr(A), make_shape(m, k, l), dA); // (M,K,L) + Tensor mB_nkl = make_tensor(make_gmem_ptr(B), make_shape(n, k, l), dB); // (N,K,L) + Tensor mC_mnl = make_tensor(make_gmem_ptr(C), make_shape(m, n, l), dC); // (M,N,L) + + Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) + Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) + + // For gather, use index lookup for input batch slicing. + uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; + uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; + + // Get batch slice. + Tensor mA = mA_mkl(_,_,a_batch); // (M,K) + Tensor mB = mB_nkl(_,_,b_batch); // (N,K) + Tensor mC = mC_mnl(_,_,l_coord); // (M,N) + + Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) + Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) + + // Get the appropriate blocks for this thread block. + auto cta_tiler = CtaTiler{}; + auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + + // Compute tile residues for predication. + auto m_max_coord = m - size<0>(gA) * m_coord; // M - BLK_M * m_coord + + qmm_sm80_mainloop( + cta_tiler, + gA, + gB, + gS, + gZ, + gC, + m_max_coord, + thread_idx); } -} // namespace mlx::core +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/device/qmm_sm90.cuh b/mlx/backend/cuda/device/qmm_sm90.cuh new file mode 100644 index 0000000000..cd115bcf6f --- /dev/null +++ b/mlx/backend/cuda/device/qmm_sm90.cuh @@ -0,0 +1,91 @@ +// Copyright © 2026 Apple Inc. + +#if defined(__CUDACC_RTC__) + +#include +#include + +// Some CUTLASS headers use following std functions but we can't use STL when +// compiling with NVRTC. +namespace std { +using cuda::std::is_pointer_v; +using cuda::std::max; +using cuda::std::void_t; +} // namespace std + +// The cutlass/floating_point_nvrtc.h file assumes following constants not +// being defined but they are in the CUDA std headers. +#undef FP_NAN +#undef FP_INFINITE +#undef FP_ZERO +#undef FP_SUBNORMAL +#undef FP_NORMAL + +#endif // defined(__CUDACC_RTC__) + +#include +#include +#include +#include +#include +#include + +namespace mlx::core::cu { + +using namespace cute; + +template +CUTLASS_HOST_DEVICE auto make_qmm_sm90_kernel() { + constexpr int AlignmentA = 128 / sizeof_bits::value; + constexpr int AlignmentB = 128 / sizeof_bits::value; + + using Arch = cutlass::arch::Sm90; + using Accumulator = float; + using ClusterShape = Shape<_1, _1, _1>; + + using Epilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + Arch, + cutlass::arch::OpClassTensorOp, + CtaTiler, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + Accumulator, + Accumulator, + // ElementC: + void, + cutlass::layout::ColumnMajor, + AlignmentA, + // ElementD: + Element, + cutlass::layout::ColumnMajor, + AlignmentA, + cutlass::epilogue::TmaWarpSpecializedCooperative>::CollectiveOp; + + // Note that A/B are swapped and transposed to use TMA epilogue. + using Mainloop = typename cutlass::gemm::collective::CollectiveBuilder< + Arch, + cutlass::arch::OpClassTensorOp, + // ElementA: + tuple, + cutlass::layout::RowMajor, + AlignmentB, + // ElementB: + Element, + cutlass::layout::ColumnMajor, + AlignmentA, + Accumulator, + CtaTiler, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename Epilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; + + return cutlass::gemm::kernel:: + GemmUniversal, Mainloop, Epilogue>{}; +} + +template +using qmm_sm90_kernel_t = + decltype(make_qmm_sm90_kernel()); + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/gemms/gather_gemm.cu b/mlx/backend/cuda/gemms/gather_gemm.cu index 478680b744..2bb345852e 100644 --- a/mlx/backend/cuda/gemms/gather_gemm.cu +++ b/mlx/backend/cuda/gemms/gather_gemm.cu @@ -1,362 +1,26 @@ // Copyright © 2026 Apple Inc. -#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/gemm_sm70.cuh" + +#include "mlx/backend/cuda/cutlass_utils.cuh" +#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/dtype_utils.h" -#include - -// clang-format off - -// We can't put kernel code in mlx::core due to name conflicts of "Shape". -namespace cutlass_gemm { - -using namespace cute; - -template -struct SharedStorage { - ArrayEngine> A; - ArrayEngine> B; -}; - -template -inline constexpr auto make_smem_layout(auto bM, auto bK) { - // TODO: Calculate swizzle based on tile shape. - if constexpr (KMajor) { - auto swizzle = composition(Swizzle<3,3,3>{}, - Layout>, - Stride<_8,Stride<_1,_64>>>{}); - return tile_to_shape(swizzle, make_shape(bM, bK)); - } else { - auto swizzle = composition(Swizzle<3,3,3>{}, - Layout, Stride<_8,_1>>{}); - return tile_to_shape(swizzle, make_shape(bM, bK)); - } -} - -template -inline constexpr auto make_smem_layouts(auto cta_tiler) { - auto [bM, bN, bK] = cta_tiler; - auto sA_layout = make_smem_layout(bM, bK); - auto sB_layout = make_smem_layout(bN, bK); - return std::make_tuple(sA_layout, sB_layout); -} - -template -inline constexpr auto make_tiled_copy(auto num_threads, auto bM, auto bK) { - auto n_read = Int{}; - auto atom = Copy_Atom>>, T>{}; - if constexpr (KMajor) { - auto k_threads = bK / n_read; - return make_tiled_copy( - atom, - make_layout(make_shape(Int{}, k_threads), LayoutRight{}), - make_layout(make_shape(Int<1>{}, n_read))); - } else { - auto m_threads = bM / n_read; - return make_tiled_copy( - atom, - make_layout(make_shape(m_threads, Int{}), LayoutLeft{}), - make_layout(make_shape(n_read, Int<1>{}))); - } -} - -template -CUTE_DEVICE void gemm_sm70_mainloop( - CtaTiler cta_tiler, - TensorA gA, - TensorB gB, - TensorC gC, - TiledMma mma, - int m_max_coord, - int n_max_coord, - int k_residue, - int thread_idx) { - // Get the types of operands. - using Element = decltype(gA)::value_type; - - // Shift tensor so we handle residue of K in the 0th tile. - gA = domain_offset(make_coord(0, k_residue, 0), gA); - gB = domain_offset(make_coord(0, k_residue, 0), gB); - - // Define smem layouts. - auto [sA_layout, sB_layout] = make_smem_layouts(cta_tiler); +#include "cuda_jit_sources.h" - // Shared memory buffer. - extern __shared__ char smem_buf[]; - using SharedStorage = SharedStorage; - SharedStorage& smem = *reinterpret_cast(smem_buf); - Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K) - Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K) - - // Define copy atoms. - auto num_threads = size(mma); - auto [bM, bN, bK] = cta_tiler; - TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); - TiledCopy copy_b = make_tiled_copy(num_threads, bN, bK); - - // Partition the copying of A/B/C tiles across the threads. - ThrCopy thr_copy_a = copy_a.get_slice(thread_idx); - Tensor tAgA = thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) - Tensor tAsA = thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K) - Tensor tArA = make_fragment_like(tAsA); // (ACPY,ACPY_M,ACPY_K) - - ThrCopy thr_copy_b = copy_b.get_slice(thread_idx); - Tensor tBgB = thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) - Tensor tBsB = thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K) - Tensor tBrB = make_fragment_like(tBsB); // (BCPY,BCPY_M,BCPY_K) - - // MMA. - ThrMMA thr_mma = mma.get_slice(thread_idx); - Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) - Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) - Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) - Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) - - // Predicates for m/n bounds. - Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); // (CPY_M,CPY_K) - Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); // (CPY_N,CPY_K) - Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) - Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) - Tensor cC = make_identity_tensor(make_shape(size<0>(gC), size<1>(gC))); // (M,N) - Tensor tAcA = thr_copy_a.partition_S(cA); // (CPY,CPY_M,CPY_K) - Tensor tBcB = thr_copy_b.partition_S(cB); // (CPY,CPY_N,CPY_K) - Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) - CUTE_UNROLL - for (int m = 0; m < size<0>(tApA); ++m) { - tApA(m,0) = get<0>(tAcA(0,m,0)) < m_max_coord; - } - CUTE_UNROLL - for (int n = 0; n < size<0>(tBpB); ++n) { - tBpB(n,0) = get<0>(tBcB(0,n,0)) < n_max_coord; - } - - // GMEM => RMEM. - auto fetch_gmem = [&](int tile) { - copy_if(copy_a, tApA, tAgA(_,_,_,tile), tArA); - copy_if(copy_b, tBpB, tBgB(_,_,_,tile), tBrB); - }; - // RMEM => SMEM. - auto store_smem = [&]() { - __syncthreads(); - copy(tArA, tAsA); - copy(tBrB, tBsB); - __syncthreads(); - }; - - // Clear the rmem tiles to account for predicated off loads. - clear(tArA); - clear(tBrB); - - // Prefetch first tile. - Tensor tAgA_k = tAgA(_,_,_,0); - Tensor tBgB_k = tBgB(_,_,_,0); - CUTE_UNROLL - for (int k = 0; k < size<2>(tArA); ++k) { - if (get<1>(tAcA(0,0,k)) >= -k_residue) { - copy_if(copy_a, tApA(_,k), tAgA_k(_,_,k), tArA(_,_,k)); - } - } - CUTE_UNROLL - for (int k = 0; k < size<2>(tBrB); ++k) { - if (get<1>(tBcB(0,0,k)) >= -k_residue) { - copy_if(copy_b, tBpB(_,k), tBgB_k(_,_,k), tBrB(_,_,k)); - } - } - - // Clear accumulators. - clear(tCrC); - - // Loop over CTA tiles. - auto K_TILE_MAX = size<3>(tAgA); - for (int tile = 0; tile < K_TILE_MAX; ++tile) { - store_smem(); - // Avoid fetching full 0th-tile when there is residue. - if (K_TILE_MAX > 1) { - fetch_gmem((tile + 1 < K_TILE_MAX) ? tile + 1 : tile); - } - gemm(mma, tCsA, tCsB, tCrC); - } - - // Epilogue. - CUTE_UNROLL - for (int i = 0; i < size(tCrC); ++i) { - if ((get<0>(tCcC(i)) < m_max_coord) && (get<1>(tCcC(i)) < n_max_coord)) { - tCgC(i) = Element(tCrC(i)); - } - } -} - -template -inline constexpr auto make_matrix_stride(auto m, auto k) { - if constexpr (KMajor) { - return cute::make_stride(k, cute::Int<1>{}, m * k); - } else { - return cute::make_stride(cute::Int<1>{}, m, m * k); - } -} - -template -inline constexpr auto make_tiled_mma(auto cta_tiler) { - using Atom = std::conditional_t< - SM80, - std::conditional_t< - std::is_same_v, - SM80_16x8x16_F32F16F16F32_TN, - std::conditional_t< - std::is_same_v, - SM80_16x8x16_F32BF16BF16F32_TN, - UniversalFMA - > - >, - UniversalFMA>; - if constexpr (!SM80 || std::is_same_v) { - return make_tiled_mma(Atom{}, Layout>{}); - } else { - if constexpr (size<0>(cta_tiler) >= 32) { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); - } else { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); - } - } -} - -template -__global__ -__launch_bounds__(decltype(size(TiledMma{}))::value) -void gather_mm_kernel( - ProblemShape shape_MNKL, - CtaTiler cta_tiler, - const Element* A, StrideA dA, - const Element* B, StrideB dB, - const uint32_t* lhs_indices, - const uint32_t* rhs_indices, - Element* C, StrideC dC, - TiledMma mma) { - CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); - CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB)); - CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC)); - - int thread_idx = int(threadIdx.x); - int m_coord = int(blockIdx.x); - int n_coord = int(blockIdx.y); - int l_coord = int(blockIdx.z); - - // Represent the full tensors. - Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) - Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) - Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L) - - // For gather, use index lookup for input batch slicing. - uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; - uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; - - // Get batch slice. - Tensor mA = mA_mkl(_,_,a_batch); // (M,K) - Tensor mB = mB_nkl(_,_,b_batch); // (N,K) - Tensor mC = mC_mnl(_,_,l_coord); // (M,N) - - // Get the appropriate blocks for this thread block. - auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) - Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) - Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) - - // Compute tile residues for predication. - int m_max_coord = size<0>(shape_MNKL) - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord - int n_max_coord = size<1>(shape_MNKL) - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord - int k_residue = size<2>(shape_MNKL) - size<1>(gA) * size<2>(gA); - - if (k_residue % 8 == 0) { - gemm_sm70_mainloop( - cta_tiler, - gA, - gB, - gC, - mma, - m_max_coord, n_max_coord, k_residue, - thread_idx); - } else { - gemm_sm70_mainloop( - cta_tiler, - gA, - gB, - gC, - mma, - m_max_coord, n_max_coord, k_residue, - thread_idx); - } -} - -template -void gather_mm( - const Element* A, - const Element* B, - const uint32_t* lhs_indices, - const uint32_t* rhs_indices, - Element* C, - int m, int n, int k, int l, - auto&& launch_kernel) { - // Define shapes (dynamic). - auto shape_MNKL = make_shape(m, n, k, l); // (M,N,K,L) - - // Define layouts (mixed). - auto dA = make_matrix_stride(m, k); // (dM,dK,dL) - auto dB = make_matrix_stride(n, k); // (dN,dK,dL) - auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) - - // Define CTA tile size (static). - auto cta_tiler = make_shape(Int<16>{}, Int<128>{}, Int<64>{}); - - // Define MMA. - auto mma = make_tiled_mma(cta_tiler); - auto num_threads = size(mma); - - // Shared memory size. - auto [sA_layout, sB_layout] = make_smem_layouts(cta_tiler); - size_t smem_bytes = sizeof(SharedStorage); +namespace mlx::core { - auto* kernel = &gather_mm_kernel< - KMajorA, KMajorB, SM80, Element, - decltype(shape_MNKL), - decltype(cta_tiler), - decltype(dA), - decltype(dB), - decltype(dC), - decltype(mma)>; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); +namespace { - dim3 num_blocks{uint32_t(ceil_div(m, size<0>(cta_tiler))), - uint32_t(ceil_div(n, size<1>(cta_tiler))), - uint32_t(l)}; - dim3 block_dims{uint32_t(num_threads)}; - void* args[] = { - &shape_MNKL, - &cta_tiler, - &A, &dA, - &B, &dB, - &lhs_indices, &rhs_indices, - &C, &dC, - &mma}; - launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); +inline auto make_cta_tiler(int m) { + int tile_m = std::max(16, std::min(256, next_power_of_2(m))); + int tile_n = 128; + int tile_k = 64; + return cute::make_shape(tile_m, tile_n, tile_k); } -} // namespace cutlass_gemm - -// clang-format on - -namespace mlx::core { +} // namespace template inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { @@ -381,45 +45,70 @@ void gather_mm( const array& rhs_indices, array& out, cu::CommandEncoder& encoder) { - bool sm80 = encoder.device().compute_capability_major() >= 8; - int m = out.shape(-2); int n = out.shape(-1); int k = a.shape(-1); int l = out.size() / (m * n); + bool aligned = (k % 8 == 0); + bool sm80 = encoder.device().compute_capability_major() >= 8; + auto cta_tiler = make_cta_tiler(m); + + std::string module_name = fmt::format( + "gemm_sm70_{}{}_{}_{}_m{}_n{}_k{}", + a_transposed ? "n" : "t", + b_transposed ? "n" : "t", + dtype_to_string(out.dtype()), + aligned ? "aligned" : "unaligned", + cute::size<0>(cta_tiler), + cute::size<1>(cta_tiler), + cute::size<2>(cta_tiler)); + + std::string kernel_name = fmt::format( + "mlx::core::cu::gemm_sm70_kernel<{}, {}, {}, {}, {}, {}>", + !a_transposed, + b_transposed, + aligned, + sm80, + dtype_to_cutlass_type(out.dtype()), + cta_tiler_to_string(cta_tiler)); + + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { + return std::make_tuple( + false, jit_source_gemm_sm70, std::vector{kernel_name}); + }); + encoder.set_input_array(a); encoder.set_input_array(b); encoder.set_input_array(lhs_indices); encoder.set_input_array(rhs_indices); encoder.set_output_array(out); - dispatch_element_types(out.dtype(), "[gather_mm]", [&]() { - dispatch_bool(!a_transposed, [&](auto k_major_a) { - dispatch_bool(b_transposed, [&](auto k_major_b) { - dispatch_bool(sm80, [&](auto sm80) { - cutlass_gemm::gather_mm( - gpu_ptr(a), - gpu_ptr(b), - gpu_ptr(lhs_indices), - gpu_ptr(rhs_indices), - gpu_ptr(out), - m, - n, - k, - l, - [&](auto* kernel, - dim3 num_blocks, - dim3 block_dims, - uint32_t smem_bytes, - void** args) { - encoder.add_kernel_node_raw( - kernel, num_blocks, block_dims, {}, smem_bytes, args); - }); - }); - }); - }); - }); + dim3 num_blocks{ + uint32_t(cute::ceil_div(m, cute::size<0>(cta_tiler))), + uint32_t(cute::ceil_div(n, cute::size<1>(cta_tiler))), + uint32_t(l)}; + dim3 block_dims{uint32_t(cute::size(cu::make_tiled_mma(cta_tiler)))}; + + auto [sA_layout, sB_layout] = cu::make_smem_layouts(cta_tiler); + size_t smem_bytes = + out.itemsize() * (cute::cosize(sA_layout) + cute::cosize(sB_layout)); + + encoder.add_kernel_node_ex( + mod.get_kernel(kernel_name), + num_blocks, + block_dims, + {}, + smem_bytes, + gpu_ptr(a), + gpu_ptr(b), + gpu_ptr(lhs_indices), + gpu_ptr(rhs_indices), + gpu_ptr(out), + m, + n, + k, + l); } } // namespace mlx::core diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 0246eff4ca..3a493fd14e 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -253,9 +253,12 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "complex.cuh", INCLUDE_PREFIX "cute_dequant.cuh", INCLUDE_PREFIX "fp16_math.cuh", + INCLUDE_PREFIX "gemm_sm70.cuh", INCLUDE_PREFIX "hadamard.cuh", INCLUDE_PREFIX "indexing.cuh", INCLUDE_PREFIX "qmm_naive.cuh", + INCLUDE_PREFIX "qmm_sm80.cuh", + INCLUDE_PREFIX "qmm_sm90.cuh", INCLUDE_PREFIX "scatter_ops.cuh", INCLUDE_PREFIX "unary_ops.cuh", INCLUDE_PREFIX "ternary_ops.cuh", @@ -272,9 +275,12 @@ constexpr const char* g_headers[] = { jit_source_complex, jit_source_cute_dequant, jit_source_fp16_math, + jit_source_gemm_sm70, jit_source_hadamard, jit_source_indexing, jit_source_qmm_naive, + jit_source_qmm_sm80, + jit_source_qmm_sm90, jit_source_scatter_ops, jit_source_unary_ops, jit_source_ternary_ops, diff --git a/mlx/backend/cuda/quantized/qmm/CMakeLists.txt b/mlx/backend/cuda/quantized/qmm/CMakeLists.txt deleted file mode 100644 index 2fc2ece0b3..0000000000 --- a/mlx/backend/cuda/quantized/qmm/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -target_sources( - mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu - ${CMAKE_CURRENT_SOURCE_DIR}/qmm_naive.cu - ${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu - ${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu) - -foreach(TileN 16 32 64 128 256) - set(OUTPUT_FILE "qmm_sm90_impl_n${TileN}.cu") - configure_file("${CMAKE_CURRENT_SOURCE_DIR}/qmm_sm90.cu" - "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}" @ONLY) - target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}) -endforeach() - -foreach(TileM 16 32 64) - set(OUTPUT_FILE "qmm_sm80_impl_m${TileM}.cu") - configure_file("${CMAKE_CURRENT_SOURCE_DIR}/qmm_sm80.cu" - "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}" @ONLY) - target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}) -endforeach() diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index 403f9f189d..5c1d4f76b4 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -2,6 +2,7 @@ #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" +#include "mlx/backend/cuda/quantized/qmm/qmm_utils.h" #include @@ -16,21 +17,6 @@ inline bool is_last_2_dims_row_contiguous(const array& x) { } // namespace -#if defined(MLX_CUDA_SM90A_ENABLED) -// Defined in qmm_sm90.cu. -template -void qmm_sm90_impl( - const array& x, - const array& w, - const array& scales, - const array& biases, - array& out, - int bits, - int group_size, - cu::CommandEncoder& encoder, - Stream s); -#endif // defined(MLX_CUDA_SM90A_ENABLED) - bool supports_qmm_sm90( const array& x, const array& w, @@ -45,7 +31,10 @@ bool supports_qmm_sm90( if (device.compute_capability_major() != 9) { return false; } - int k = x.shape(-1); + auto [m, n, k, l, broadcast_b] = make_problem_shape(x, w, out); + if ((n * w.itemsize()) % 16 != 0) { // TMA alignment + return false; + } if (k % 64 != 0) { return false; } @@ -60,7 +49,7 @@ bool supports_qmm_sm90( if (!transpose) { return false; } - if (bits % 2 != 0) { + if (bits != 4 && bits != 8) { return false; } if (group_size < k) { @@ -72,54 +61,6 @@ bool supports_qmm_sm90( return true; } -void qmm_sm90( - const array& x, - const array& w, - const array& scales, - const array& biases, - array& out, - int bits, - int group_size, - cu::CommandEncoder& encoder, - Stream s) { -#if defined(MLX_CUDA_SM90A_ENABLED) - auto dispatch = [&]() { - qmm_sm90_impl( - x, w, scales, biases, out, bits, group_size, encoder, s); - }; - int m = out.ndim() > 1 ? out.shape(-2) : 1; - if (m <= 16) { - dispatch.template operator()<16>(); - } else if (m <= 32) { - dispatch.template operator()<32>(); - } else if (m <= 64) { - dispatch.template operator()<64>(); - } else if (m <= 128) { - dispatch.template operator()<128>(); - } else { - dispatch.template operator()<256>(); - } -#else - throw std::runtime_error( - "[quantized_matmul] Hopper-only kernel is not available."); -#endif // defined(MLX_CUDA_SM90A_ENABLED) -} - -// Defined in qmm_sm80.cu. -template -void qmm_sm80_impl( - const array& x, - const array& w, - const array& scales, - const std::optional& biases, - const std::optional& lhs_indices, - const std::optional& rhs_indices, - array& out, - int bits, - int group_size, - QuantizationMode mode, - cu::CommandEncoder& encoder); - bool supports_qmm_sm80( const array& x, const array& w, @@ -158,42 +99,6 @@ bool supports_qmm_sm80( return true; } -void qmm_sm80( - const array& x, - const array& w, - const array& scales, - const std::optional& biases, - const std::optional& lhs_indices, - const std::optional& rhs_indices, - array& out, - int bits, - int group_size, - QuantizationMode mode, - cu::CommandEncoder& encoder) { - auto dispatch = [&]() { - qmm_sm80_impl( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - out, - bits, - group_size, - mode, - encoder); - }; - int m = out.ndim() > 1 ? out.shape(-2) : 1; - if (m <= 16) { - dispatch.template operator()<16>(); - } else if (m <= 32) { - dispatch.template operator()<32>(); - } else { - dispatch.template operator()<64>(); - } -} - bool supports_qmm_naive( const array& x, const array& w, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu index c01fd2a639..cb47d7f1aa 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu @@ -1,10 +1,12 @@ // Copyright © 2026 Apple Inc. #include "mlx/backend/cuda/device/qmm_naive.cuh" + +#include "mlx/backend/cuda/cutlass_utils.cuh" #include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" -#include "mlx/dtype_utils.h" +#include "mlx/backend/cuda/quantized/qmm/qmm_utils.h" #include "cuda_jit_sources.h" @@ -20,60 +22,6 @@ inline auto make_cta_tiler(int itemsize, int m, int group_size, bool sm80) { return cute::make_shape(tile_m, tile_n, tile_k); } -inline auto cta_tiler_to_string(auto cta_tiler) { - return fmt::format( - "cute::Shape, cute::Int<{}>, cute::Int<{}>>", - cute::size<0>(cta_tiler), - cute::size<1>(cta_tiler), - cute::size<2>(cta_tiler)); -} - -const char* get_weight_cutlass_type(const Dtype& dtype) { - switch (dtype) { - case float16: - return "cutlass::half_t"; - case bfloat16: - return "cutlass::bfloat16_t"; - case float32: - return "float"; - default: - throw std::invalid_argument( - fmt::format( - "[quantized_matmul] Unsupported dtype: {}.", - dtype_to_string(dtype))); - } -} - -inline std::tuple -get_quant_cutlass_types(const char* ctype_x, int bits, QuantizationMode mode) { - if (mode == QuantizationMode::Mxfp4) { - return {"cutlass::float_e2m1_t", "cutlass::float_ue8m0_t"}; - } else if (mode == QuantizationMode::Mxfp8) { - return {"cutlass::float_e4m3_t", "cutlass::float_ue8m0_t"}; - } else if (mode == QuantizationMode::Nvfp4) { - return {"cutlass::float_e2m1_t", "cutlass::float_e4m3_t"}; - } else { - if (bits == 2) { - return {"cutlass::uint2b_t", ctype_x}; - } else if (bits == 3) { - return {"cutlass::uint3b_t", ctype_x}; - } else if (bits == 4) { - return {"cutlass::uint4b_t", ctype_x}; - } else if (bits == 5) { - return {"cutlass::uint5b_t", ctype_x}; - } else if (bits == 6) { - return {"cutlass::uint6b_t", ctype_x}; - } else if (bits == 8) { - return {"uint8_t", ctype_x}; - } else { - throw std::invalid_argument( - fmt::format( - "[quantized_matmul] {}-bit quantization is not supported.", - bits)); - } - } -} - } // namespace void qmm_naive( @@ -89,29 +37,22 @@ void qmm_naive( int group_size, QuantizationMode mode, cu::CommandEncoder& encoder) { - int m = out.ndim() > 1 ? out.shape(-2) : 1; - int n = out.shape(-1); - int k = x.shape(-1); - int l = out.size() / (m * n); - bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size()); - + auto [m, n, k, l, broadcast_b] = make_problem_shape(x, w, out); bool sm80 = encoder.device().compute_capability_major() >= 8; auto cta_tiler = make_cta_tiler(x.itemsize(), m, group_size, sm80); bool has_k_residue = (k % cute::size<2>(cta_tiler)) != 0; std::string module_name = fmt::format( - "qmm_naive_{}_{}_{}_m{}_b{}_g{}_{}", - dtype_to_string(x.dtype()), - transpose ? "k" : "n", + "qmm_naive_t{}_{}_{}_m{}_b{}_g{}_{}", + transpose ? "n" : "t", has_k_residue ? "residue" : "aligned", + dtype_to_string(x.dtype()), cute::size<0>(cta_tiler), bits, group_size, quantization_mode_to_string(mode)); - auto ctype_x = get_weight_cutlass_type(x.dtype()); - auto [ctype_q, ctype_s] = get_quant_cutlass_types(ctype_x, bits, mode); - + auto [ctype_x, ctype_q, ctype_s] = get_qmm_cutlass_types(x, bits, mode); std::string kernel_name = fmt::format( "mlx::core::cu::qmm_naive_kernel<{}, {}, {}, {}, {}, {}, {}, {}>", group_size, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_sm80.cu b/mlx/backend/cuda/quantized/qmm/qmm_sm80.cu index 7f76c8cc04..9818239561 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_sm80.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm_sm80.cu @@ -1,163 +1,29 @@ // Copyright © 2026 Apple Inc. +#include "mlx/backend/cuda/device/qmm_sm80.cuh" + +#include "mlx/backend/cuda/cutlass_utils.cuh" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" -#include "mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh" - -// clang-format off - -// We can't put kernel code in mlx::core due to name conflicts of "Shape". -namespace cutlass_gemm { - -using namespace cute; - -template -__global__ -__launch_bounds__(decltype(size(TiledMma{}))::value) -void qmm_sm80_kernel( - ProblemShape shape_MNKL, CtaTiler cta_tiler, - const Element* A, StrideA dA, - const Quant* B, StrideB dB, - const Scale* S, const Element* Z, LayoutS S_layout, - const uint32_t* lhs_indices, const uint32_t* rhs_indices, - Element* C, StrideC dC, - TiledMma mma) { - CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); - CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB)); - CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC)); - - int thread_idx = int(threadIdx.x); - int m_coord = int(blockIdx.x); - int n_coord = int(blockIdx.y); - int l_coord = int(blockIdx.z); - - // For gather, use index lookup for input batch slicing. - uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; - uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; - - // Represent the full tensors. - Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) - Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) - Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L) - - Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) - Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) - - // Get batch slice. - Tensor mA = mA_mkl(_,_,a_batch); // (M,K) - Tensor mB = mB_nkl(_,_,b_batch); // (N,K) - Tensor mC = mC_mnl(_,_,l_coord); // (M,N) - - Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) - Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) - - // Get the appropriate blocks for this thread block. - auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) - Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) - Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) - - Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - Tensor gZ = local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - - // Compute tile residues for predication. - auto m_max_coord = size<0>(shape_MNKL) - size<0>(gA) * m_coord; // M - BLK_M * m_coord - - qmm_sm80_mainloop( - cta_tiler, - gA, - gB, - gS, - gZ, - gC, - mma, - m_max_coord, - thread_idx); -} +#include "mlx/backend/cuda/quantized/qmm/qmm_utils.h" -template -void qmm_sm80( - const Element* A, - const Quant* B, - const Scale* S, - const Element* Z, - const uint32_t* lhs_indices, - const uint32_t* rhs_indices, - Element* C, - int m, int n, int k, int l, - bool broadcast_b, - auto group_size, - auto&& launch_kernel) { - // Define shapes (dynamic). - auto shape_MNKL = make_shape(m, n, k, l); // (M,N,K,L) - - // Define layouts (mixed). - auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL) - auto dB = make_stride(k, Int<1>{}, n * k); // (dN,dK,dL) - auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) - auto S_layout = make_scales_layout(n, k, l, group_size); - - // Handle broadcasting. - if (broadcast_b) { - get<2>(dB) = 0; - get<2>(stride(S_layout)) = 0; - } +#include "cuda_jit_sources.h" - // Define CTA tile sizes (static). - auto cta_tiler = make_cta_tiler(group_size); - - // Define MMA. - TiledMMA mma = make_tiled_mma(); - auto num_threads = size(mma); - - // Shared memory size. - auto [sA_layout, sB_layout, sC_layout] = make_smem_layouts(cta_tiler); - size_t smem_bytes = sizeof(SharedStorage); - - auto* kernel = &qmm_sm80_kernel< - Element, Quant, Scale, - decltype(shape_MNKL), - decltype(cta_tiler), - decltype(dA), - decltype(dB), - decltype(S_layout), - decltype(dC), - decltype(mma)>; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - - dim3 num_blocks{uint32_t(ceil_div(m, size<0>(cta_tiler))), - uint32_t(ceil_div(n, size<1>(cta_tiler))), - uint32_t(l)}; - dim3 block_dims{uint32_t(num_threads)}; - void* args[] = { - &shape_MNKL, &cta_tiler, - &A, &dA, - &B, &dB, - &S, &Z, &S_layout, - &lhs_indices, &rhs_indices, - &C, &dC, - &mma}; - launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); -} +namespace mlx::core { -} // namespace cutlass_gemm +namespace { -// clang-format on +inline auto make_cta_tiler(int m, int group_size) { + int tile_m = std::max(16, std::min(64, next_power_of_2(m))); + int tile_n = 128; + int tile_k = std::max(64, group_size); + return cute::make_shape(tile_m, tile_n, tile_k); +} -namespace mlx::core { +} // namespace -template -void qmm_sm80_impl( +void qmm_sm80( const array& x, const array& w, const array& scales, @@ -169,72 +35,82 @@ void qmm_sm80_impl( int group_size, QuantizationMode mode, cu::CommandEncoder& encoder) { - const char* tag = "[quantized_matmul]"; - int m = out.ndim() > 1 ? out.shape(-2) : 1; - int n = out.shape(-1); - int k = x.shape(-1); - int l = out.size() / (m * n); - bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size()); - - dispatch_element_types(out.dtype(), tag, [&]() { - dispatch_quant_types( - bits, - group_size, - mode, - tag, - [&]() { - encoder.set_input_array(x); - encoder.set_input_array(w); - encoder.set_input_array(scales); - if (biases) { - encoder.set_input_array(*biases); - } - if (lhs_indices) { - encoder.set_input_array(*lhs_indices); - } - if (rhs_indices) { - encoder.set_input_array(*rhs_indices); - } - encoder.set_output_array(out); - cutlass_gemm::qmm_sm80( - gpu_ptr(x), - gpu_ptr(w), - gpu_ptr(scales), - biases ? gpu_ptr(*biases) : nullptr, - lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, - rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, - gpu_ptr(out), - m, - n, - k, - l, - broadcast_b, - cute::Int{}, - [&](auto* kernel, - dim3 num_blocks, - dim3 block_dims, - size_t smem_bytes, - void** args) { - encoder.add_kernel_node_raw( - kernel, num_blocks, block_dims, {}, smem_bytes, args); - }); - }); + auto [m, n, k, l, broadcast_b] = make_problem_shape(x, w, out); + auto cta_tiler = make_cta_tiler(m, group_size); + + std::string module_name = fmt::format( + "qmm_sm80_tn_{}_m{}_b{}_g{}_{}", + dtype_to_string(x.dtype()), + cute::size<0>(cta_tiler), + bits, + group_size, + quantization_mode_to_string(mode)); + + auto [ctype_x, ctype_q, ctype_s] = get_qmm_cutlass_types(x, bits, mode); + std::string kernel_name = fmt::format( + "mlx::core::cu::qmm_sm80_kernel<{}, {}, {}, {}, {}>", + group_size, + ctype_x, + ctype_q, + ctype_s, + cta_tiler_to_string(cta_tiler)); + + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { + return std::make_tuple( + false, jit_source_qmm_sm80, std::vector{kernel_name}); }); -} -// clang-format off -template void qmm_sm80_impl<@TileM@>( - const array& x, - const array& w, - const array& scales, - const std::optional& biases, - const std::optional& lhs_indices, - const std::optional& rhs_indices, - array& out, - int bits, - int group_size, - QuantizationMode mode, - cu::CommandEncoder& encoder); -// clang-format on + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + if (biases) { + encoder.set_input_array(*biases); + } + if (lhs_indices) { + encoder.set_input_array(*lhs_indices); + } + if (rhs_indices) { + encoder.set_input_array(*rhs_indices); + } + encoder.set_output_array(out); + + dim3 num_blocks{ + uint32_t(cute::ceil_div(m, cute::size<0>(cta_tiler))), + uint32_t(cute::ceil_div(n, cute::size<1>(cta_tiler))), + uint32_t(l)}; + dim3 block_dims{uint32_t(cute::size(cu::make_tiled_mma()))}; + + auto [sA_layout, sB_layout, sC_layout] = cu::make_smem_layouts(cta_tiler); + size_t smem_bytes = std::max( + cute::cosize(sA_layout) * x.itemsize() + + cute::cosize(sB_layout) * bits / 8, + cute::cosize(sC_layout) * x.itemsize()); + + auto kernel = mod.get_kernel(kernel_name, [&](CUfunction kernel) { + if (smem_bytes > 48000) { + cuFuncSetAttribute( + kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_bytes); + } + }); + + encoder.add_kernel_node_ex( + kernel, + num_blocks, + block_dims, + {}, + smem_bytes, + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + biases ? gpu_ptr(*biases) : nullptr, + lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, + rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, + gpu_ptr(out), + m, + n, + k, + l, + broadcast_b); +} } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qmm/qmm_sm90.cu b/mlx/backend/cuda/quantized/qmm/qmm_sm90.cu index e9425d8392..7bc5d83437 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_sm90.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm_sm90.cu @@ -1,131 +1,20 @@ // Copyright © 2026 Apple Inc. +#include "mlx/backend/cuda/device/qmm_sm90.cuh" + #include "mlx/backend/cuda/cutlass_utils.cuh" -#include "mlx/backend/cuda/quantized/quantized_utils.h" +#include "mlx/backend/cuda/jit_module.h" +#include "mlx/backend/cuda/quantized/qmm/qmm.h" +#include "mlx/backend/cuda/quantized/qmm/qmm_utils.h" #include "mlx/backend/gpu/copy.h" -#include "mlx/dtype_utils.h" - -#include -#include -#include -#include -#include -#include -#if defined(MLX_CUDA_SM90A_ENABLED) +#include "cuda_jit_sources.h" -// We can't put kernel code in mlx::core due to name conflicts of "Shape". -namespace cutlass_gemm { +namespace mlx::core { using namespace cute; -template < - int TileN = 16, - typename Element, - typename Quant, - typename GroupSize, - typename F> -void qmm_sm90( - const Element* A, - const Quant* B, - const Element* S, - const Element* Z, - Element* D, - int64_t m, - int64_t n, - int64_t k, - int64_t l, - bool broadcast_b, - GroupSize group_size, - F&& launch_kernel) { - constexpr int kAlignmentA = 128 / sizeof_bits::value; - constexpr int kAlignmentB = 128 / sizeof_bits::value; - constexpr int kTileShapeK = - std::max(64, 128 * 8 / sizeof_bits::value); - static_assert(group_size % kTileShapeK == 0); - - using Arch = cutlass::arch::Sm90; - using Accumulator = float; - using TileShape = Shape<_128, Int, Int>; - using ClusterShape = Shape, _1, _1>; - - using Epilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - Arch, - cutlass::arch::OpClassTensorOp, - TileShape, - ClusterShape, - cutlass::epilogue::collective::EpilogueTileAuto, - Accumulator, - Accumulator, - // ElementC: - void, - cutlass::layout::ColumnMajor, - kAlignmentA, - // ElementD: - Element, - cutlass::layout::ColumnMajor, - kAlignmentA, - cutlass::epilogue::TmaWarpSpecializedCooperative>::CollectiveOp; - - // Note that A/B are swapped and transposed to use TMA epilogue. - using Mainloop = typename cutlass::gemm::collective::CollectiveBuilder< - Arch, - cutlass::arch::OpClassTensorOp, - // ElementA: - tuple, - cutlass::layout::RowMajor, - kAlignmentB, - // ElementB: - Element, - cutlass::layout::ColumnMajor, - kAlignmentA, - Accumulator, - TileShape, - ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename Epilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedCooperative>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel:: - GemmUniversal, Mainloop, Epilogue>; - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - - auto dA = make_stride(k, Int<1>{}, m * k); - auto dB = make_stride(k, Int<1>{}, n * k); - auto dS = make_stride(Int<1>{}, n, n * k / group_size); - auto dD = make_stride(Int<1>{}, n, m * n); - if (broadcast_b) { - get<2>(dB) = 0; - get<2>(dS) = 0; - } - - Gemm gemm; - typename Gemm::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGemm, - {int(n), int(m), int(k), int(l)}, - {B, dB, A, dA, S, dS, group_size, Z}, - {{1.f, 0.f}, D, dD, D, dD}}; - - CHECK_CUTLASS_ERROR(gemm.can_implement(args)); - CHECK_CUTLASS_ERROR(gemm.initialize(args, nullptr)); - - auto* kernel = &cutlass::device_kernel; - void* kernel_params[] = {const_cast(&gemm.params())}; - auto cluster = ClusterShape{}; - launch_kernel( - reinterpret_cast(kernel), - gemm.get_grid_shape(gemm.params()), - GemmKernel::get_block_shape(), - {static_cast(get<0>(cluster)), - static_cast(get<1>(cluster)), - static_cast(get<2>(cluster))}, - GemmKernel::SharedStorageSize, - kernel_params); -} - -} // namespace cutlass_gemm - -namespace mlx::core { +namespace { inline array transpose_last_2_dims( const array& x, @@ -168,17 +57,61 @@ inline void dispatch_quant_types(int bits, const char* tag, F&& f) { template inline void dispatch_groups(int group_size, const char* tag, F&& f) { if (group_size == 64) { - f(cute::Int<64>{}); + f.template operator()<64>(); } else if (group_size == 128) { - f(cute::Int<128>{}); + f.template operator()<128>(); } else { throw std::invalid_argument( fmt::format("{} Group size {} is not supported.", tag, group_size)); } } -template -void qmm_sm90_impl( +template +inline void dispatch_tile(int m, F&& f) { + if (m <= 16) { + f.template operator()<16>(); + } else if (m <= 32) { + f.template operator()<32>(); + } else if (m <= 64) { + f.template operator()<64>(); + } else if (m <= 128) { + f.template operator()<128>(); + } else { + f.template operator()<256>(); + } +} + +template +inline void dispatch_gemm( + const array& x, + int n, + int bits, + int group_size, + const char* tag, + F&& f) { + dispatch_element_types(x.dtype(), tag, [&]() { + dispatch_tile(n, [&]() { + dispatch_quant_types(bits, tag, [&]() { + dispatch_groups(group_size, tag, [&]() { + auto cta_tiler = make_shape( + Int<128>{}, + Int{}, + Int)>{}); + auto gemm = cu::make_qmm_sm90_kernel< + GroupSize, + Element, + Quant, + decltype(cta_tiler)>(); + f(cta_tiler, gemm); + }); + }); + }); + }); +} + +} // namespace + +void qmm_sm90( const array& x, const array& w, const array& scales_, @@ -189,68 +122,86 @@ void qmm_sm90_impl( cu::CommandEncoder& encoder, Stream s) { const char* tag = "[quantized_matmul]"; - int m = out.ndim() > 1 ? out.shape(-2) : 1; - int n = out.shape(-1); - int k = x.shape(-1); - int l = out.size() / (m * n); - bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size()); + auto [m, n, k, l, broadcast_b] = make_problem_shape(x, w, out); + + auto dA = make_stride(int64_t(k), Int<1>{}, int64_t(m * k)); + auto dB = make_stride(int64_t(k), Int<1>{}, int64_t(n * k)); + auto dS = make_stride(Int<1>{}, int64_t(n), int64_t(n * k / group_size)); + auto dD = make_stride(Int<1>{}, int64_t(n), int64_t(m * n)); + if (broadcast_b) { + get<2>(dB) = 0; + get<2>(dS) = 0; + } // FIXME: Copy happens for every call. array scales = transpose_last_2_dims(scales_, encoder, s); array biases = transpose_last_2_dims(biases_, encoder, s); - dispatch_element_types(out.dtype(), tag, [&]() { - dispatch_quant_types(bits, tag, [&]() { - dispatch_groups(group_size, tag, [&](auto group_size) { - encoder.set_input_array(x); - encoder.set_input_array(w); - encoder.set_input_array(scales); - encoder.set_input_array(biases); - encoder.set_output_array(out); - cutlass_gemm::qmm_sm90( - gpu_ptr(x), - gpu_ptr(w), - gpu_ptr(scales), - gpu_ptr(biases), - gpu_ptr(out), - m, - n, - k, - l, - broadcast_b, - group_size, - [&](auto* kernel, - dim3 num_blocks, - dim3 block_dims, - dim3 cluster_shape, - uint32_t smem_bytes, - void** args) { - encoder.add_kernel_node_raw( - kernel, - num_blocks, - block_dims, - cluster_shape, - smem_bytes, - args); - }); - }); + dispatch_gemm(x, n, bits, group_size, tag, [&](auto cta_tiler, auto gemm) { + // JIT compilation. + std::string module_name = fmt::format( + "qmm_sm90_tn_{}_n{}_b{}_g{}_affine", + dtype_to_string(x.dtype()), + int(size<1>(cta_tiler)), + bits, + group_size); + + auto [ctype_x, ctype_q, ctype_s] = get_qmm_cutlass_types(x, bits); + std::string kernel_name = fmt::format( + "cutlass::device_kernel>", + group_size, + ctype_x, + ctype_q, + cta_tiler_to_string(cta_tiler)); + + cu::JitModule& mod = + cu::get_jit_module(encoder.device(), module_name, [&]() { + return std::make_tuple( + false, jit_source_qmm_sm90, std::vector{kernel_name}); + }); + + // Prepare kernel args. + using Gemm = decltype(gemm); + + const auto* A = gpu_ptr(x); + const auto* B = gpu_ptr(w); + const auto* S = + gpu_ptr(scales); + const auto* Z = + gpu_ptr(biases); + auto* D = gpu_ptr(out); + + auto params = Gemm::to_underlying_arguments( + {cutlass::gemm::GemmUniversalMode::kGemm, + {n, m, k, l}, + {B, dB, A, dA, S, dS, group_size, Z}, + {{1.f, 0.f}, D, dD, D, dD}}, + nullptr); + + size_t smem_bytes = Gemm::SharedStorageSize; + auto kernel = mod.get_kernel(kernel_name, [&](CUfunction kernel) { + if (smem_bytes > 48000) { + cuFuncSetAttribute( + kernel, + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + smem_bytes); + } }); + + // Append to CUDA graph. + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + encoder.set_input_array(biases); + encoder.set_output_array(out); + encoder.add_kernel_node_ex( + kernel, + Gemm::get_grid_shape(params), + Gemm::get_block_shape(), + {}, + smem_bytes, + params); }); } -// clang-format off -template void qmm_sm90_impl<@TileN@>( - const array& x, - const array& w, - const array& scales, - const array& biases, - array& out, - int bits, - int group_size, - cu::CommandEncoder& encoder, - Stream s); -// clang-format on - } // namespace mlx::core - -#endif // defined(MLX_CUDA_SM90A_ENABLED) diff --git a/mlx/backend/cuda/quantized/qmm/qmm_utils.h b/mlx/backend/cuda/quantized/qmm/qmm_utils.h new file mode 100644 index 0000000000..8e54e2e053 --- /dev/null +++ b/mlx/backend/cuda/quantized/qmm/qmm_utils.h @@ -0,0 +1,72 @@ +// Copyright © 2026 Apple Inc. + +#include "mlx/dtype_utils.h" + +namespace mlx::core { + +inline auto +make_problem_shape(const array& x, const array& w, const array& out) { + int m = out.ndim() > 1 ? out.shape(-2) : 1; + int n = out.shape(-1); + int k = x.shape(-1); + int l = out.size() / (m * n); + bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size()); + return std::make_tuple(m, n, k, l, broadcast_b); +} + +inline const char* get_weight_cutlass_type(const Dtype& dtype) { + switch (dtype) { + case float16: + return "cutlass::half_t"; + case bfloat16: + return "cutlass::bfloat16_t"; + case float32: + return "float"; + default: + throw std::invalid_argument( + fmt::format( + "[quantized_matmul] Unsupported dtype: {}.", + dtype_to_string(dtype))); + } +} + +inline std::tuple +get_quant_cutlass_types(const char* ctype_x, int bits, QuantizationMode mode) { + if (mode == QuantizationMode::Mxfp4) { + return {"cutlass::float_e2m1_t", "cutlass::float_ue8m0_t"}; + } else if (mode == QuantizationMode::Mxfp8) { + return {"cutlass::float_e4m3_t", "cutlass::float_ue8m0_t"}; + } else if (mode == QuantizationMode::Nvfp4) { + return {"cutlass::float_e2m1_t", "cutlass::float_e4m3_t"}; + } else { + if (bits == 2) { + return {"cutlass::uint2b_t", ctype_x}; + } else if (bits == 3) { + return {"cutlass::uint3b_t", ctype_x}; + } else if (bits == 4) { + return {"cutlass::uint4b_t", ctype_x}; + } else if (bits == 5) { + return {"cutlass::uint5b_t", ctype_x}; + } else if (bits == 6) { + return {"cutlass::uint6b_t", ctype_x}; + } else if (bits == 8) { + return {"uint8_t", ctype_x}; + } else { + throw std::invalid_argument( + fmt::format( + "[quantized_matmul] {}-bit quantization is not supported.", + bits)); + } + } +} + +inline std::tuple get_qmm_cutlass_types( + const array& x, + int bits, + QuantizationMode mode = QuantizationMode::Affine) { + auto ctype_x = get_weight_cutlass_type(x.dtype()); + auto [ctype_q, ctype_s] = get_quant_cutlass_types(ctype_x, bits, mode); + return {ctype_x, ctype_q, ctype_s}; +} + +} // namespace mlx::core diff --git a/python/tests/test_quantized.py b/python/tests/test_quantized.py index 8f199dea14..70bae250e5 100644 --- a/python/tests/test_quantized.py +++ b/python/tests/test_quantized.py @@ -369,8 +369,8 @@ def test_qmv(self): with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): x_shape = (3, 1, N) if B == 0 else (B, 1, N) w_shape = (M, N) if B == 0 else (B, M, N) - x = mx.random.normal(shape=x_shape, key=k1) - w = mx.random.normal(shape=w_shape, key=k2) + x = mx.random.normal(shape=x_shape, key=k1) / N**0.5 + w = mx.random.normal(shape=w_shape, key=k2) / N**0.5 w_q, scales, biases = mx.quantize(w, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) y_q = mx.quantized_matmul(