From 7ed686c05573148e473cb62c3ed6860e33fabafc Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 15 Jun 2026 07:45:13 -0700 Subject: [PATCH 1/3] rope with out copy for strided input --- benchmarks/python/rope_bench.py | 17 +++ mlx/backend/cuda/rope.cu | 222 +++++++++++++++++++------------- 2 files changed, 146 insertions(+), 93 deletions(-) diff --git a/benchmarks/python/rope_bench.py b/benchmarks/python/rope_bench.py index 35479c0b1a..b35fe902fd 100644 --- a/benchmarks/python/rope_bench.py +++ b/benchmarks/python/rope_bench.py @@ -31,5 +31,22 @@ def rope_mat(x): time_fn(rope_mat, x) +def time_rope_hs_transposed(): + rope = nn.RoPE(128) + + # matrix + x = mx.random.uniform(shape=(8, 8192, 8, 128)).astype(mx.float16) + mx.eval(x) + x = x.transpose(0, 2, 1, 3) + + def rope_transposed(x): + for _ in range(32): + x = rope(x) + return x + + time_fn(rope_transposed, x) + + if __name__ == "__main__": time_rope() + time_rope_hs_transposed() diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 8606b23b2d..2c27ab1846 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -99,7 +99,12 @@ __global__ void rope_single_freqs( in, out, *offset, inv_freq, scale, stride, pos, dims); } -template +template < + typename T, + bool traditional, + bool forward, + bool hs_transpose, + int N = 4> __device__ void rope_impl( const T* in, T* out, @@ -124,6 +129,16 @@ __device__ void rope_impl( float costheta = cos(theta); float sintheta = sin(theta); + // For the swapaxes(1,2) layout the (batch, head) pair is not a single + // linear stride: batch_stride = T * seq_stride, head_stride = strides[0]. + size_t in_batch_head; + if (hs_transpose) { + int64_t batch_stride = static_cast(dims.y) * strides[1]; + in_batch_head = batch_idx * batch_stride + head_idx * strides[0]; + } else { + in_batch_head = mat_idx * strides[0]; + } + // Compute the input and output indices size_t in_index_1, in_index_2; size_t out_index_1, out_index_2; @@ -131,14 +146,13 @@ __device__ void rope_impl( out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + mat_idx * out_strides[0]; out_index_2 = out_index_1 + 1; - in_index_1 = - 2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; + in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + in_batch_head; in_index_2 = in_index_1 + strides[2]; } else { out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + mat_idx * out_strides[0]; out_index_2 = out_index_1 + dims.x * out_strides[2]; - in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0]; + in_index_1 = pos.x * strides[2] + pos.y * strides[1] + in_batch_head; in_index_2 = in_index_1 + dims.x * strides[2]; } for (int i = 0; i < N && head_idx + i < n_head; ++i) { @@ -163,7 +177,7 @@ __device__ void rope_impl( } } -template +template __global__ void rope( const T* in, T* out, @@ -185,7 +199,7 @@ __global__ void rope( float d = static_cast(pos.x) / static_cast(dims.x); float inv_freq = exp2(-d * base); - rope_impl( + rope_impl( in, out, offset, @@ -199,7 +213,7 @@ __global__ void rope( dims); } -template +template __global__ void rope_freqs( const T* in, T* out, @@ -222,7 +236,7 @@ __global__ void rope_freqs( } float inv_freq = 1.0 / freqs[freq_stride * pos.x]; - rope_impl( + rope_impl( in, out, offset, @@ -258,6 +272,7 @@ void RoPE::eval_gpu( cuda::std::array strides; cuda::std::array out_strides; bool donated = false; + bool head_seq_transpose = false; int ndim = in.ndim(); int B = in.shape(0); @@ -303,6 +318,18 @@ void RoPE::eval_gpu( strides[0] = in.strides()[ndim - 3]; strides[1] = in.strides()[ndim - 2]; strides[2] = in.strides()[ndim - 1]; + } else if ( + ndim == 4 && + // batch dim is regularly strided + in.strides()[0] == static_cast(T) * N * D && + // sequence and head dimensions are transposed (x.swapaxes(1, 2)) + in.strides()[1] == D && in.strides()[2] == static_cast(N) * D && + in.strides()[3] == 1) { + head_seq_transpose = true; + out.set_data(cu::malloc_async(out.nbytes(), encoder)); + strides[0] = in.strides()[1]; + strides[1] = in.strides()[2]; + strides[2] = in.strides()[3]; } else { // Copy non-contiguous > 3D inputs into the output and treat // input as donated @@ -327,94 +354,103 @@ void RoPE::eval_gpu( } encoder.set_output_array(out); dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) { + using DataType = cuda_type_t; dispatch_bool(traditional_, [&](auto traditional) { dispatch_bool(forward_, [&](auto forward) { - using DataType = cuda_type_t; - if (single && !with_freqs) { - auto kernel = - cu::rope_single; - uint2 dims = make_uint2(dims_ / 2, N); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(donated ? out : in), - gpu_ptr(out), - gpu_ptr(offset), - scale_, - std::log2(base_), - mat_size, - dims); - } else if (single) { - auto kernel = - cu::rope_single_freqs; - uint2 dims = make_uint2(dims_ / 2, N); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(donated ? out : in), - gpu_ptr(out), - gpu_ptr(offset), - gpu_ptr(inputs[2]), - scale_, - mat_size, - dims, - inputs[2].strides(0)); - } else if (with_freqs) { - auto kernel = - cu::rope_freqs; - int n_per_thread = 4; - uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); - uint3 dims = make_uint3(dims_ / 2, T, dimz); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - int64_t offset_stride = 0; - if (inputs[1].ndim() > 0) { - offset_stride = inputs[1].strides()[0]; - } - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(donated ? out : in), - gpu_ptr(out), - gpu_ptr(offset), - gpu_ptr(inputs[2]), - scale_, - std::log2(base_), - strides, - out_strides, - offset_stride, - N, - dims, - inputs[2].strides(0)); - } else { - auto kernel = cu::rope; - int n_per_thread = 4; - uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); - uint3 dims = make_uint3(dims_ / 2, T, dimz); - auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); - int64_t offset_stride = 0; - if (inputs[1].ndim() > 0) { - offset_stride = inputs[1].strides()[0]; + dispatch_bool(head_seq_transpose, [&](auto hs_transpose) { + if (single && !with_freqs) { + auto kernel = + cu::rope_single; + uint2 dims = make_uint2(dims_ / 2, N); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + scale_, + std::log2(base_), + mat_size, + dims); + } else if (single) { + auto kernel = cu:: + rope_single_freqs; + uint2 dims = make_uint2(dims_ / 2, N); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1); + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + gpu_ptr(inputs[2]), + scale_, + mat_size, + dims, + inputs[2].strides(0)); + } else if (with_freqs) { + auto kernel = cu::rope_freqs< + DataType, + traditional.value, + forward.value, + hs_transpose.value>; + int n_per_thread = 4; + uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); + uint3 dims = make_uint3(dims_ / 2, T, dimz); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + int64_t offset_stride = 0; + if (inputs[1].ndim() > 0) { + offset_stride = inputs[1].strides()[0]; + } + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + gpu_ptr(inputs[2]), + scale_, + std::log2(base_), + strides, + out_strides, + offset_stride, + N, + dims, + inputs[2].strides(0)); + } else { + auto kernel = cu::rope< + DataType, + traditional.value, + forward.value, + hs_transpose.value>; + int n_per_thread = 4; + uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread); + uint3 dims = make_uint3(dims_ / 2, T, dimz); + auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z); + int64_t offset_stride = 0; + if (inputs[1].ndim() > 0) { + offset_stride = inputs[1].strides()[0]; + } + encoder.add_kernel_node( + kernel, + grid, + block, + gpu_ptr(donated ? out : in), + gpu_ptr(out), + gpu_ptr(offset), + scale_, + std::log2(base_), + strides, + out_strides, + offset_stride, + N, + dims); } - encoder.add_kernel_node( - kernel, - grid, - block, - gpu_ptr(donated ? out : in), - gpu_ptr(out), - gpu_ptr(offset), - scale_, - std::log2(base_), - strides, - out_strides, - offset_stride, - N, - dims); - } + }); }); }); }); From fa698d8b101f6bab5df910484d0eba04c71d370b Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 16 Jun 2026 06:30:29 -0700 Subject: [PATCH 2/3] output has the same strides as an input --- mlx/backend/cuda/rope.cu | 48 +++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index 2c27ab1846..c19c99c48c 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -131,26 +131,32 @@ __device__ void rope_impl( // For the swapaxes(1,2) layout the (batch, head) pair is not a single // linear stride: batch_stride = T * seq_stride, head_stride = strides[0]. + // When hs_transpose is on the output mirrors the input layout, so the same + // decomposition applies to writes as well. size_t in_batch_head; + size_t out_batch_head; if (hs_transpose) { - int64_t batch_stride = static_cast(dims.y) * strides[1]; - in_batch_head = batch_idx * batch_stride + head_idx * strides[0]; + int64_t in_batch_stride = static_cast(dims.y) * strides[1]; + int64_t out_batch_stride = static_cast(dims.y) * out_strides[1]; + in_batch_head = batch_idx * in_batch_stride + head_idx * strides[0]; + out_batch_head = batch_idx * out_batch_stride + head_idx * out_strides[0]; } else { in_batch_head = mat_idx * strides[0]; + out_batch_head = mat_idx * out_strides[0]; } // Compute the input and output indices size_t in_index_1, in_index_2; size_t out_index_1, out_index_2; if (traditional) { - out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + - mat_idx * out_strides[0]; - out_index_2 = out_index_1 + 1; + out_index_1 = + 2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_batch_head; + out_index_2 = out_index_1 + out_strides[2]; in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + in_batch_head; in_index_2 = in_index_1 + strides[2]; } else { - out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] + - mat_idx * out_strides[0]; + out_index_1 = + pos.x * out_strides[2] + pos.y * out_strides[1] + out_batch_head; out_index_2 = out_index_1 + dims.x * out_strides[2]; in_index_1 = pos.x * strides[2] + pos.y * strides[1] + in_batch_head; in_index_2 = in_index_1 + dims.x * strides[2]; @@ -326,7 +332,18 @@ void RoPE::eval_gpu( in.strides()[1] == D && in.strides()[2] == static_cast(N) * D && in.strides()[3] == 1) { head_seq_transpose = true; - out.set_data(cu::malloc_async(out.nbytes(), encoder)); + // Mirror the input layout in the output so we can donate when possible + // and otherwise write into a buffer that preserves the swapaxes view. + if (in.is_donatable()) { + donated = true; + out.copy_shared_buffer(in); + } else { + out.set_data( + cu::malloc_async(in.data_size() * in.itemsize(), encoder), + in.data_size(), + in.strides(), + in.flags()); + } strides[0] = in.strides()[1]; strides[1] = in.strides()[2]; strides[2] = in.strides()[3]; @@ -339,9 +356,18 @@ void RoPE::eval_gpu( strides[1] = out.strides()[ndim - 2]; strides[2] = out.strides()[ndim - 1]; } - out_strides[0] = mat_size; - out_strides[1] = out.strides()[ndim - 2]; - out_strides[2] = out.strides()[ndim - 1]; + if (head_seq_transpose) { + // Output mirrors the swapaxes(1, 2) input: head and seq strides are the + // input's strides, and the (batch, head) pair is reassembled by the + // hs_transpose path in the kernel. + out_strides[0] = strides[0]; + out_strides[1] = strides[1]; + out_strides[2] = strides[2]; + } else { + out_strides[0] = mat_size; + out_strides[1] = out.strides()[ndim - 2]; + out_strides[2] = out.strides()[ndim - 1]; + } // Some flags to help us dispatch below bool single = in.flags().row_contiguous && B == 1 && T == 1; From caa649870cb00b3e67e09efe271b8d030ce8a734 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 16 Jun 2026 06:39:53 -0700 Subject: [PATCH 3/3] drop from benchmarks --- benchmarks/python/rope_bench.py | 17 ----------------- mlx/backend/cuda/rope.cu | 9 --------- 2 files changed, 26 deletions(-) diff --git a/benchmarks/python/rope_bench.py b/benchmarks/python/rope_bench.py index b35fe902fd..35479c0b1a 100644 --- a/benchmarks/python/rope_bench.py +++ b/benchmarks/python/rope_bench.py @@ -31,22 +31,5 @@ def rope_mat(x): time_fn(rope_mat, x) -def time_rope_hs_transposed(): - rope = nn.RoPE(128) - - # matrix - x = mx.random.uniform(shape=(8, 8192, 8, 128)).astype(mx.float16) - mx.eval(x) - x = x.transpose(0, 2, 1, 3) - - def rope_transposed(x): - for _ in range(32): - x = rope(x) - return x - - time_fn(rope_transposed, x) - - if __name__ == "__main__": time_rope() - time_rope_hs_transposed() diff --git a/mlx/backend/cuda/rope.cu b/mlx/backend/cuda/rope.cu index c19c99c48c..1c29464d51 100644 --- a/mlx/backend/cuda/rope.cu +++ b/mlx/backend/cuda/rope.cu @@ -129,10 +129,6 @@ __device__ void rope_impl( float costheta = cos(theta); float sintheta = sin(theta); - // For the swapaxes(1,2) layout the (batch, head) pair is not a single - // linear stride: batch_stride = T * seq_stride, head_stride = strides[0]. - // When hs_transpose is on the output mirrors the input layout, so the same - // decomposition applies to writes as well. size_t in_batch_head; size_t out_batch_head; if (hs_transpose) { @@ -332,8 +328,6 @@ void RoPE::eval_gpu( in.strides()[1] == D && in.strides()[2] == static_cast(N) * D && in.strides()[3] == 1) { head_seq_transpose = true; - // Mirror the input layout in the output so we can donate when possible - // and otherwise write into a buffer that preserves the swapaxes view. if (in.is_donatable()) { donated = true; out.copy_shared_buffer(in); @@ -357,9 +351,6 @@ void RoPE::eval_gpu( strides[2] = out.strides()[ndim - 1]; } if (head_seq_transpose) { - // Output mirrors the swapaxes(1, 2) input: head and seq strides are the - // input's strides, and the (batch, head) pair is reassembled by the - // hs_transpose path in the kernel. out_strides[0] = strides[0]; out_strides[1] = strides[1]; out_strides[2] = strides[2];