Skip to content
2 changes: 1 addition & 1 deletion cpp/.clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ BreakConstructorInitializers: BeforeColon
BreakInheritanceList: BeforeColon
BreakStringLiterals: true
ColumnLimit: 100
CommentPragmas: '^ IWYU pragma:'
CommentPragmas: '^ (IWYU pragma:)|(SPDX-)'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
# Kept the below 2 to be the same as `IndentWidth` to keep everything uniform
Expand Down
11 changes: 9 additions & 2 deletions cpp/src/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
#pragma once
Expand Down Expand Up @@ -217,10 +217,12 @@ struct dataset_descriptor_host {
std::mutex mutex;
std::atomic<bool> ready; // Not sure if std::holds_alternative is thread-safe
std::variant<ready_t, init_f> value;
cudaEvent_t ready_event;

template <typename InitF>
state(InitF init, size_t size) : ready{false}, value{std::make_tuple(init, size)}
{
RAFT_CUDA_TRY(cudaEventCreateWithFlags(&ready_event, cudaEventDisableTiming));
}

~state() noexcept
Expand All @@ -229,6 +231,7 @@ struct dataset_descriptor_host {
auto& [ptr, stream] = std::get<ready_t>(value);
RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(ptr, stream));
}
RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(ready_event));
}

void eval(rmm::cuda_stream_view stream)
Expand All @@ -239,6 +242,7 @@ struct dataset_descriptor_host {
dev_descriptor_t* ptr = nullptr;
RAFT_CUDA_TRY(cudaMallocAsync(&ptr, size, stream));
fun(ptr, stream);
RAFT_CUDA_TRY(cudaEventRecord(ready_event, stream));
value = std::make_tuple(ptr, stream);
ready.store(true, std::memory_order_release);
}
Expand All @@ -247,7 +251,10 @@ struct dataset_descriptor_host {
auto get(rmm::cuda_stream_view stream) -> dev_descriptor_t*
{
if (!ready.load(std::memory_order_acquire)) { eval(stream); }
return std::get<0>(std::get<ready_t>(value));
// value is immutable at this point.
auto& [ptr, ready_stream] = std::get<ready_t>(value);
if (ready_stream != stream) { RAFT_CUDA_TRY(cudaStreamWaitEvent(stream, ready_event, 0)); }
return ptr;
}
};

Expand Down
Loading