Implement cache for hipStream in ROCm executor#869
Open
mfrancepillois wants to merge 3 commits into
Open
Conversation
b11b315 to
14eee1d
Compare
Review SummaryThis PR adds a process-level Key finding: The use-after-free fix via See inline comments for details. |
ebfaf2c to
91d265d
Compare
hipStreamCreate on ROCm is expensive (~100 ms per stream). When a
PjRtClient is destroyed and a new one is immediately created (common in
tests and interactive use), all ~18 streams per device are destroyed and
recreated, blocking for several seconds.
This commit implements a process-level HipStreamHandleCache singleton
directly in rocm_stream.cc (ROCm-only, touches no CUDA/SYCL code).
Cache key: (device_ordinal, creation_flags, creation_priority_int).
On destruction (RocmStream::~RocmStream):
1. BlockHostUntilDone() already ran -- stream is idle.
2. hipStreamQuery() confirms idleness; on error the handle is
destroyed rather than cached (no poisoning).
3. hipStreamGetFlags / hipStreamGetPriority are called to build the
exact cache key, ensuring a retrieved handle always matches the
flags and priority the new stream would have used -- even if XLA
later switches to hipStreamNonBlocking.
4. Idle handle is stored; hipStreamDestroy is skipped.
On creation (RocmStream::Create via CreateStream):
The cache is checked first; on hit the cached handle is returned
directly and hipStreamCreate is skipped. On miss the cold path
calls hipStreamCreate as before.
The LocalDeviceState and RocmStream wrapper objects are still created
and destroyed normally on every client instantiation. DNN state is
cleaned up via DeallocateStream as usual. Only the underlying HIP
queue (hipStream_t) is reused.
Also fix a latent use-after-free in LocalDeviceState::~LocalDeviceState:
C++ destroys members in reverse declaration order. compute_events_
(line 352 in local_device_state.h) is declared after callback_thread_
(line 342), so its destructor runs *before* callback_thread_'s
destructor joins the worker thread. If callback_thread_ still has
pending pop_front(compute_events_) closures when compute_events_ is
destroyed, those closures access freed memory.
The fix adds callback_thread_->Drain() between SynchronizeAllActivity()
and the explicit stream/event clears. After Drain() the callback thread
is idle and compute_events_ can be safely cleared.
36c68c2 to
c07395d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The iota_test was very slow on AMD targets (compared to NVDIA) because the pjrt client was destroyed and recreated for each of the 4500 tests that make up the
iota_test. This task in ROCm is ~40× slower than with CUDA (see table below).The main cause of slowdowns when creating and destroying a pjrt client lies in the creation and destruction of streams.
This PR implements a process-level
HipStreamHandleCachesingleton directly in rocm_stream.cc. Cache key: (device_ordinal, creation_flags, creation_priority_int).On destruction (RocmStream::~RocmStream):
On creation (RocmStream::Create via CreateStream):
The cache is checked first; on hit the cached handle is returned
directly and hipStreamCreate is skipped. On miss the cold path
calls hipStreamCreate as before.
The LocalDeviceState and RocmStream wrapper objects are still created and destroyed normally on every client instantiation. DNN state is cleaned up via DeallocateStream as usual. Only the underlying HIP queue (hipStream_t) is reused.
Also fix a latent use-after-free in
LocalDeviceState::~LocalDeviceState:C++ destroys members in reverse declaration order.
compute_events_(line 352 in local_device_state.h) is declared after
callback_thread_(line 342), so its destructor runs before
callback_thread_'sdestructor joins the worker thread. If callback_thread_ still has
pending
pop_front(compute_events_)closures whencompute_events_isdestroyed, those closures access freed memory.
The fix adds
callback_thread_->Drain()betweenSynchronizeAllActivity()and the explicit stream/event clears. After
Drain()the callback threadis idle and
compute_events_can be safely cleared.