Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion mlx/backend/cpu/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,41 @@ void ScatterAxis::eval_cpu(const std::vector<array>& inputs, array& out) {
}

template <typename T>
void masked_scatter_impl(const array& mask, const array& src, array& out) {
void masked_scatter_contiguous_impl(
const array& mask,
const array& src,
array& out) {
const bool* mask_ptr = mask.data<bool>();
const T* src_ptr = src.data<T>();
T* dst_ptr = out.data<T>();

const size_t batch_count = mask.shape(0);
const size_t mask_batch_size = mask.size() / batch_count;
const size_t src_batch_size = src.size() / batch_count;

for (size_t b = 0; b < batch_count; ++b) {
const size_t batch_offset = b * mask_batch_size;
const size_t src_offset = b * src_batch_size;
size_t src_consumed = 0;

for (size_t i = 0; i < mask_batch_size; ++i) {
if (mask_ptr[batch_offset + i]) {
if (src_consumed >= src_batch_size) {
throw std::runtime_error(
"[MaskedScatter::eval_cpu] Source does not have enough elements for mask.");
}
dst_ptr[batch_offset + i] = src_ptr[src_offset + src_consumed];
++src_consumed;
}
}
}
}

template <typename T>
void masked_scatter_general_impl(
const array& mask,
const array& src,
array& out) {
ContiguousIterator mask_it(mask);
ContiguousIterator src_it(src);
ContiguousIterator out_it(out);
Expand Down Expand Up @@ -784,6 +818,16 @@ void masked_scatter_impl(const array& mask, const array& src, array& out) {
}
}

template <typename T>
void masked_scatter_impl(const array& mask, const array& src, array& out) {
if (mask.flags().row_contiguous && src.flags().row_contiguous &&
out.flags().row_contiguous) {
masked_scatter_contiguous_impl<T>(mask, src, out);
} else {
masked_scatter_general_impl<T>(mask, src, out);
}
}

void MaskedScatter::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3);

Expand Down