Skip to content

Implement cache for pjrt client streams#861

Draft
mfrancepillois wants to merge 2 commits into
mainfrom
maxime_pjrt_client_cache_rocm
Draft

Implement cache for pjrt client streams#861
mfrancepillois wants to merge 2 commits into
mainfrom
maxime_pjrt_client_cache_rocm

Conversation

@mfrancepillois
Copy link
Copy Markdown

Motivation

The iota_test was very slow on AMD targets (compared to NVDIA) because the pjrt client was destroyed and recreated for each of the 4500 tests that make up the iota_test. This task in ROCm is ~40× slower than with CUDA (see table below).

Phase H100 AMD MI300X
Previous client teardown + new client init (pre-BFC log) ~35ms total ~963ms total
BFC allocator re-setup (8 GPUs) ~0.3ms ~0.1ms
Per-test GPU lifecycle cost 35ms 1009ms

The main cause of slowdowns when creating and destroying a pjrt client lies in the creation and destruction of streams.

Per-test overhead (8 GPU ROCm, iota_test):

CREATION (~406ms):
  Phase1 GetGpuXlaClient:      0.2ms    (negligible, singleton)
  Phase2 hipStreamCreate ×112: 385ms    ← dominant creation cost
  Phase3 EnablePeerAccess:     0.4ms    (negligible, cached)
  Phase4 BFC Allocator:        0.2ms    (negligible, no prealloc)
  Phase5 BuildDistributed:      20ms    (RCCL topology)

DESTRUCTION (~513ms):
  dtor body:                   0.05ms
  thread pool shutdown:       138ms 
  hipStreamDestroy ×112:      375ms    ← dominant destruction cost
  SyncAllActivity:              1.5ms   (device 0 only)

TOTAL OVERHEAD:                ~919ms per test
ACTUAL COMPUTATION:             ~90ms  (IotaReshapeExtraDims = 1012ms total)

This PR therefore implements a cache to avoid creating and destroying the stream for each client/test, and to allow existing streams to be reused (after they have been cleared) for the next test in the process.
Note that the stream and associated meta data are cleared before being placed in the cache to ensure reusing the stream can be reused safely. If a previous test failed due to a hardware failure preventing the stream from being cleared safely, it is destroyed and recreated for the next client/test.

On MI350:

Iota test main (seconds) with cache (seconds) Improvement
Shard=1 2337 264 88.7%
Shard=50 610.8 61.1 90.0%

@i-chaochen i-chaochen added the claude-review Request a Claude AI code review for this PR label May 19, 2026
Comment thread xla/pjrt/local_device_state.cc
Comment thread xla/pjrt/gpu/se_gpu_pjrt_client.cc
Comment thread xla/pjrt/gpu/se_gpu_pjrt_client.cc Outdated
Comment thread xla/pjrt/pjrt_stream_executor_client.h Outdated
Comment thread xla/pjrt/gpu/se_gpu_pjrt_client.cc
Comment thread xla/pjrt/local_device_state.cc
Comment thread xla/pjrt/gpu/se_gpu_pjrt_client.cc
@claude
Copy link
Copy Markdown

claude Bot commented May 19, 2026

Review Summary

This PR adds a process-level cache for LocalDeviceState objects to avoid the expensive GPU stream creation/destruction cycle when PjRt clients are repeatedly created and destroyed (common in test suites). The optimization is well-motivated by ROCm's ~40x slower stream lifecycle compared to CUDA.

Key findings (details in inline comments):

  • Race condition in Reset() — host-side callbacks on callback_thread_ may still be in-flight when compute_events_ is cleared, risking undefined behavior.
  • Config mismatch on cache reuseschedule_async and max_inflight_computations are silently ignored when reusing cached states, which could cause null pointer dereferences.
  • Encapsulationrelease_local_device_state() should be restricted to friend/private access to prevent accidental misuse.
  • No tests included — given the complexity and potential for subtle race conditions, unit tests covering reuse, Reset() failure fallback, and concurrent client lifecycle would be valuable.

🤖 Generated with Claude Code

@github-actions github-actions Bot removed the claude-review Request a Claude AI code review for this PR label May 19, 2026
Copy link
Copy Markdown
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this work!

Just to make sure IIUC, from your table, does it mean only ROCm will suffer this high cost by large number of destoryed/recreated on pjrt client? because seems this irrational pjrt lifecycle changes doesn't hurt NV at all.... if it's this case, I'm not sure it's enough to convince upstream to make this change, or we could just have this changes on ROCm only?

Phase H100 AMD MI300X
Previous client teardown + new client init (pre-BFC log) ~35ms total ~963ms total
BFC allocator re-setup (8 GPUs) ~0.3ms ~0.1ms
Per-test GPU lifecycle cost 35ms 1009ms

const LocalDeviceId local_device_id_;
const LocalChipId local_hardware_id_;
const std::unique_ptr<LocalDeviceState> local_device_state_;
std::unique_ptr<LocalDeviceState> local_device_state_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can only know whether this is ok by upstream review.

@mfrancepillois
Copy link
Copy Markdown
Author

Thanks for this work!

Just to make sure IIUC, from your table, does it mean only ROCm will suffer this high cost by large number of destoryed/recreated on pjrt client? because seems this irrational pjrt lifecycle changes doesn't hurt NV at all.... if it's this case, I'm not sure it's enough to convince upstream to make this change, or we could just have this changes on ROCm only?

Phase H100 AMD MI300X
Previous client teardown + new client init (pre-BFC log) ~35ms total ~963ms total
BFC allocator re-setup (8 GPUs) ~0.3ms ~0.1ms
Per-test GPU lifecycle cost 35ms 1009ms

That’s true. The significant cost associated with creating and destroying streams only applies to the HIP side (I assume that stream creation and destruction have already been optimised within the CUDA driver). So, this optimisation really only benefits us...

@i-chaochen
Copy link
Copy Markdown
Collaborator

Thanks for this work!
Just to make sure IIUC, from your table, does it mean only ROCm will suffer this high cost by large number of destoryed/recreated on pjrt client? because seems this irrational pjrt lifecycle changes doesn't hurt NV at all.... if it's this case, I'm not sure it's enough to convince upstream to make this change, or we could just have this changes on ROCm only?
Phase H100 AMD MI300X
Previous client teardown + new client init (pre-BFC log) ~35ms total ~963ms total
BFC allocator re-setup (8 GPUs) ~0.3ms ~0.1ms
Per-test GPU lifecycle cost 35ms 1009ms

That’s true. The significant cost associated with creating and destroying streams only applies to the HIP side (I assume that stream creation and destruction have already been optimised within the CUDA driver). So, this optimisation really only benefits us...

Then I guess the best is to have this on rocm-only?

@mfrancepillois mfrancepillois marked this pull request as draft May 20, 2026 14:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants