diff --git a/cpp/src/neighbors/detail/cagra/compute_distance.hpp b/cpp/src/neighbors/detail/cagra/compute_distance.hpp index 75b56860bb..94b38c70e4 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/src/neighbors/detail/cagra/compute_distance.hpp @@ -217,6 +217,7 @@ struct dataset_descriptor_host { std::mutex mutex; std::atomic ready; // Not sure if std::holds_alternative is thread-safe std::variant value; + cudaEvent_t init_event{nullptr}; template state(InitF init, size_t size) : ready{false}, value{std::make_tuple(init, size)} @@ -229,6 +230,7 @@ struct dataset_descriptor_host { auto& [ptr, stream] = std::get(value); RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(ptr, stream)); } + if (init_event != nullptr) { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(init_event)); } } void eval(rmm::cuda_stream_view stream) @@ -237,8 +239,12 @@ struct dataset_descriptor_host { if (std::holds_alternative(value)) { auto& [fun, size] = std::get(value); dev_descriptor_t* ptr = nullptr; + RAFT_CUDA_TRY(cudaEventCreateWithFlags(&init_event, cudaEventDisableTiming)); RAFT_CUDA_TRY(cudaMallocAsync(&ptr, size, stream)); fun(ptr, stream); + // Record an event after initialization so that other streams can establish + // a GPU-side dependency without expensive host synchronization. + RAFT_CUDA_TRY(cudaEventRecord(init_event, stream)); value = std::make_tuple(ptr, stream); ready.store(true, std::memory_order_release); } @@ -247,6 +253,11 @@ struct dataset_descriptor_host { auto get(rmm::cuda_stream_view stream) -> dev_descriptor_t* { if (!ready.load(std::memory_order_acquire)) { eval(stream); } + // Make the caller's stream wait for the init to complete. This is a + // lightweight GPU-side dependency with no host blocking. On the same + // stream that performed the init (or after the event has already + // completed) this is essentially a no-op. + if (init_event != nullptr) { RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, init_event)); } return std::get<0>(std::get(value)); } };