diff --git a/cpp/include/raft/core/detail/nvtx_range_stack.hpp b/cpp/include/raft/core/detail/nvtx_range_stack.hpp index b0145a6904..8d0ee947cf 100644 --- a/cpp/include/raft/core/detail/nvtx_range_stack.hpp +++ b/cpp/include/raft/core/detail/nvtx_range_stack.hpp @@ -6,12 +6,14 @@ #include +#include #include +#include #include #include -#include #include #include +#include namespace raft { namespace common::nvtx { @@ -35,6 +37,17 @@ class current_range { return {value_, depth_}; } + /** + * Read the full root->leaf range path with instance ids, formatted as + * "name#id > name#id > ..." (empty when no range is active). + * This identifies the exact nvtx range stack responsible for an allocation. + */ + auto get_path() const -> std::string + { + std::lock_guard lock(mu_); + return path_; + } + operator std::string() const { std::lock_guard lock(mu_); @@ -45,38 +58,68 @@ class current_range { mutable std::mutex mu_; std::string value_; std::size_t depth_{0}; + std::string path_; - void set(const char* name, std::size_t depth) + void set(const char* name, std::size_t depth, std::string path) { std::lock_guard lock(mu_); value_ = name ? name : ""; depth_ = depth; + path_ = std::move(path); } }; namespace detail { +RAFT_EXPORT inline std::atomic range_instance_counter{0}; + struct nvtx_range_name_stack { void push(const char* name) { - stack_.emplace(name); - current_->set(name, stack_.size()); + ensure_current(); + auto id = range_instance_counter.fetch_add(1, std::memory_order_relaxed) + 1; + stack_.emplace_back(id, name ? name : ""); + current_->set(stack_.back().second.c_str(), stack_.size(), build_path()); } void pop() { - if (!stack_.empty()) { stack_.pop(); } - current_->set(stack_.empty() ? nullptr : stack_.top().c_str(), stack_.size()); + ensure_current(); + if (!stack_.empty()) { stack_.pop_back(); } + current_->set( + stack_.empty() ? nullptr : stack_.back().second.c_str(), stack_.size(), build_path()); } - auto current() const -> std::shared_ptr { return current_; } + [[nodiscard]] auto current() const -> std::shared_ptr + { + ensure_current(); + return current_; + } private: - std::stack stack_{}; - std::shared_ptr current_{std::make_shared()}; + void ensure_current() const + { + if (!current_) { current_ = std::make_shared(); } + } + + // Serialize the active stack as "name#id > name#id > ..." (outer -> inner). + [[nodiscard]] auto build_path() const -> std::string + { + std::string path; + for (auto const& [id, name] : stack_) { + if (!path.empty()) { path += " > "; } + path += name; + path += '#'; + path += std::to_string(id); + } + return path; + } + + std::vector> stack_{}; + mutable std::shared_ptr current_{std::make_shared()}; }; -inline thread_local nvtx_range_name_stack range_name_stack_instance{}; +RAFT_EXPORT inline thread_local nvtx_range_name_stack range_name_stack_instance{}; } // namespace detail @@ -85,7 +128,7 @@ inline thread_local nvtx_range_name_stack range_name_stack_instance{}; * Pass the returned shared_ptr to another thread to read this thread's current NVTX range name at * any time. */ -inline auto thread_local_current_range() -> std::shared_ptr +RAFT_EXPORT inline auto thread_local_current_range() -> std::shared_ptr { return detail::range_name_stack_instance.current(); } diff --git a/cpp/include/raft/mr/allocation_event_monitor.hpp b/cpp/include/raft/mr/allocation_event_monitor.hpp new file mode 100644 index 0000000000..6300ce9d84 --- /dev/null +++ b/cpp/include/raft/mr/allocation_event_monitor.hpp @@ -0,0 +1,193 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace mr { + +/** + * @brief A single allocation or deallocation event, captured on the allocating thread. + */ +struct allocation_event { + int source_id{0}; //< which registered source this belongs to + std::int64_t current{0}; //< source's live bytes after this event + std::int64_t total_alloc{0}; //< cumulative bytes allocated (this source) + std::int64_t total_freed{0}; //< cumulative bytes freed (this source) + std::size_t nvtx_depth{0}; //< NVTX stack depth at event time + std::string nvtx_range; //< NVTX range name active at event time + std::int64_t event_bytes{0}; //< signed bytes for THIS event (+alloc / -free) + std::string alloc_range; //< responsible range path "name#id > ..." + // captured at ALLOCATION time (empty if unknown) + std::chrono::steady_clock::time_point timestamp{}; //< when the event happened +}; + +/** + * @brief Thread-safe multi-producer / single-consumer queue of allocation_events. + */ +class allocation_event_queue { + public: + /** @brief Append an event (any thread). */ + void push(allocation_event event) + { + { + std::lock_guard lock(mtx_); + events_.push_back(std::move(event)); + } + cv_.notify_one(); + } + + /** + * @brief Block until events are available or the queue is stopped, then move + * all pending events into `out`. + * + * @return false once the queue is stopped AND drained (consumer should exit), + * true otherwise. + */ + bool wait_and_take(std::vector& out) + { + std::unique_lock lock(mtx_); + cv_.wait(lock, [this] { return stopped_ || !events_.empty(); }); + out.clear(); + out.swap(events_); + return !(stopped_ && out.empty()); + } + + /** @brief Signal the consumer to drain and exit. */ + void stop() + { + { + std::lock_guard lock(mtx_); + stopped_ = true; + } + cv_.notify_all(); + } + + private: + std::mutex mtx_; + std::condition_variable cv_; + std::vector events_; + bool stopped_{false}; +}; + +/** + * @brief Consumes allocation_events from a queue and writes one CSV row per + * event from a background thread. + */ +class allocation_event_monitor { + public: + explicit allocation_event_monitor(std::ostream& out) : out_(out) {} + + ~allocation_event_monitor() { stop(); } + + allocation_event_monitor(allocation_event_monitor const&) = delete; + allocation_event_monitor& operator=(allocation_event_monitor const&) = delete; + + [[nodiscard]] auto get_queue() const noexcept -> std::shared_ptr + { + return queue_; + } + + /** + * @brief Register a named source and return its id (column-group index). + * Must be called before start(). + */ + auto register_source(std::string name) -> int + { + int id = static_cast(source_names_.size()); // TODO (huuanhhuyn) conflict id? + source_names_.push_back(std::move(name)); + view_.emplace_back(); + return id; + } + + void start() + { + if (worker_.joinable()) { return; } + write_header(); + worker_ = std::thread([this] { run(); }); + } + + void stop() + { + if (!worker_.joinable()) { return; } + queue_->stop(); // drains the queue and causes the worker to exit its loop + worker_.join(); + } + + private: + struct source_view { + std::int64_t current{0}; + std::int64_t total_alloc{0}; + std::int64_t total_freed{0}; + }; + + void write_header() + { + out_ << "timestamp_us"; + for (auto const& name : source_names_) { + out_ << ',' << name << "_current," << name << "_peak," << name << "_total_alloc," << name + << "_total_freed"; + } + out_ << ",nvtx_depth,nvtx_range,event_source,event_bytes,alloc_range\n"; + out_.flush(); + } + + void run() + { + std::vector batch; + for (;;) { + bool keep_going = queue_->wait_and_take(batch); + for (auto const& event : batch) { + write_row(event); + } + out_.flush(); + if (!keep_going) { break; } + } + } + + void write_row(allocation_event const& event) + { + if (event.source_id >= 0 && event.source_id < static_cast(view_.size())) { + view_[event.source_id] = source_view{event.current, event.total_alloc, event.total_freed}; + } + + auto us = + std::chrono::duration_cast(event.timestamp - start_).count(); + out_ << us; + for (auto const& v : view_) { + out_ << ',' << v.current << ',' << v.current << ',' << v.total_alloc << ',' << v.total_freed; + } + out_ << ',' << event.nvtx_depth << ",\"" << event.nvtx_range << "\""; + + auto const* src_name = + (event.source_id >= 0 && event.source_id < static_cast(source_names_.size())) + ? source_names_[event.source_id].c_str() + : ""; + out_ << ',' << src_name << ',' << event.event_bytes << ",\"" << event.alloc_range << "\"\n"; + } + + std::ostream& out_; + std::shared_ptr queue_{std::make_shared()}; + std::vector source_names_; + std::vector view_; + std::chrono::steady_clock::time_point start_{std::chrono::steady_clock::now()}; + std::thread worker_; +}; + +} // namespace mr +} // namespace raft diff --git a/cpp/include/raft/mr/host_memory_resource.hpp b/cpp/include/raft/mr/host_memory_resource.hpp index b4dcbf906d..d671715d44 100644 --- a/cpp/include/raft/mr/host_memory_resource.hpp +++ b/cpp/include/raft/mr/host_memory_resource.hpp @@ -46,7 +46,7 @@ struct default_host_resource_holder { } }; -inline default_host_resource_holder default_host_resource_holder_{}; +RAFT_EXPORT inline default_host_resource_holder default_host_resource_holder_{}; } // namespace detail @@ -56,7 +56,7 @@ inline default_host_resource_holder default_host_resource_holder_{}; * Returns raft::mr::host_resource_ref pointing to the resource installed * via set_default_host_resource(), or new_delete_resource() if none was set. */ -inline auto get_default_host_resource() -> raft::mr::host_resource_ref +RAFT_EXPORT inline auto get_default_host_resource() -> raft::mr::host_resource_ref { return detail::default_host_resource_holder_.get(); } @@ -70,7 +70,7 @@ inline auto get_default_host_resource() -> raft::mr::host_resource_ref * @param ref Non-owning reference to the resource to install. * @return The previous default host resource ref. */ -inline auto set_default_host_resource(raft::mr::host_resource_ref ref) +RAFT_EXPORT inline auto set_default_host_resource(raft::mr::host_resource_ref ref) -> raft::mr::host_resource_ref { return detail::default_host_resource_holder_.set(ref); diff --git a/cpp/include/raft/mr/recording_adaptor.hpp b/cpp/include/raft/mr/recording_adaptor.hpp new file mode 100644 index 0000000000..6a75abccc7 --- /dev/null +++ b/cpp/include/raft/mr/recording_adaptor.hpp @@ -0,0 +1,171 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include // thread_local_current_range +#include // allocation_event, allocation_event_queue +#include // resource_stats (atomic counters, reused) + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace mr { + +/** + * @brief Resource adaptor that records each allocation/deallocation as an event, + * capturing the active NVTX range AT THE TIME OF THE EVENT. + * + * Details: + * - nvtx range is captured at the event time, preventing misassociation of a later + * range with an earlier allocation. + * - event is pushed to a thread-safe queue, preventing dropped events. This is desirable + * for a profiling use case, where all events and labels are required. + */ +template +class recording_adaptor : public cuda::forward_property, Upstream> { + // Map an allocated address to the nvtx stack range responsible for the allocation. + // It allows the deallocation event to be tagged with the same range, even if the responsible + // range has ended by the time of deallocation. + struct address_range_map { + std::mutex mtx; + std::unordered_map paths; + }; + + Upstream upstream_; + std::shared_ptr stats_; + std::shared_ptr queue_; + std::shared_ptr alloc_map_; + int source_id_; + + auto record_allocation(void* ptr) noexcept -> std::string + { + std::string path; + try { + path = raft::common::nvtx::thread_local_current_range()->get_path(); + if (ptr != nullptr) { + std::lock_guard lock(alloc_map_->mtx); + alloc_map_->paths[ptr] = path; + } + } catch (...) { + } + return path; + } + + auto forget_allocation(void* ptr) noexcept -> std::string + { + std::string path; + try { + std::lock_guard lock(alloc_map_->mtx); + auto it = alloc_map_->paths.find(ptr); + if (it != alloc_map_->paths.end()) { + path = std::move(it->second); + alloc_map_->paths.erase(it); + } + } catch (...) { + // Safely returns "" on a miss + } + return path; + } + + // Build and enqueue an event from the current snapshot and nvtx range + void emit(std::string alloc_range, std::int64_t signed_bytes) noexcept + { + try { + allocation_event event; + event.source_id = source_id_; + event.current = stats_->bytes_current.load(std::memory_order_relaxed); + event.total_alloc = stats_->bytes_total_allocated.load(std::memory_order_relaxed); + event.total_freed = stats_->bytes_total_deallocated.load(std::memory_order_relaxed); + event.timestamp = std::chrono::steady_clock::now(); + auto range = raft::common::nvtx::thread_local_current_range()->get(); + event.nvtx_range = std::move(range.first); + event.nvtx_depth = range.second; + event.event_bytes = signed_bytes; + event.alloc_range = std::move(alloc_range); + queue_->push(std::move(event)); + } catch (...) { + // noexcept: profiling bookkeeping must not disrupt the allocation path, so + // any failure is swallowed. + RAFT_LOG_WARN("Failed to emit an allocation event"); + } + } + + public: + recording_adaptor(Upstream upstream, + std::shared_ptr queue, + int source_id) + : upstream_(std::move(upstream)), + stats_(std::make_shared()), + queue_(std::move(queue)), + alloc_map_(std::make_shared()), + source_id_(source_id) + { + } + + /** @brief Access this source's shared counters. */ + [[nodiscard]] auto get_stats() const noexcept -> std::shared_ptr { return stats_; } + + void* allocate_sync(std::size_t bytes, std::size_t alignment = alignof(std::max_align_t)) + { + void* ptr = upstream_.allocate_sync(bytes, alignment); + stats_->record_allocate(static_cast(bytes)); + emit(record_allocation(ptr), static_cast(bytes)); + return ptr; + } + + void deallocate_sync(void* ptr, + std::size_t bytes, + std::size_t alignment = alignof(std::max_align_t)) noexcept + { + upstream_.deallocate_sync(ptr, bytes, alignment); + stats_->record_deallocate(static_cast(bytes)); + emit(forget_allocation(ptr), -static_cast(bytes)); + } + + template , int> = 0> + void* allocate(cuda::stream_ref stream, + std::size_t bytes, + std::size_t alignment = alignof(std::max_align_t)) + { + void* ptr = upstream_.allocate(stream, bytes, alignment); + stats_->record_allocate(static_cast(bytes)); + emit(record_allocation(ptr), static_cast(bytes)); + return ptr; + } + + template , int> = 0> + void deallocate(cuda::stream_ref stream, + void* ptr, + std::size_t bytes, + std::size_t alignment = alignof(std::max_align_t)) noexcept + { + upstream_.deallocate(stream, ptr, bytes, alignment); + stats_->record_deallocate(static_cast(bytes)); + emit(forget_allocation(ptr), -static_cast(bytes)); + } + + [[nodiscard]] bool operator==(recording_adaptor const& other) const noexcept + { + return upstream_ == other.upstream_; + } + + [[nodiscard]] auto upstream_resource() noexcept -> Upstream& { return upstream_; } + [[nodiscard]] auto upstream_resource() const noexcept -> Upstream const& { return upstream_; } +}; + +} // namespace mr +} // namespace raft diff --git a/cpp/include/raft/util/memory_tracking_resources.hpp b/cpp/include/raft/util/memory_tracking_resources.hpp index 1f96e48938..43e088e54a 100644 --- a/cpp/include/raft/util/memory_tracking_resources.hpp +++ b/cpp/include/raft/util/memory_tracking_resources.hpp @@ -9,11 +9,10 @@ #include #include #include +#include #include #include -#include -#include -#include +#include #include #include @@ -31,7 +30,7 @@ namespace raft { /** * @brief A resources handle that wraps all reachable memory resources with - * allocation-tracking adaptors and logs CSV statistics from a + * allocation-recording adaptors and logs CSV statistics from a * background thread. * * Inherits from raft::resources, so it can be passed anywhere a @@ -39,12 +38,16 @@ namespace raft { * - Materializes all tracked resource types (host, device, pinned, * managed, workspace, large_workspace). * - Takes a snapshot of the original resources to keep them alive. - * - Wraps each with statistics_adaptor + notifying_adaptor. + * - Wraps each with a recording_adaptor that pushes an allocation_event + * (carrying the NVTX range captured at allocation time) onto a shared queue. * - Replaces global host and device resources with tracked versions. - * - Starts a background CSV reporter. + * - Starts a background CSV writer that drains the queue. * - * On destruction the handle stops the reporter and restores the - * global host and device resources. + * On destruction the handle stops the writer (draining all pending events) and + * restores the global host and device resources. + * + * Unlike a sampling monitor, the NVTX range is captured on the allocating + * thread at event time, so range attribution in the CSV is always correct. */ class memory_tracking_resources : public resources { public: @@ -55,7 +58,8 @@ class memory_tracking_resources : public resources { * * @param existing Resources to shallow-copy and wrap with tracking. * @param out Output stream for CSV rows (must outlive this object). - * @param sample_interval Minimum time between successive CSV samples. + * @param sample_interval Accepted for API compatibility; unused by the + * event-driven monitor (every event is recorded). */ memory_tracking_resources(const resources& existing, std::ostream& out, @@ -69,7 +73,7 @@ class memory_tracking_resources : public resources { * * @param existing Resources to shallow-copy and wrap with tracking. * @param file_path Path to the output CSV file (created/truncated). - * @param sample_interval Minimum time between successive CSV samples. + * @param sample_interval Accepted for API compatibility; unused. */ memory_tracking_resources(const resources& existing, const std::string& file_path, @@ -83,7 +87,7 @@ class memory_tracking_resources : public resources { * @brief Construct from scratch (default resources), logging to an ostream. * * @param out Output stream for CSV rows (must outlive this object). - * @param sample_interval Minimum time between successive CSV samples. + * @param sample_interval Accepted for API compatibility; unused. */ explicit memory_tracking_resources(std::ostream& out, duration sample_interval = std::chrono::milliseconds{10}) @@ -95,7 +99,7 @@ class memory_tracking_resources : public resources { * @brief Construct from scratch (default resources), logging to a file. * * @param file_path Path to the output CSV file (created/truncated). - * @param sample_interval Minimum time between successive CSV samples. + * @param sample_interval Accepted for API compatibility; unused. */ explicit memory_tracking_resources(const std::string& file_path, duration sample_interval = std::chrono::milliseconds{10}) @@ -116,17 +120,17 @@ class memory_tracking_resources : public resources { memory_tracking_resources& operator=(memory_tracking_resources const&) = delete; memory_tracking_resources& operator=(memory_tracking_resources&&) = delete; - /** @brief Access the underlying CSV reporter (e.g. to read stats). */ - [[nodiscard]] auto report() noexcept -> raft::mr::resource_monitor& { return report_; } + /** @brief Access the underlying CSV writer. */ + [[nodiscard]] auto report() noexcept -> raft::mr::allocation_event_monitor& { return report_; } private: memory_tracking_resources(const resources* existing, std::unique_ptr owned_stream, std::ostream* out_override, - duration sample_interval) + [[maybe_unused]] duration sample_interval) : resources(existing ? *existing : resources{}), owned_stream_(std::move(owned_stream)), - report_(out_override ? *out_override : *owned_stream_, sample_interval), + report_(out_override ? *out_override : *owned_stream_), old_host_ref_(raft::mr::get_default_host_resource()), old_device_ref_(rmm::mr::get_current_device_resource_ref()) { @@ -136,23 +140,21 @@ class memory_tracking_resources : public resources { // Declaration order determines initialization and destruction order. // snapshot_ is destroyed last (keeps original resource shared_ptrs alive). // owned_stream_ outlives report_ (report_ writes to it). - // report_ is destroyed first of the three (stops background thread). + // report_ is destroyed first of the three (stops the background thread). std::vector snapshot_; std::unique_ptr owned_stream_; - raft::mr::resource_monitor report_; + raft::mr::allocation_event_monitor report_; raft::mr::host_resource_ref old_host_ref_; rmm::device_async_resource_ref old_device_ref_; std::size_t saved_ws_limit_{}; - using host_stats_t = raft::mr::statistics_adaptor; - using host_notify_t = raft::mr::notifying_adaptor; - std::unique_ptr host_adaptor_; - - using device_stats_t = raft::mr::statistics_adaptor; - using device_notify_t = raft::mr::notifying_adaptor; - - std::unique_ptr device_adaptor_; + // Host and device adaptors are installed as the *global* resources, which + // hold them by reference, so they must outlive this object's use -> owned here. + using host_adaptor_t = raft::mr::recording_adaptor; + using device_adaptor_t = raft::mr::recording_adaptor; + std::unique_ptr host_adaptor_; + std::unique_ptr device_adaptor_; void init() { @@ -166,60 +168,54 @@ class memory_tracking_resources : public resources { // Keeps original resource objects alive while tracking refs point into them. snapshot_ = resources_; + auto queue = report_.get_queue(); + + // Source ids are assigned in registration order, which must match the CSV + // column-group order below. + // --- Host (global) --- { - host_stats_t sa{old_host_ref_}; - report_.register_source("host", sa.get_stats()); - host_adaptor_ = std::make_unique(std::move(sa), report_.get_notifier()); + int id = report_.register_source("host"); + host_adaptor_ = std::make_unique(old_host_ref_, queue, id); raft::mr::set_default_host_resource(*host_adaptor_); } // --- Pinned --- { - using stats_t = raft::mr::statistics_adaptor; - using notify_t = raft::mr::notifying_adaptor; - stats_t sa{pinned_ref}; - report_.register_source("pinned", sa.get_stats()); - raft::resource::set_pinned_memory_resource(*this, - notify_t{std::move(sa), report_.get_notifier()}); + int id = report_.register_source("pinned"); + raft::resource::set_pinned_memory_resource( + *this, raft::mr::recording_adaptor{pinned_ref, queue, id}); } // --- Managed --- { - using stats_t = raft::mr::statistics_adaptor; - using notify_t = raft::mr::notifying_adaptor; - stats_t sa{managed_ref}; - report_.register_source("managed", sa.get_stats()); - raft::resource::set_managed_memory_resource(*this, - notify_t{std::move(sa), report_.get_notifier()}); + int id = report_.register_source("managed"); + raft::resource::set_managed_memory_resource( + *this, + raft::mr::recording_adaptor{managed_ref, queue, id}); } // --- Device (global) --- { - device_stats_t sa{old_device_ref_}; - report_.register_source("device", sa.get_stats()); - device_adaptor_ = std::make_unique(std::move(sa), report_.get_notifier()); + int id = report_.register_source("device"); + device_adaptor_ = std::make_unique(old_device_ref_, queue, id); rmm::mr::set_current_device_resource(*device_adaptor_); } // --- Workspace (track upstream to preserve limiting_resource_adaptor) --- { - using ws_stats_t = raft::mr::statistics_adaptor; - using ws_notify_t = raft::mr::notifying_adaptor; - ws_stats_t sa{upstream_ref}; - report_.register_source("workspace", sa.get_stats()); + int id = report_.register_source("workspace"); raft::resource::set_workspace_resource( - *this, ws_notify_t{std::move(sa), report_.get_notifier()}, saved_ws_limit_); + *this, + raft::mr::recording_adaptor{upstream_ref, queue, id}, + saved_ws_limit_); } // --- Large workspace --- { - using lws_stats_t = raft::mr::statistics_adaptor; - using lws_notify_t = raft::mr::notifying_adaptor; - lws_stats_t sa{lws_ref}; - report_.register_source("large_workspace", sa.get_stats()); + int id = report_.register_source("large_workspace"); raft::resource::set_large_workspace_resource( - *this, lws_notify_t{std::move(sa), report_.get_notifier()}); + *this, raft::mr::recording_adaptor{lws_ref, queue, id}); } report_.start(); diff --git a/cpp/tests/core/monitor_resources.cu b/cpp/tests/core/monitor_resources.cu index fabd1e33ec..71a6fe52f0 100644 --- a/cpp/tests/core/monitor_resources.cu +++ b/cpp/tests/core/monitor_resources.cu @@ -4,6 +4,10 @@ */ #include +#include +#include +#include +#include #include #include @@ -11,20 +15,25 @@ #include #include +#include +#include +#include #include #include #include namespace { +namespace nvtx = raft::common::nvtx; +using namespace std::chrono_literals; +constexpr std::size_t MiB = std::size_t{1024} * 1024; + TEST(MemoryTrackingResources, TracksDeviceAllocations) { - using namespace std::chrono_literals; - std::ostringstream oss; { raft::resources res; - raft::resource::set_workspace_to_pool_resource(res, 1024 * 1024); + raft::resource::set_workspace_to_pool_resource(res, 1 * MiB); raft::memory_tracking_resources tracked(res, oss, 1ms); @@ -49,4 +58,34 @@ TEST(MemoryTrackingResources, TracksDeviceAllocations) << output; } +TEST(MemoryTrackingResources, MismatchedRangeLabeling) +{ + const std::string csv_path = "mismatch_range_label.csv"; + + { + raft::resources res; + + raft::memory_tracking_resources tracked(res, csv_path, 1ms); + { + nvtx::range r{"1. expect 10 KB"}; + auto matrix = raft::make_host_vector(tracked, 10 * 1024); + } + { + // Deliberately huge & slow: allocating/freeing 10 GiB of host memory takes + // several ms, which makes the background sampler lag past this range's end. + // As a result this allocation's peak is mis-attributed to the NEXT range in + // the CSV (the range-labeling race discussed in the file header). Source + // attribution (host) stays correct; only the nvtx_range label is wrong. + nvtx::range r{"2. expect 10 GiB"}; + auto vector = raft::make_host_vector(tracked, 10 * 1024 * MiB); + } + { + nvtx::range r{"3. expect 4 MiB"}; + auto matrix = raft::make_host_vector(tracked, 4 * MiB); + } + } // tracked destroyed here: stops the sampler and flushes the file + + std::cout << "Wrote allocation statistics to " << csv_path << "\n"; +} + } // namespace