diff --git a/mlx/backend/cpu/indexing.cpp b/mlx/backend/cpu/indexing.cpp index ec4090172f..4537980d4b 100644 --- a/mlx/backend/cpu/indexing.cpp +++ b/mlx/backend/cpu/indexing.cpp @@ -751,7 +751,41 @@ void ScatterAxis::eval_cpu(const std::vector& inputs, array& out) { } template -void masked_scatter_impl(const array& mask, const array& src, array& out) { +void masked_scatter_contiguous_impl( + const array& mask, + const array& src, + array& out) { + const bool* mask_ptr = mask.data(); + const T* src_ptr = src.data(); + T* dst_ptr = out.data(); + + const size_t batch_count = mask.shape(0); + const size_t mask_batch_size = mask.size() / batch_count; + const size_t src_batch_size = src.size() / batch_count; + + for (size_t b = 0; b < batch_count; ++b) { + const size_t batch_offset = b * mask_batch_size; + const size_t src_offset = b * src_batch_size; + size_t src_consumed = 0; + + for (size_t i = 0; i < mask_batch_size; ++i) { + if (mask_ptr[batch_offset + i]) { + if (src_consumed >= src_batch_size) { + throw std::runtime_error( + "[MaskedScatter::eval_cpu] Source does not have enough elements for mask."); + } + dst_ptr[batch_offset + i] = src_ptr[src_offset + src_consumed]; + ++src_consumed; + } + } + } +} + +template +void masked_scatter_general_impl( + const array& mask, + const array& src, + array& out) { ContiguousIterator mask_it(mask); ContiguousIterator src_it(src); ContiguousIterator out_it(out); @@ -784,6 +818,16 @@ void masked_scatter_impl(const array& mask, const array& src, array& out) { } } +template +void masked_scatter_impl(const array& mask, const array& src, array& out) { + if (mask.flags().row_contiguous && src.flags().row_contiguous && + out.flags().row_contiguous) { + masked_scatter_contiguous_impl(mask, src, out); + } else { + masked_scatter_general_impl(mask, src, out); + } +} + void MaskedScatter::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3);