Skip to content

Implement cache for hipStream in ROCm executor#869

Open
mfrancepillois wants to merge 3 commits into
mainfrom
ci_maxime_hip_stream_cache_rocm
Open

Implement cache for hipStream in ROCm executor#869
mfrancepillois wants to merge 3 commits into
mainfrom
ci_maxime_hip_stream_cache_rocm

Conversation

@mfrancepillois
Copy link
Copy Markdown

Motivation

The iota_test was very slow on AMD targets (compared to NVDIA) because the pjrt client was destroyed and recreated for each of the 4500 tests that make up the iota_test. This task in ROCm is ~40× slower than with CUDA (see table below).

Phase H100 AMD MI300X
Previous client teardown + new client init (pre-BFC log) ~35ms total ~963ms total
BFC allocator re-setup (8 GPUs) ~0.3ms ~0.1ms
Per-test GPU lifecycle cost 35ms 1009ms

The main cause of slowdowns when creating and destroying a pjrt client lies in the creation and destruction of streams.

Per-test overhead (8 GPU ROCm, iota_test):

CREATION (~406ms):
  Phase1 GetGpuXlaClient:      0.2ms    (negligible, singleton)
  Phase2 hipStreamCreate ×112: 385ms    ← dominant creation cost
  Phase3 EnablePeerAccess:     0.4ms    (negligible, cached)
  Phase4 BFC Allocator:        0.2ms    (negligible, no prealloc)
  Phase5 BuildDistributed:      20ms    (RCCL topology)

DESTRUCTION (~513ms):
  dtor body:                   0.05ms
  thread pool shutdown:       138ms 
  hipStreamDestroy ×112:      375ms    ← dominant destruction cost
  SyncAllActivity:              1.5ms   (device 0 only)

TOTAL OVERHEAD:                ~919ms per test
ACTUAL COMPUTATION:             ~90ms  (IotaReshapeExtraDims = 1012ms total)

This PR implements a process-level HipStreamHandleCache singleton directly in rocm_stream.cc. Cache key: (device_ordinal, creation_flags, creation_priority_int).

On destruction (RocmStream::~RocmStream):

  1. BlockHostUntilDone() already ran -- stream is idle.
  2. hipStreamQuery() confirms idleness; on error the handle is destroyed rather than cached (no poisoning).
  3. hipStreamGetFlags / hipStreamGetPriority are called to build the exact cache key, ensuring a retrieved handle always matches the flags and priority the new stream would have used -- even if XLA later switches to hipStreamNonBlocking.
  4. Idle handle is stored; hipStreamDestroy is skipped.

On creation (RocmStream::Create via CreateStream):
The cache is checked first; on hit the cached handle is returned
directly and hipStreamCreate is skipped. On miss the cold path
calls hipStreamCreate as before.

The LocalDeviceState and RocmStream wrapper objects are still created and destroyed normally on every client instantiation. DNN state is cleaned up via DeallocateStream as usual. Only the underlying HIP queue (hipStream_t) is reused.

Also fix a latent use-after-free in LocalDeviceState::~LocalDeviceState:

C++ destroys members in reverse declaration order. compute_events_
(line 352 in local_device_state.h) is declared after callback_thread_
(line 342), so its destructor runs before callback_thread_'s
destructor joins the worker thread. If callback_thread_ still has
pending pop_front(compute_events_) closures when compute_events_ is
destroyed, those closures access freed memory.

The fix adds callback_thread_->Drain() between SynchronizeAllActivity()
and the explicit stream/event clears. After Drain() the callback thread
is idle and compute_events_ can be safely cleared.

@mfrancepillois mfrancepillois force-pushed the ci_maxime_hip_stream_cache_rocm branch from b11b315 to 14eee1d Compare May 20, 2026 13:42
@mfrancepillois mfrancepillois added the claude-review Request a Claude AI code review for this PR label May 20, 2026
Comment thread xla/stream_executor/rocm/rocm_stream.cc
Comment thread xla/stream_executor/rocm/rocm_stream.cc
Comment thread xla/stream_executor/rocm/rocm_stream.cc Outdated
Comment thread xla/pjrt/worker_thread.h Outdated
Comment thread xla/stream_executor/rocm/rocm_stream.cc
Comment thread xla/pjrt/worker_thread.cc
@claude
Copy link
Copy Markdown

claude Bot commented May 20, 2026

Review Summary

This PR adds a process-level hipStream_t handle cache to avoid expensive hipStreamCreate/hipStreamDestroy calls (~100ms each) on ROCm, and fixes a latent use-after-free in LocalDeviceState::~LocalDeviceState where compute_events_ could be destroyed before the callback thread finishes draining.

Key finding: hipStreamQuery in DestroyStream and the cache-hit path in CreateStream both lack executor->Activate() calls, which could cause incorrect behavior on multi-GPU systems. The CUDA counterpart activates context before the equivalent cuStreamQuery.

The use-after-free fix via WorkerThread::Drain() is clean and correct — it's also platform-agnostic and benefits CUDA builds equally.

See inline comments for details.

@github-actions github-actions Bot removed the claude-review Request a Claude AI code review for this PR label May 20, 2026
@mfrancepillois mfrancepillois force-pushed the ci_maxime_hip_stream_cache_rocm branch 2 times, most recently from ebfaf2c to 91d265d Compare May 20, 2026 14:25
hipStreamCreate on ROCm is expensive (~100 ms per stream). When a
PjRtClient is destroyed and a new one is immediately created (common in
tests and interactive use), all ~18 streams per device are destroyed and
recreated, blocking for several seconds.

This commit implements a process-level HipStreamHandleCache singleton
directly in rocm_stream.cc (ROCm-only, touches no CUDA/SYCL code).
Cache key: (device_ordinal, creation_flags, creation_priority_int).

On destruction (RocmStream::~RocmStream):
  1. BlockHostUntilDone() already ran -- stream is idle.
  2. hipStreamQuery() confirms idleness; on error the handle is
     destroyed rather than cached (no poisoning).
  3. hipStreamGetFlags / hipStreamGetPriority are called to build the
     exact cache key, ensuring a retrieved handle always matches the
     flags and priority the new stream would have used -- even if XLA
     later switches to hipStreamNonBlocking.
  4. Idle handle is stored; hipStreamDestroy is skipped.

On creation (RocmStream::Create via CreateStream):
  The cache is checked first; on hit the cached handle is returned
  directly and hipStreamCreate is skipped. On miss the cold path
  calls hipStreamCreate as before.

The LocalDeviceState and RocmStream wrapper objects are still created
and destroyed normally on every client instantiation. DNN state is
cleaned up via DeallocateStream as usual. Only the underlying HIP
queue (hipStream_t) is reused.

Also fix a latent use-after-free in LocalDeviceState::~LocalDeviceState:

  C++ destroys members in reverse declaration order. compute_events_
  (line 352 in local_device_state.h) is declared after callback_thread_
  (line 342), so its destructor runs *before* callback_thread_'s
  destructor joins the worker thread. If callback_thread_ still has
  pending pop_front(compute_events_) closures when compute_events_ is
  destroyed, those closures access freed memory.

  The fix adds callback_thread_->Drain() between SynchronizeAllActivity()
  and the explicit stream/event clears. After Drain() the callback thread
  is idle and compute_events_ can be safely cleared.
@mfrancepillois mfrancepillois force-pushed the ci_maxime_hip_stream_cache_rocm branch from 36c68c2 to c07395d Compare May 28, 2026 09:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant