Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 154 additions & 101 deletions mlx/backend/cuda/rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,12 @@ __global__ void rope_single_freqs(
in, out, *offset, inv_freq, scale, stride, pos, dims);
}

template <typename T, bool traditional, bool forward, int N = 4>
template <
typename T,
bool traditional,
bool forward,
bool hs_transpose,
int N = 4>
__device__ void rope_impl(
const T* in,
T* out,
Expand All @@ -124,21 +129,32 @@ __device__ void rope_impl(
float costheta = cos(theta);
float sintheta = sin(theta);

size_t in_batch_head;
size_t out_batch_head;
if (hs_transpose) {
int64_t in_batch_stride = static_cast<int64_t>(dims.y) * strides[1];
int64_t out_batch_stride = static_cast<int64_t>(dims.y) * out_strides[1];
in_batch_head = batch_idx * in_batch_stride + head_idx * strides[0];
out_batch_head = batch_idx * out_batch_stride + head_idx * out_strides[0];
} else {
in_batch_head = mat_idx * strides[0];
out_batch_head = mat_idx * out_strides[0];
}

// Compute the input and output indices
size_t in_index_1, in_index_2;
size_t out_index_1, out_index_2;
if (traditional) {
out_index_1 = 2 * pos.x * out_strides[2] + pos.y * out_strides[1] +
mat_idx * out_strides[0];
out_index_2 = out_index_1 + 1;
in_index_1 =
2 * pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
out_index_1 =
2 * pos.x * out_strides[2] + pos.y * out_strides[1] + out_batch_head;
out_index_2 = out_index_1 + out_strides[2];
in_index_1 = 2 * pos.x * strides[2] + pos.y * strides[1] + in_batch_head;
in_index_2 = in_index_1 + strides[2];
} else {
out_index_1 = pos.x * out_strides[2] + pos.y * out_strides[1] +
mat_idx * out_strides[0];
out_index_1 =
pos.x * out_strides[2] + pos.y * out_strides[1] + out_batch_head;
out_index_2 = out_index_1 + dims.x * out_strides[2];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + mat_idx * strides[0];
in_index_1 = pos.x * strides[2] + pos.y * strides[1] + in_batch_head;
in_index_2 = in_index_1 + dims.x * strides[2];
}
for (int i = 0; i < N && head_idx + i < n_head; ++i) {
Expand All @@ -163,7 +179,7 @@ __device__ void rope_impl(
}
}

template <typename T, bool traditional, bool forward>
template <typename T, bool traditional, bool forward, bool hs_transpose>
__global__ void rope(
const T* in,
T* out,
Expand All @@ -185,7 +201,7 @@ __global__ void rope(

float d = static_cast<float>(pos.x) / static_cast<float>(dims.x);
float inv_freq = exp2(-d * base);
rope_impl<T, traditional, forward>(
rope_impl<T, traditional, forward, hs_transpose>(
in,
out,
offset,
Expand All @@ -199,7 +215,7 @@ __global__ void rope(
dims);
}

template <typename T, bool traditional, bool forward>
template <typename T, bool traditional, bool forward, bool hs_transpose>
__global__ void rope_freqs(
const T* in,
T* out,
Expand All @@ -222,7 +238,7 @@ __global__ void rope_freqs(
}

float inv_freq = 1.0 / freqs[freq_stride * pos.x];
rope_impl<T, traditional, forward>(
rope_impl<T, traditional, forward, hs_transpose>(
in,
out,
offset,
Expand Down Expand Up @@ -258,6 +274,7 @@ void RoPE::eval_gpu(
cuda::std::array<int64_t, 3> strides;
cuda::std::array<int64_t, 3> out_strides;
bool donated = false;
bool head_seq_transpose = false;
int ndim = in.ndim();

int B = in.shape(0);
Expand Down Expand Up @@ -303,6 +320,27 @@ void RoPE::eval_gpu(
strides[0] = in.strides()[ndim - 3];
strides[1] = in.strides()[ndim - 2];
strides[2] = in.strides()[ndim - 1];
} else if (
ndim == 4 &&
// batch dim is regularly strided
in.strides()[0] == static_cast<int64_t>(T) * N * D &&
// sequence and head dimensions are transposed (x.swapaxes(1, 2))
in.strides()[1] == D && in.strides()[2] == static_cast<int64_t>(N) * D &&
in.strides()[3] == 1) {
head_seq_transpose = true;
if (in.is_donatable()) {
donated = true;
out.copy_shared_buffer(in);
} else {
out.set_data(
cu::malloc_async(in.data_size() * in.itemsize(), encoder),
in.data_size(),
in.strides(),
in.flags());
}
strides[0] = in.strides()[1];
strides[1] = in.strides()[2];
strides[2] = in.strides()[3];
} else {
// Copy non-contiguous > 3D inputs into the output and treat
// input as donated
Expand All @@ -312,9 +350,15 @@ void RoPE::eval_gpu(
strides[1] = out.strides()[ndim - 2];
strides[2] = out.strides()[ndim - 1];
}
out_strides[0] = mat_size;
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
if (head_seq_transpose) {
out_strides[0] = strides[0];
out_strides[1] = strides[1];
out_strides[2] = strides[2];
} else {
out_strides[0] = mat_size;
out_strides[1] = out.strides()[ndim - 2];
out_strides[2] = out.strides()[ndim - 1];
}

// Some flags to help us dispatch below
bool single = in.flags().row_contiguous && B == 1 && T == 1;
Expand All @@ -327,94 +371,103 @@ void RoPE::eval_gpu(
}
encoder.set_output_array(out);
dispatch_float_types(out.dtype(), "rope", [&](auto type_tag) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
dispatch_bool(traditional_, [&](auto traditional) {
dispatch_bool(forward_, [&](auto forward) {
using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
if (single && !with_freqs) {
auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
mat_size,
dims);
} else if (single) {
auto kernel =
cu::rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
mat_size,
dims,
inputs[2].strides(0));
} else if (with_freqs) {
auto kernel =
cu::rope_freqs<DataType, traditional.value, forward.value>;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
std::log2(base_),
strides,
out_strides,
offset_stride,
N,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<DataType, traditional.value, forward.value>;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
dispatch_bool(head_seq_transpose, [&](auto hs_transpose) {
if (single && !with_freqs) {
auto kernel =
cu::rope_single<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
mat_size,
dims);
} else if (single) {
auto kernel = cu::
rope_single_freqs<DataType, traditional.value, forward.value>;
uint2 dims = make_uint2(dims_ / 2, N);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, 1);
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
mat_size,
dims,
inputs[2].strides(0));
} else if (with_freqs) {
auto kernel = cu::rope_freqs<
DataType,
traditional.value,
forward.value,
hs_transpose.value>;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
gpu_ptr<float>(inputs[2]),
scale_,
std::log2(base_),
strides,
out_strides,
offset_stride,
N,
dims,
inputs[2].strides(0));
} else {
auto kernel = cu::rope<
DataType,
traditional.value,
forward.value,
hs_transpose.value>;
int n_per_thread = 4;
uint32_t dimz = B * ((N + n_per_thread - 1) / n_per_thread);
uint3 dims = make_uint3(dims_ / 2, T, dimz);
auto [grid, block] = get_grid_and_block(dims.x, dims.y, dims.z);
int64_t offset_stride = 0;
if (inputs[1].ndim() > 0) {
offset_stride = inputs[1].strides()[0];
}
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
strides,
out_strides,
offset_stride,
N,
dims);
}
encoder.add_kernel_node(
kernel,
grid,
block,
gpu_ptr<DataType>(donated ? out : in),
gpu_ptr<DataType>(out),
gpu_ptr<int32_t>(offset),
scale_,
std::log2(base_),
strides,
out_strides,
offset_stride,
N,
dims);
}
});
});
});
});
Expand Down
Loading