diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 8606b23b2d..1c29464d51 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -99,7 +99,12 @@ __global__ void rope_single_freqs( in, out, *offset, inv_freq, scale, stride, pos, dims); } -template +template < + typename T, + bool traditional, + bool forward, + bool hs_transpose, + int N = 4> __device__ void rope_impl( const T* in, T* out, @@ -124,21 +129,32 @@ __device__ void rope_impl( float costheta = cos(theta); float sintheta = sin(theta); + size_t in_batch_head; + size_t out_batch_head; + if (hs_transpose) { + int64_t in_batch_stride = static_cast(dims.y) * strides[1]; + int64_t out_batch_stride = static_cast(dims.y) * out_strides[1]; + in_batch_head = batch_idx * in_batch_stride + head_idx * strides[0]; + out_batch_head = batch_idx * out_batch_stride + head_idx * out_strides[0]; + } else { + in_batch_head = mat_idx * strides[0]; + out_batch_head = mat_idx * out_strides[0]; + } + // Compute the input and output indices size_t in_index_1, in_index_2; size_t out_index_1, out_index_2; if (traditional) { - out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + - mat_idx * out_strides[0]; - out_index_2 = out_index_1 + 1; - in_index_1 = - 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; + out_index_1 = + 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_batch_head; + out_index_2 = out_index_1 + out_strides[2]; + in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + in_batch_head; in_index_2 = in_index_1 + strides[2]; } else { - out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + - mat_idx * out_strides[0]; + out_index_1 = + pos.x * out_strides[2] + pos.y * out_strides[1] + out_batch_head; out_index_2 = out_index_1 + dims.x * out_strides[2]; - in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; + in_index_1 = pos.x * strides[2] + pos.y * strides[1] + in_batch_head; in_index_2 = in_index_1 + dims.x * strides[2]; } for (int i = 0; i < N && head_idx + i < n_head; ++i) { @@ -163,7 +179,7 @@ __device__ void rope_impl( } } -template +template __global__ void rope( const T* in, T* out, @@ -185,7 +201,7 @@ __global__ void rope( float d = static_cast(pos.x) / static_cast(dims.x); float inv_freq = exp2(-d * base); - rope_impl( + rope_impl( in, out, offset, @@ -199,7 +215,7 @@ __global__ void rope( dims); } -template +template __global__ void rope_freqs( const T* in, T* out, @@ -222,7 +238,7 @@ __global__ void rope_freqs( } float inv_freq = 1.0 / freqs[freq_stride * pos.x]; - rope_impl( + rope_impl( in, out, offset, @@ -258,6 +274,7 @@ void RoPE::eval_gpu( cuda::std::array strides; cuda::std::array out_strides; bool donated = false; + bool head_seq_transpose = false; int ndim = in.ndim(); int B = in.shape(0); @@ -303,6 +320,27 @@ void RoPE::eval_gpu( strides[0] = in.strides()[ndim - 3]; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; + } else if ( + ndim == 4 && + // batch dim is regularly strided + in.strides()[0] == static_cast(T) * N * D && + // sequence and head dimensions are transposed (x.swapaxes(1, 2)) + in.strides()[1] == D && in.strides()[2] == static_cast(N) * D && + in.strides()[3] == 1) { + head_seq_transpose = true; + if (in.is_donatable()) { + donated = true; + out.copy_shared_buffer(in); + } else { + out.set_data( + cu::malloc_async(in.data_size() * in.itemsize(), encoder), + in.data_size(), + in.strides(), + in.flags()); + } + strides[0] = in.strides()[1]; + strides[1] = in.strides()[2]; + strides[2] = in.strides()[3]; } else { // Copy non-contiguous > 3D inputs into the output and treat // input as donated @@ -312,9 +350,15 @@ void RoPE::eval_gpu( strides[1] = out.strides()[ndim - 2]; strides[2] = out.strides()[ndim - 1]; } - out_strides[0] = mat_size; - out_strides[1] = out.strides()[ndim - 2]; - out_strides[2] = out.strides()[ndim - 1]; + if (head_seq_transpose) { + out_strides[0] = strides[0]; + out_strides[1] = strides[1]; + out_strides[2] = strides[2]; + } else { + out_strides[0] = mat_size; + out_strides[1] = out.strides()[ndim - 2]; + out_strides[2] = out.strides()[ndim - 1]; + } // Some flags to help us dispatch below bool single = in.flags().row_contiguous && B == 1 && T == 1; @@ -327,94 +371,103 @@ void RoPE::eval_gpu( } encoder.set_output_array(out); dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + using DataType = cuda_type_t; dispatch_bool(traditional_, [&](auto traditional) { dispatch_bool(forward_, [&](auto forward) { - using DataType = cuda_type_t; - if (single && !with_freqs) { - auto kernel = - cu::rope_single; - uint2 dims = make_uint2(dims_ / 2, N); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(donated ? out : in), - gpu_ptr(out), - gpu_ptr(offset), - scale_, - std::log2(base_), - mat_size, - dims); - } else if (single) { - auto kernel = - cu::rope_single_freqs; - uint2 dims = make_uint2(dims_ / 2, N); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(donated ? out : in), - gpu_ptr(out), - gpu_ptr(offset), - gpu_ptr(inputs[2]), - scale_, - mat_size, - dims, - inputs[2].strides(0)); - } else if (with_freqs) { - auto kernel = - cu::rope_freqs; - int n_per_thread = 4; - uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); - uint3 dims = make_uint3(dims_ / 2, T, dimz); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - int64_t offset_stride = 0; - if (inputs[1].ndim() > 0) { - offset_stride = inputs[1].strides()[0]; - } - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(donated ? out : in), - gpu_ptr(out), - gpu_ptr(offset), - gpu_ptr(inputs[2]), - scale_, - std::log2(base_), - strides, - out_strides, - offset_stride, - N, - dims, - inputs[2].strides(0)); - } else { - auto kernel = cu::rope; - int n_per_thread = 4; - uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); - uint3 dims = make_uint3(dims_ / 2, T, dimz); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - int64_t offset_stride = 0; - if (inputs[1].ndim() > 0) { - offset_stride = inputs[1].strides()[0]; + dispatch_bool(head_seq_transpose, [&](auto hs_transpose) { + if (single && !with_freqs) { + auto kernel = + cu::rope_single; + uint2 dims = make_uint2(dims_ / 2, N); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + scale_, + std::log2(base_), + mat_size, + dims); + } else if (single) { + auto kernel = cu:: + rope_single_freqs; + uint2 dims = make_uint2(dims_ / 2, N); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + gpu_ptr(inputs[2]), + scale_, + mat_size, + dims, + inputs[2].strides(0)); + } else if (with_freqs) { + auto kernel = cu::rope_freqs< + DataType, + traditional.value, + forward.value, + hs_transpose.value>; + int n_per_thread = 4; + uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); + uint3 dims = make_uint3(dims_ / 2, T, dimz); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + int64_t offset_stride = 0; + if (inputs[1].ndim() > 0) { + offset_stride = inputs[1].strides()[0]; + } + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + gpu_ptr(inputs[2]), + scale_, + std::log2(base_), + strides, + out_strides, + offset_stride, + N, + dims, + inputs[2].strides(0)); + } else { + auto kernel = cu::rope< + DataType, + traditional.value, + forward.value, + hs_transpose.value>; + int n_per_thread = 4; + uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); + uint3 dims = make_uint3(dims_ / 2, T, dimz); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + int64_t offset_stride = 0; + if (inputs[1].ndim() > 0) { + offset_stride = inputs[1].strides()[0]; + } + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + scale_, + std::log2(base_), + strides, + out_strides, + offset_stride, + N, + dims); } - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(donated ? out : in), - gpu_ptr(out), - gpu_ptr(offset), - scale_, - std::log2(base_), - strides, - out_strides, - offset_stride, - N, - dims); - } + }); }); }); });