Skip to content

xla/pjrt: honor XLA_PYTHON_CLIENT_ALLOCATOR=platform#856

Draft
magaonka-amd wants to merge 5 commits into
ROCm:mainfrom
magaonka-amd:fix/pjrt-honor-kplatform-rocm-main
Draft

xla/pjrt: honor XLA_PYTHON_CLIENT_ALLOCATOR=platform#856
magaonka-amd wants to merge 5 commits into
ROCm:mainfrom
magaonka-amd:fix/pjrt-honor-kplatform-rocm-main

Conversation

@magaonka-amd
Copy link
Copy Markdown

@magaonka-amd magaonka-amd commented May 11, 2026

📝 Summary of Changes

GpuAllocatorConfig::Kind::kPlatform in GetStreamExecutorGpuDeviceAllocator()
now returns an explicit StreamExecutorAddressAllocator (synchronous
passthrough) instead of nullptr. This restores the user-visible meaning of
XLA_PYTHON_CLIENT_ALLOCATOR=platform, which was being silently overridden by
the BFC allocator that Backend::Backend() builds unconditionally for
CUDA/ROCm platforms.

🎯 Justification

Setting XLA_PYTHON_CLIENT_ALLOCATOR=platform is supposed to deliver a
synchronous passthrough to cudaMalloc/hipMalloc. After commit
426087bc1d (PR
#41761, "Port XLA Backend to use
BFC allocator"), Backend::Backend() unconditionally builds a
MultiDeviceAdapter-over-tsl::BFCAllocator for CUDA/ROCm platforms and
accepts no GpuAllocatorConfig. Combined with PJRT's pre-existing
pjrt_stream_executor_client.cc:311-315 fallback
(if (owned_allocator_ == nullptr) allocator_ = client_->backend().memory_allocator();),
users who set kPlatform get the BFC anyway. The env var is silently ignored.

Steps to reproduce (pre-fix)

cat > /tmp/prove_alloc_override.py <<'EOF'
import os, jax, jax.numpy as jnp
print("[python] XLA_PYTHON_CLIENT_ALLOCATOR =",
      os.environ.get("XLA_PYTHON_CLIENT_ALLOCATOR"))
print("[python] devices:", jax.devices())

@jax.jit
def f(x): return jnp.sin(x) * jnp.cos(x) + jnp.tanh(x)
y = f(jnp.arange(64, dtype=jnp.float32)).block_until_ready()
print("[python] result OK, shape=", y.shape)
EOF

XLA_PYTHON_CLIENT_ALLOCATOR=platform \
XLA_PYTHON_CLIENT_PREALLOCATE=false \
TF_CPP_MIN_LOG_LEVEL=0 \
TF_CPP_VMODULE=bfc_allocator=1,se_gpu_pjrt_client=1 \
  python /tmp/prove_alloc_override.py 2>&1 \
  | grep -E "Using (platform|BFC) allocator|Creating new BFCAllocator|XLA_backend_[0-9]+_bfc"

Pre-fix output:

I0511 16:32:47.460748  Creating new BFCAllocator named: XLA_backend_0_bfc
... (one per device, all from Backend ctor) ...
I0511 16:32:50.078285  Using platform allocator.
I0511 16:32:50.312169  Extending allocation by 2.00MiB bytes for XLA_backend_0_bfc.

The order is the giveaway: the BFC named XLA_backend_<n>_bfc (constructed
in xla/service/backend.cc:164) is the one growing to serve user
allocations
, even though PJRT logs "Using platform allocator." for the same
workload. The env var was a no-op on this code path.

Post-fix output (same script, same env): the Extending allocation by ... for XLA_backend_<n>_bfc. line under load is gone. Only the synchronous
passthrough is exercised, which is what kPlatform advertises.

🚀 Kind of Contribution

🐛 Bug Fix

🧪 Unit Tests

Adds PlatformAllocatorIsSynchronousPassthrough in
xla/pjrt/gpu/se_gpu_pjrt_client_test.cc, mirroring the existing
VmmAllocatorCanBeSet pattern. Asserts:

  • The resulting client's allocator() IS a StreamExecutorAddressAllocator.
  • The resulting client's allocator() is NOT a MultiDeviceAdapter — the
    negative assertion catches a future regression where someone re-introduces
    the silent fallback to Backend's BFC.

@magaonka-amd magaonka-amd added the claude-review Request a Claude AI code review for this PR label May 11, 2026
Comment on lines +1536 to 1553
case GpuAllocatorConfig::Kind::kPlatform: {
LOG(INFO) << "Using platform allocator.";
if (allocator_config.collective_memory_size != 0) {
LOG(WARNING)
<< "collective_memory_size is non-zero, but allocator kind is set "
"to \"platform\". Collective memory will not be allocated.";
}
// Returning null will cause the client to use the default backend
// allocator.
return nullptr;
std::vector<se::StreamExecutor*> executors;
executors.reserve(addressable_devices.size());
for (const auto& [ordinal, device] : addressable_devices) {
executors.push_back(device->executor());
}
return std::make_unique<se::StreamExecutorAddressAllocator>(platform,
executors);
}

case GpuAllocatorConfig::Kind::kVmm: {
#if GOOGLE_CUDA
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Like the kVmm case, kPlatform now returns early, which means it skips the collective memory allocator and host allocator setup at lines 1570–1591. The existing warning about collective_memory_size is good, but consider adding a brief comment or LOG(WARNING) noting that kPlatform does not support alternate memory spaces (collective, host, temp buffer) through MultiDeviceAdapter. This would help users who hit failures when using collective operations or host memory allocations with kPlatform.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LOG(WARNING) on every kplatform... hmm not needed I think

Comment on lines +3729 to 3747
TEST(StreamExecutorGpuClientTest, PlatformAllocatorIsSynchronousPassthrough) {
GpuClientOptions options;
options.allocator_config.kind = GpuAllocatorConfig::Kind::kPlatform;
options.allowed_devices = {0};

TF_ASSERT_OK_AND_ASSIGN(auto client, GetStreamExecutorGpuClient(options));

auto* pjrt_se_client =
tensorflow::down_cast<PjRtStreamExecutorClient*>(client.get());
EXPECT_NE(dynamic_cast<se::StreamExecutorAddressAllocator*>(
pjrt_se_client->allocator()),
nullptr);
EXPECT_EQ(dynamic_cast<se::MultiDeviceAdapter*>(pjrt_se_client->allocator()),
nullptr);
}

#if GOOGLE_CUDA
TEST(StreamExecutorGpuClientTest, VmmAllocatorCanBeSet) {
GpuClientOptions options;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The test is well-structured and mirrors the existing VmmAllocatorCanBeSet pattern nicely. Two optional improvements:

  1. Interaction with command-buffer override: There's a pre-existing override at se_gpu_pjrt_client.cc:1493–1498 that silently converts any non-kVmm kind to kVmm when xla_gpu_command_buffer_update_mode is not ALWAYS_UPDATE. This test passes because the default is ALWAYS_UPDATE, but could break if that default ever changes. Consider explicitly setting xla_gpu_command_buffer_update_mode in the test's debug options, or adding a comment noting the dependency.

  2. E2E coverage: Unlike VmmAllocatorE2ETest which runs an HLO program, this test only checks the allocator type. An additional E2E test that allocates/deallocates a buffer through the platform allocator would provide stronger confidence that the passthrough semantics work end-to-end.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment.

i dont think hlo test is needed.

@claude
Copy link
Copy Markdown

claude Bot commented May 11, 2026

Review Summary

Overall: Looks good. This is a well-motivated bug fix. The old kPlatform case returned nullptr, causing the PJRT client constructor to fall back to the backend's unconditional BFC allocator — silently ignoring XLA_PYTHON_CLIENT_ALLOCATOR=platform. The fix correctly returns an explicit StreamExecutorAddressAllocator, consistent with how kVmm handles a similar early-return pattern. The test follows existing patterns well.

Two minor suggestions posted inline:

  1. Consider adding a warning/comment about absent alternate memory space allocators (collective, host) when using kPlatform.
  2. Consider hardening the test against a pre-existing command-buffer-mode override and optionally adding an E2E allocation test.

No blocking issues found.

🤖 Generated with Claude Code

@github-actions github-actions Bot removed the claude-review Request a Claude AI code review for this PR label May 11, 2026
@magaonka-amd magaonka-amd force-pushed the fix/pjrt-honor-kplatform-rocm-main branch from 05c8d1b to 4d8398a Compare May 11, 2026 20:38
@i-chaochen i-chaochen requested a review from mfrancepillois May 12, 2026 09:23
Copy link
Copy Markdown
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC NV will also have this same issue, isn't? or they don't because they are using VMM?

executors.push_back(device->executor());
}
return std::make_unique<se::StreamExecutorAddressAllocator>(platform,
executors);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

openxla#7963 (comment) have you checked their previous fix?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me check this and get back to you thank you.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IUC NV will also have this same issue, isn't --> Yes NV also has this problem

openxla#7963 (comment) -- here concern is about constructing
MultiDeviceAdapter object which we dont do here. I think we are okay in this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarification. I assume you found out this from jax upstream pytest? If it's also failed NV side why it doesn't appear on their NV CI?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes in pytest run I came across this issue.

In pytest rocm started seeing flaky autotuner problems , after debugging we realized BFC allocator was getting called and when BFCAllocator does asynchronous deallocation there is a race condition.
all these problems got exposed by openxla@426087bc1d recent commit from G.

so my argument here is we have two problems:

  1. BFC getting called when it shouldn't have because pytest run from JAX explicitly asks for platform allocator : so this PR is for that
  2. race condition in when BFC allocator is used : this is something we need to debug ( ruturaj is on it , I'm also looking bit into it ).

now coming to how NV survives this: even when platform allocator not used NV doesnt suffer from race issues that we are having. why exactly it doesn't race is not clear to me yet. I tried playing with CUDA reproducer on our H100 machine. my results were not really conclusive.

so in short answer:
does issue 1 exist in CUDA? yes.
does it lead to UT fails in CUDA ? No
does issue 1 lead to UT fails in ROCm ? Yes.

Copy link
Copy Markdown
Collaborator

@i-chaochen i-chaochen May 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK platform allocator also depends on BFC to do memory/fragmentation management. The only diff is the initial stage of BFC will pre-allocate most of memory while platform allocator doesn't. So I still think we need to firstly figure out the root cause of race condition in BFC allocator (if we think the issue is from there)

for (const auto& [ordinal, device] : addressable_devices) {
executors.push_back(device->executor());
}
return std::make_unique<se::StreamExecutorAddressAllocator>(platform,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this early return safe, and why is GetGpuHostAllocator not required?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I want to provide similar option as kvmm , if user explicitly sets platform allocator is needed , we intentionally want to bypass BFC wrapped alternative allocator.

IIUC GetGpuHostAllocator provides memory pool and when user asks for kplatform he wants to opt out from the pooled memory.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks for explaining.

@magaonka-amd
Copy link
Copy Markdown
Author

@i-chaochen is it okay to open this PR upstream?

@i-chaochen
Copy link
Copy Markdown
Collaborator

Thanks for the explain. Yes, please

@magaonka-amd
Copy link
Copy Markdown
Author

openxla#42627 opened PR upstream. thanks everyone for the feedback.

@magaonka-amd magaonka-amd marked this pull request as draft May 14, 2026 17:12
Google-ML-Automation and others added 4 commits May 27, 2026 09:24
Updates LLVM usage to match
[a225aafbd1a4](llvm/llvm-project@a225aafbd1a4)

PiperOrigin-RevId: 922174432
…to use a fully-parameterised SpmdPartitioningTest.

PiperOrigin-RevId: 922192214
Make GpuAllocatorConfig::Kind::kPlatform actually deliver a synchronous
passthrough allocator instead of returning nullptr. Previously the env
var XLA_PYTHON_CLIENT_ALLOCATOR=platform was silently overridden by
the BFC allocator that Backend::Backend() builds unconditionally for
CUDA/ROCm platforms.

Adds a regression test PlatformAllocatorIsSynchronousPassthrough that
asserts the resulting client's allocator is a StreamExecutorAddressAllocator
and not a MultiDeviceAdapter (which would indicate the silent BFC
fallback).
…iption_test

DWYU flagged //xla/runtime:process_id as unused by
se_gpu_topology_description_test. Confirmed: the test source contains
no reference to any process_id / ProcessId symbol and does not include
"xla/runtime/process_id.h", so the dep is dead.
@magaonka-amd magaonka-amd force-pushed the fix/pjrt-honor-kplatform-rocm-main branch from a8a2432 to 8536283 Compare May 27, 2026 17:46
clang-tidy-cuda's misc-include-cleaner flagged
xla/stream_executor/device_address.h as not used directly inside
PlatformAllocatorIsSynchronousPassthrough, and dwyu independently
flagged that both device_address and stream had no matching BUILD
dep on //xla/pjrt/gpu:se_gpu_pjrt_client_test.

These two includes were carried over from the original commit when
the monolithic se_gpu_pjrt_client_test.cc still owned the multi-GPU
tests that needed them. After the file split (openxla/xla 591da9d)
my new single-GPU test only uses StreamExecutorAddressAllocator and
MultiDeviceAdapter; nothing from stream.h or device_address.h.

Pruning the two includes silences both clang-tidy and dwyu without
changing any logic. Verified locally with --config=rocm on gfx950:

  bazel build //xla/pjrt/gpu:se_gpu_pjrt_client_test
  -> Build completed successfully

No functional change.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants