[ROCm] Insert ReshapeDecomposer before post-GemmRewriter LayoutNormalization#817
Conversation
…malization
Restores the bitcast-only invariant of LayoutNormalization on the cuBLAS-LT
custom-call path. After GemmRewriter pins a "__cublas$lt$matmul" output
layout to {n-1,...,1,0}, a downstream reshape whose layout was
consumer-pulled to a non-canonical permutation is no longer
bitcast-compatible. LayoutNormalization::HandleReshape RET_CHECKs on this.
Re-decompose any non-bitcast reshape so the precondition holds, mirroring
the existing ReshapeDecomposer that already runs before
LayoutNormalization at the cuDNN-conv site.
Both insertion sites are protected:
- After AddGemmRewriterPasses + GemmBroadcastFoldingRewriter (covers dots
that GemmFusion declined to wrap in __triton_gemm).
- After AddConvAndGemmAutotuningPass + GemmBroadcastFoldingRewriter (covers
dots that GemmFusion wrapped, where the autotuner later fissions back to
__cublas$lt$matmul).
|
I was expecting you needed to make a change here https://github.com/openxla/xla/blob/main/xla/backends/gpu/autotuner/factory_rocm.cc#L58. Does test fail consistently with triton_gemm set to false https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc#L336 |
|
if you set triton_gemm set to false layout assignment takes into account make sure it works with hipBLASLT so no you wont see failuire. coming to fix in factory rocm, I didnt think of that before. I just tested it now: I just added it here in factory_rocm.cc std::unique_ptr<HloPassPipeline> GetGemmRewriterPipeline(
const stream_executor::DeviceDescription& device_description,
bool enable_cublaslt, absl::Span<const DType> dtypes) {
auto pipeline = std::make_unique<HloPassPipeline>(
enable_cublaslt ? "hipblaslt_rewriter_pipeline"
: "rocblas_rewriter_pipeline");
pipeline->AddPass(std::make_unique<DotAlgorithmRewriter>());
pipeline->AddPass(std::make_unique<ScaledDotRewriter>());
for (DType dtype : dtypes) {
GemmRewriterOptions options{dtype};
options.enable_cublaslt = enable_cublaslt;
auto gemm_rewriter = std::make_unique<GemmRewriter>(
device_description.gpu_compute_capability(),
device_description.runtime_version(), options);
pipeline->AddPass(std::move(gemm_rewriter));
}
pipeline->AddPass(std::make_unique<ReshapeDecomposer>()); ----- HERE
return pipeline;
} |
|
it will fail consistently with the hacky patch I made , I added that in PR description , other than that it is always flaky. |
|
One thing that concerns me here if I'm reading this right (please correct me if not). The hipBLASLt candidate is profiled as a __cublas$lt$matmul writing canonical {n-1,...,0}, and the layout-fixup transpose lives outside the profiled module, so the comparison is:
However, something similar could already happen elsewhere in autotuning where post-boundary passes add costs the profiler doesn't see. |
No I think it gets profiled as full gpu executable, so that is hipblast + anything else that triton might have fused but hipblaslt can't. |
| // GemmRewriter pins "__cublas$lt$matmul" output layouts to {n-1,...,1,0}, | ||
| // which can leave a downstream reshape no longer bitcast-compatible. | ||
| // Decompose any such reshape so LayoutNormalization's ReshapeIsBitcast | ||
| // precondition holds. | ||
| pipeline.AddPass<ReshapeDecomposer>(); |
There was a problem hiding this comment.
Missing regression test. The PR doesn't include an HLO-level test to guard against this crash regressing.
A test using HloTestBase could construct the post-GemmRewriter HLO directly — with a __cublas$lt$matmul custom call already in place and the problematic reshape (canonical output layout {n-1,...,1,0} feeding a reshape that is no longer bitcast-compatible) — then run only the ReshapeDecomposer + LayoutNormalization portion of the pipeline. This avoids dependence on the autotuner selecting hipBLASLt and makes the test deterministic.
Even a simple test that verifies the module compiles without the RET_CHECK crash would be valuable.
| // Rewrite GEMMs with broadcasted inputs as strided GEMMs. | ||
| pipeline.AddPass<GemmBroadcastFoldingRewriter>(); | ||
|
|
||
| // GemmRewriter pins "__cublas$lt$matmul" output layouts to {n-1,...,1,0}, |
There was a problem hiding this comment.
Nit: comment scope is slightly narrow. The comment mentions only __cublas$lt$matmul, but GemmRewriter also pins layouts for __cublas$gemm custom calls. Since ReshapeDecomposer is unconditional this doesn't affect correctness, but the comment could say "GEMM custom calls" to be more precise and avoid misleading a future reader into thinking this is cBLAS-LT-specific.
| // Rewrite GEMMs with broadcasted inputs as strided GEMMs. | ||
| pipeline.AddPass<GemmBroadcastFoldingRewriter>(); | ||
|
|
||
| // GemmRewriter pins "__cublas$lt$matmul" output layouts to {n-1,...,1,0}, |
There was a problem hiding this comment.
Same comment scope nit applies here. Same suggestion as above — consider broadening "__cublas$lt$matmul" to "GEMM custom calls" in this second comment instance as well.
Review SummaryThe fix is mechanically correct and well-placed. Inserting Key observations:
Suggestions: A regression test and a minor comment scope improvement — see inline comments. 🤖 Generated with Claude Code |
…ization Cherry-pick of #817 onto this debug branch so the run we use to gather XLA_HSACO_TRACE diagnostics doesn't hit the unrelated RET_CHECK at layout_normalization.cc:431 in testDotProductAttention / pallas/gpu_paged_attention_test.
…unImpl LayoutNormalization::HandleReshape requires every kReshape it visits to be bitcast-equivalent (TF_RET_CHECK at xla/service/layout_normalization.cc:431). That precondition is normally satisfied by running ReshapeDecomposer immediately before LayoutNormalization, but the contract lived only in a comment, so callers could (and did) silently drop it: - PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two of the three LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc. - Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") added a third LayoutNormalization invocation inside the new GpuCompiler::AutotunerAndPostCleanup helper without ReshapeDecomposer, reopening the same bug whenever the autotuner fissions __triton_gemm back to __cublas\$lt\$matmul. The pattern is structurally identical on CUDA -- same passes, same canonical layout pin -- and is masked there only because cuDNN handles attention and Triton more reliably wins the autotune. Move the precondition from a comment into LayoutNormalization::RunImpl itself: the pass now runs ReshapeDecomposer on its module as the first step. The three explicit pipeline.AddPass<ReshapeDecomposer>() calls in xla/service/gpu/gpu_compiler.cc that paired with LayoutNormalization (RunLayoutNormalizationPasses, OptimizeHloPostLayoutAssignment x2) are now redundant and removed; their adjacent comments are replaced with a brief note pointing at the encapsulation. The unrelated ReshapeDecomposer at the top of OptimizeHloPostLayoutAssignment (which serves a different purpose, ahead of ReduceDecomposer / GemmRewriter) is left in place. The HandleReshape TF_RET_CHECK gains a stream-style explanation so any future bypass debugs in seconds rather than days. A new backend-agnostic regression test (NonBitcastReshapeIsDecomposedAutomatically) constructs the canonical failing HLO -- the minimum jit_dot_general extracted from jax.nn.dot_product_attention(impl='xla') -- and asserts that LayoutNormalization handles a non-bitcast kReshape directly. Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See ret_check_repro.hlo + XLA_FLAGS=--xla_gpu_autotune_level=0 with the cuBLAS-LT-first hack from PR #817's discussion for a deterministic single-process reproducer; the new regression test exercises the same code path without any backend-specific instructions. A more architecturally complete follow-up would pre-pin GemmRewriter's canonical output layout in GpuLayoutAssignment so the non-bitcast reshape is never created in the first place (see PR #817 discussion question 2). That path is left out of scope here because of its potential perf impact on the Triton-wins path: every Triton-eligible dot would get pre-constrained, possibly inserting kCopy operations on the (much more common) Triton-wins path to fix the rare cuBLAS-LT-wins crash. The encapsulation in this change protects all current and future LayoutNormalization invocations regardless of which architectural direction follow-up work takes. Co-authored-by: magaonka <magaonka@amd.com>
…unImpl LayoutNormalization::HandleReshape requires every kReshape it visits to be bitcast-equivalent (TF_RET_CHECK at xla/service/layout_normalization.cc:431). That precondition is normally satisfied by running ReshapeDecomposer immediately before LayoutNormalization, but the contract lived only in a comment, so callers could (and did) silently drop it: - PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two of the three LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc. - Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") added a third LayoutNormalization invocation inside the new GpuCompiler::AutotunerAndPostCleanup helper without ReshapeDecomposer, reopening the same bug whenever the autotuner fissions __triton_gemm back to __cublas\$lt\$matmul. The pattern is structurally identical on CUDA -- same passes, same canonical layout pin -- and is masked there only because cuDNN handles attention and Triton more reliably wins the autotune. Move the precondition from a comment into LayoutNormalization::RunImpl itself: the pass now runs ReshapeDecomposer on its module as the first step. The three explicit pipeline.AddPass<ReshapeDecomposer>() calls in xla/service/gpu/gpu_compiler.cc that paired with LayoutNormalization (RunLayoutNormalizationPasses, OptimizeHloPostLayoutAssignment x2) are now redundant and removed; their adjacent comments are replaced with a brief note pointing at the encapsulation. The unrelated ReshapeDecomposer at the top of OptimizeHloPostLayoutAssignment (which serves a different purpose, ahead of ReduceDecomposer / GemmRewriter) is left in place. The HandleReshape TF_RET_CHECK gains a stream-style explanation so any future bypass debugs in seconds rather than days. A new backend-agnostic regression test (NonBitcastReshapeIsDecomposedAutomatically) constructs the canonical failing HLO -- the minimum jit_dot_general extracted from jax.nn.dot_product_attention(impl='xla') -- and asserts that LayoutNormalization handles a non-bitcast kReshape directly. Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See ret_check_repro.hlo + XLA_FLAGS=--xla_gpu_autotune_level=0 with the cuBLAS-LT-first hack from PR #817's discussion for a deterministic single-process reproducer; the new regression test exercises the same code path without any backend-specific instructions. A more architecturally complete follow-up would pre-pin GemmRewriter's canonical output layout in GpuLayoutAssignment so the non-bitcast reshape is never created in the first place (see PR #817 discussion question 2). That path is left out of scope here because of its potential perf impact on the Triton-wins path: every Triton-eligible dot would get pre-constrained, possibly inserting kCopy operations on the (much more common) Triton-wins path to fix the rare cuBLAS-LT-wins crash. The encapsulation in this change protects all current and future LayoutNormalization invocations regardless of which architectural direction follow-up work takes. Co-authored-by: magaonka <magaonka@amd.com>
…unImpl LayoutNormalization::HandleReshape requires every kReshape it visits to be bitcast-equivalent (TF_RET_CHECK at xla/service/layout_normalization.cc:431). That precondition is normally satisfied by running ReshapeDecomposer immediately before LayoutNormalization, but the contract lived only in a comment, so callers could (and did) silently drop it: - PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two of the three LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc. - Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") added a third LayoutNormalization invocation inside the new GpuCompiler::AutotunerAndPostCleanup helper without ReshapeDecomposer, reopening the same bug whenever the autotuner fissions __triton_gemm back to __cublas\$lt\$matmul. The pattern is structurally identical on CUDA -- same passes, same canonical layout pin -- and is masked there only because cuDNN handles attention and Triton more reliably wins the autotune. Move the precondition from a comment into LayoutNormalization::RunImpl itself: the pass now runs ReshapeDecomposer on its module as the first step. The three explicit pipeline.AddPass<ReshapeDecomposer>() calls in xla/service/gpu/gpu_compiler.cc that paired with LayoutNormalization (RunLayoutNormalizationPasses, OptimizeHloPostLayoutAssignment x2) are now redundant and removed; their adjacent comments are replaced with a brief note pointing at the encapsulation. The unrelated ReshapeDecomposer at the top of OptimizeHloPostLayoutAssignment (which serves a different purpose, ahead of ReduceDecomposer / GemmRewriter) is left in place. The HandleReshape TF_RET_CHECK gains a stream-style explanation so any future bypass debugs in seconds rather than days. A new backend-agnostic regression test (NonBitcastReshapeIsDecomposedAutomatically) constructs the canonical failing HLO -- the minimum jit_dot_general extracted from jax.nn.dot_product_attention(impl='xla') -- and asserts that LayoutNormalization handles a non-bitcast kReshape directly. Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See ret_check_repro.hlo + XLA_FLAGS=--xla_gpu_autotune_level=0 with the cuBLAS-LT-first hack from PR #817's discussion for a deterministic single-process reproducer; the new regression test exercises the same code path without any backend-specific instructions. A more architecturally complete follow-up would pre-pin GemmRewriter's canonical output layout in GpuLayoutAssignment so the non-bitcast reshape is never created in the first place (see PR #817 discussion question 2). That path is left out of scope here because of its potential perf impact on the Triton-wins path: every Triton-eligible dot would get pre-constrained, possibly inserting kCopy operations on the (much more common) Triton-wins path to fix the rare cuBLAS-LT-wins crash. The encapsulation in this change protects all current and future LayoutNormalization invocations regardless of which architectural direction follow-up work takes. Co-authored-by: magaonka <magaonka@amd.com>
…youtNormalizationPasses helper LayoutNormalization::HandleReshape requires every kReshape it visits to be bitcast-equivalent (TF_RET_CHECK at xla/service/layout_normalization.cc:431). That precondition is normally satisfied by running ReshapeDecomposer immediately before LayoutNormalization, but the contract lived only in a comment, so callers could (and did) silently drop it: - PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two of the three LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc. - Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") added a third LayoutNormalization invocation inside the new GpuCompiler::AutotunerAndPostCleanup helper without a matching ReshapeDecomposer, reopening the same bug whenever the autotuner fissions __triton_gemm back to __cublas\$lt\$matmul. The pattern is structurally identical on CUDA -- same passes, same canonical layout pin -- and is masked there only because cuDNN handles attention and Triton more reliably wins the autotune. Centralize the canonical "ReshapeDecomposer; LayoutNormalization" pair in a new file-local helper AddLayoutNormalizationPasses inside xla/service/gpu/gpu_compiler.cc, and route the three post-GemmRewriter LayoutNormalization invocations (AutotunerAndPostCleanup and OptimizeHloPostLayoutAssignment x2) through it. Both passes remain real HloPassPipeline-registered passes, so --xla_disable_hlo_passes, HLO dumps, and pass scheduling continue to work as before. The fourth LayoutNormalization invocation (RunLayoutNormalizationPasses) interleaves MoveCopyToUsers between ReshapeDecomposer and LayoutNormalization, so it preserves the explicit three-pass spelling instead of using the helper and gets the missing ReshapeDecomposer call restored explicitly. The HandleReshape TF_RET_CHECK gains a stream-style explanation pointing at AddLayoutNormalizationPasses for any future caller who bypasses the helper. A new backend-agnostic regression test (ReshapeDecomposerThenLayoutNormalizationHandlesNonBitcastReshape) constructs the canonical failing HLO -- the minimum jit_dot_general extracted from jax.nn.dot_product_attention(impl='xla') -- and asserts that the canonical "ReshapeDecomposer; LayoutNormalization" pipeline handles it. Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See ret_check_repro.hlo + XLA_FLAGS=--xla_gpu_autotune_level=0 with the cuBLAS-LT-first hack from PR #817's discussion for a deterministic single-process reproducer; the new regression test exercises the same code path without any backend-specific instructions. A more architecturally complete follow-up would pre-pin GemmRewriter's canonical output layout in GpuLayoutAssignment so the non-bitcast reshape is never created in the first place (see PR #817 discussion question 2). That path is left out of scope here because of its potential perf impact on the Triton-wins path. Co-authored-by: magaonka <magaonka@amd.com>
… site LayoutNormalization::HandleReshape requires every kReshape it visits to be bitcast-equivalent (TF_RET_CHECK at xla/service/layout_normalization.cc:431). That precondition is normally satisfied by running ReshapeDecomposer immediately before LayoutNormalization, but the contract lived only in a comment, so callers could (and did) silently drop it: - PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two of the three LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc. - Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") added a third LayoutNormalization invocation inside the new GpuCompiler::AutotunerAndPostCleanup helper without a matching ReshapeDecomposer, reopening the same bug whenever the autotuner fissions __triton_gemm back to __cublas\$lt\$matmul. The pattern is structurally identical on CUDA -- same passes, same canonical layout pin -- and is masked there only because cuDNN handles attention and Triton more reliably wins the autotune. Add the missing ReshapeDecomposer call at the new fourth site (AutotunerAndPostCleanup), mirroring the inline pattern at the three existing sites in the same file. The HandleReshape TF_RET_CHECK gains a stream-style explanation pointing at the canonical pattern so the next forgetter debugs in seconds rather than days. A new backend-agnostic regression test (ReshapeDecomposerThenLayoutNormalizationHandlesNonBitcastReshape) constructs the canonical failing HLO -- the minimum jit_dot_general extracted from jax.nn.dot_product_attention(impl='xla') -- and asserts that the canonical "ReshapeDecomposer; LayoutNormalization" pipeline handles it; if the pattern ever stops working, this test fails in CI before a regression can land. Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See ret_check_repro.hlo + XLA_FLAGS=--xla_gpu_autotune_level=0 with the cuBLAS-LT-first hack from PR #817's discussion for a deterministic single-process reproducer; the new regression test exercises the same code path without any backend-specific instructions. A more architecturally complete follow-up would pre-pin GemmRewriter's canonical output layout in GpuLayoutAssignment so the non-bitcast reshape is never created in the first place (see PR #817 discussion question 2). That path is left out of scope here because of its potential perf impact on the Triton-wins path. Co-authored-by: magaonka <magaonka@amd.com>
… site LayoutNormalization::HandleReshape requires every kReshape it visits to be bitcast-equivalent (TF_RET_CHECK at xla/service/layout_normalization.cc:431). That precondition is normally satisfied by running ReshapeDecomposer immediately before LayoutNormalization, but the contract lived only in a comment, so callers could (and did) silently drop it: - PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two of the three LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc. - Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") added a third LayoutNormalization invocation inside the new GpuCompiler::AutotunerAndPostCleanup helper without a matching ReshapeDecomposer, reopening the same bug whenever the autotuner fissions __triton_gemm back to __cublas\$lt\$matmul. The pattern is structurally identical on CUDA -- same passes, same canonical layout pin -- and is masked there only because cuDNN handles attention and Triton more reliably wins the autotune. Add the missing ReshapeDecomposer call at the new fourth site (AutotunerAndPostCleanup), mirroring the inline pattern at the three existing sites in the same file. The HandleReshape TF_RET_CHECK gains a stream-style explanation pointing at the canonical pattern so the next forgetter debugs in seconds rather than days. A new backend-agnostic regression test (ReshapeDecomposerThenLayoutNormalizationHandlesNonBitcastReshape) constructs the canonical failing HLO -- the minimum jit_dot_general extracted from jax.nn.dot_product_attention(impl='xla') -- and asserts that the canonical "ReshapeDecomposer; LayoutNormalization" pipeline handles it; if the pattern ever stops working, this test fails in CI before a regression can land. Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See ret_check_repro.hlo + XLA_FLAGS=--xla_gpu_autotune_level=0 with the cuBLAS-LT-first hack from PR #817's discussion for a deterministic single-process reproducer; the new regression test exercises the same code path without any backend-specific instructions. A more architecturally complete follow-up would pre-pin GemmRewriter's canonical output layout in GpuLayoutAssignment so the non-bitcast reshape is never created in the first place (see PR #817 discussion question 2). That path is left out of scope here because of its potential perf impact on the Triton-wins path. Co-authored-by: magaonka <magaonka@amd.com>
… site LayoutNormalization::HandleReshape requires every kReshape it visits to be bitcast-equivalent (TF_RET_CHECK at xla/service/layout_normalization.cc:431). That precondition is normally satisfied by running ReshapeDecomposer immediately before LayoutNormalization, but the contract lived only in a comment, so callers could (and did) silently drop it: - PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two of the three LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc. - Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") added a third LayoutNormalization invocation inside the new GpuCompiler::AutotunerAndPostCleanup helper without a matching ReshapeDecomposer, reopening the same bug whenever the autotuner fissions __triton_gemm back to __cublas\$lt\$matmul. The pattern is structurally identical on CUDA -- same passes, same canonical layout pin -- and is masked there only because cuDNN handles attention and Triton more reliably wins the autotune. Add the missing ReshapeDecomposer call at the new fourth site (AutotunerAndPostCleanup), mirroring the inline pattern at the three existing sites in the same file. The HandleReshape TF_RET_CHECK gains a stream-style explanation pointing at the canonical pattern so the next forgetter debugs in seconds rather than days. A new backend-agnostic regression test (ReshapeDecomposerThenLayoutNormalizationHandlesNonBitcastReshape) constructs the canonical failing HLO -- the minimum jit_dot_general extracted from jax.nn.dot_product_attention(impl='xla') -- and asserts that the canonical "ReshapeDecomposer; LayoutNormalization" pipeline handles it; if the pattern ever stops working, this test fails in CI before a regression can land. Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See ret_check_repro.hlo + XLA_FLAGS=--xla_gpu_autotune_level=0 with the cuBLAS-LT-first hack from PR #817's discussion for a deterministic single-process reproducer; the new regression test exercises the same code path without any backend-specific instructions. A more architecturally complete follow-up would pre-pin GemmRewriter's canonical output layout in GpuLayoutAssignment so the non-bitcast reshape is never created in the first place (see PR #817 discussion question 2). That path is left out of scope here because of its potential perf impact on the Triton-wins path. Co-authored-by: magaonka <magaonka@amd.com>
…tunerAndPostCleanup Restores the bitcast-only invariant of LayoutNormalization on the GpuCompiler::AutotunerAndPostCleanup path. PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc to satisfy LayoutNormalization::HandleReshape's ShapeUtil::ReshapeIsBitcast precondition after GemmRewriter pins __cublas\$lt\$matmul output layouts to {n-1,...,1,0}. Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") introduced a third LayoutNormalization invocation inside the new AutotunerAndPostCleanup helper without a matching ReshapeDecomposer, reopening the RET_CHECK whenever the autotuner fissions __triton_gemm back to __cublas\$lt\$matmul. The pattern is structurally identical on CUDA -- same passes, same canonical layout pin -- and is masked there only because cuDNN handles attention and Triton more reliably wins the autotune. Add the missing ReshapeDecomposer call, mirroring the inline pattern at the three pre-existing sites. The HandleReshape TF_RET_CHECK gains a stream-style explanation pointing at the canonical pattern, so any future caller that drops ReshapeDecomposer self-explains. Drop one unused <utility> include in layout_normalization.h flagged by misc-include-cleaner. A new backend-agnostic regression test (ReshapeDecomposerThenLayoutNormalizationHandlesNonBitcastReshape) constructs the minimal jit_dot_general extracted from jax.nn.dot_product_attention(impl='xla') and asserts the canonical ReshapeDecomposer; LayoutNormalization pipeline handles it. Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See PR #817's cuBLAS-LT-first patch + XLA_FLAGS=--xla_gpu_autotune_level=0 for a deterministic single-process reproducer. Co-authored-by: magaonka <magaonka@amd.com>
…tunerAndPostCleanup Restores LayoutNormalization::HandleReshape's ShapeUtil::ReshapeIsBitcast precondition on the GpuCompiler::AutotunerAndPostCleanup path. PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc to satisfy that precondition after GemmRewriter pins __cublas$lt$matmul output layouts to {n-1,...,1,0}. Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") introduced a third LayoutNormalization invocation inside the new AutotunerAndPostCleanup helper without a matching ReshapeDecomposer, reopening the RET_CHECK whenever the autotuner fissions __triton_gemm back to __cublas$lt$matmul. Add the missing ReshapeDecomposer call, mirroring the inline pattern at the three pre-existing sites. AutotunerAndPostCleanup is the autotuner's boundary in the main pipeline; FissionBackend::ApplyConfig + InlineFissionedComputation can re-introduce a non-bitcast reshape there when the hipBLASLt candidate wins, and only that LayoutNormalization invocation is downstream of the layout shift. The fix preserves --xla_disable_hlo_passes=reshape-decomposer (no encapsulation) and stays in a single named pipeline (no HLO-dump pollution). Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See PR #817's cuBLAS-LT-first patch + XLA_FLAGS=--xla_gpu_autotune_level=0 for a deterministic single-process reproducer. Co-authored-by: magaonka <magaonka@amd.com>
…malization in AutotunerAndPostCleanup Imported from GitHub PR openxla#41980 📝 Summary of Changes Add the missing pipeline.AddPass<ReshapeDecomposer>() before LayoutNormalization in GpuCompiler::AutotunerAndPostCleanup — the third call site introduced by xla 6874dd2 without ReshapeDecomposer, mirroring the inline pattern at the three pre-existing sites. One-line change. 🎯 Justification openxla#41481 added ReshapeDecomposer ahead of two LayoutNormalization sites to prevent a RET_CHECK when GemmRewriter pins a __cublas$lt$matmul output and a downstream reshape becomes non-bitcast. xla 6874dd2 added a third LayoutNormalization site without the precondition and reopened the bug, most visibly on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py). AutotunerAndPostCleanup is the autotuner's boundary in the main pipeline: FissionBackend::ApplyConfig + InlineFissionedComputation can re-introduce a non-bitcast reshape there when the hipBLASLt candidate wins. Only that LayoutNormalization invocation is downstream of the layout shift, so the cleanup belongs at exactly that one site. The fix preserves --xla_disable_hlo_passes=reshape-decomposer (no encapsulation) and stays in a single named pipeline (no HLO-dump pollution). 🚀 Kind of Contribution 🐛 Bug Fix 🧪 Execution Test xla/tools/run_hlo_module on the jit_dot_general HLO from jax.nn.dot_product_attention(impl='xla'), using the cuBLAS-LT-first hack from #817 + XLA_FLAGS=--xla_gpu_autotune_level=0 to deterministically force the failing path on a single MI355X gfx950. Copybara import of the project: -- 94b1d0c by Ruturaj4 <Ruturaj.vaidya@amd.com>: [XLA:GPU] Insert ReshapeDecomposer before LayoutNormalization in AutotunerAndPostCleanup Restores LayoutNormalization::HandleReshape's ShapeUtil::ReshapeIsBitcast precondition on the GpuCompiler::AutotunerAndPostCleanup path. PR openxla#41481 (commit 788f269) added ReshapeDecomposer before two LayoutNormalization invocations in xla/service/gpu/gpu_compiler.cc to satisfy that precondition after GemmRewriter pins __cublas$lt$matmul output layouts to {n-1,...,1,0}. Commit 6874dd2 ("Move gemm-conv autotuner pass after fusion passes behind a flag.") introduced a third LayoutNormalization invocation inside the new AutotunerAndPostCleanup helper without a matching ReshapeDecomposer, reopening the RET_CHECK whenever the autotuner fissions __triton_gemm back to __cublas$lt$matmul. Add the missing ReshapeDecomposer call, mirroring the inline pattern at the three pre-existing sites. AutotunerAndPostCleanup is the autotuner's boundary in the main pipeline; FissionBackend::ApplyConfig + InlineFissionedComputation can re-introduce a non-bitcast reshape there when the hipBLASLt candidate wins, and only that LayoutNormalization invocation is downstream of the layout shift. The fix preserves `--xla_disable_hlo_passes=reshape-decomposer` (no encapsulation) and stays in a single named pipeline (no HLO-dump pollution). Reproduces the original 41481 failure mode on ROCm under multi-process pytest (tests/nn_test.py::testDotProductAttention*, tests/pallas/gpu_paged_attention_test.py) once the autotuner picks the hipBLASLt fission backend. See PR #817's cuBLAS-LT-first patch + `XLA_FLAGS=--xla_gpu_autotune_level=0` for a deterministic single-process reproducer. Co-authored-by: magaonka <magaonka@amd.com> Merging this change closes openxla#41980 COPYBARA_INTEGRATE_REVIEW=openxla#41980 from ROCm:fix-encapsulate-reshape-decomposer-in-layout-normalization 94b1d0c PiperOrigin-RevId: 911205496
Commit history for google/XNNPACK (bccfe733 -> ace56b61): - b3e7d5a1 mohammadmseet-hue: Fix stack buffer overflows in NCHW reduce rewrite and ynnpack shim functions - 1dfa3da5 mohammadmseet-hue: Address review: return error instead of clamping, use YNN_LOG_ERROR - 0cd97f2f Ken Unger: add rvv support for f16-vcmul - 1f8d093a Ken Unger: add rvv support for f16-vcmul - d4474df4 velonica0: rvv-f16-activation - e36eb021 Ken Unger: add fp16 rvv kernels for vsin,vcos,vexp - d3121ea9 velonica0: [RVV] add rvv f32 kernels for velu, vgelu, vapproxgelu - ae231328 velonica0: Alphabetize RVV elementwise entries in cmake/bzl lists - 0b6f61af velonica0: fix cmake bug - 803079cd Gregory Comer: Add AVX512 f32<->bf16 vcvt kernels - 80303ee9 Gregory Comer: Add native AVX512_BF16 f32->bf16 vcvt kernel - b0a078ab Volodymyr Kysenko: Add benchmarks for int4x2/int2x4 to int8_t conversions. - aca142ca Dillon Sharlet: Decide whether to constant fold heuristically - 172b7f4a Marie White: Add arm_neonbf16 binary kernels - b7ebadee MarkLee131: Reject qpint8 in xnn_define_dynamically_quantized_tensor_value - 524c06d2 MarkLee131: Detect size_t overflow in get_tensor_size and reject the tensor - 0d406ba6 Volodymyr Kysenko: Add 2-bit and 4-bit interleave kernels. - 3b773ce1 Quentin Khan: Don't produce an op when `Cast` is casting to the same type as the input type. - 16acef83 Volodymyr Kysenko: Fix typo in the comment. - c8ec3d63 Misha Gutman: Added dynamic_b support for qdu8-f32-qc8w operator. - 1c7554a8 Dillon Sharlet: Remove unnecessary assert - bf188a3d Nicolas Pitre: Add Zephyr RTOS (Generic) platform support - 55d4036f Frank Barchard: Increase tolerance for SUBGRAPH_FP16.fully_connected_qd8_f16_qc8w test to account for numerical deviation. - 8e06513a XNNPACK Team: Merge pull request openxla#10060 from npitre:zephyr-support-pr - 6a9a9c50 Dillon Sharlet: Disable AMX kernels if msan is enabled - a9cb6cb9 XNNPACK Team: Merge pull request openxla#10023 from GregoryComer:bf16-f32-vcvt-avx512 - cde7d935 Marie White: Improve tile sizes for arm_neonbf16 kernels. Tuned with AI agents. - fadd2dbe XNNPACK Team: Merge pull request openxla#9986 from velonica0:rvv-f16-elementwise - 4ca5fb8d Dillon: Merge branch 'master' into f16-unary-trig-rvv - b2985573 Quentin Khan: Add wrappers for storage type of 2/4 bit int and 16 bit floats. - 9c70eb91 Quentin Khan: Add reverse data type to native type mapping. - 2c52c9f Quentin Khan: Add a conversion function to be able to specialize buffer copy from a sequence. - c817561c Quentin Khan: Move declaration of `NativeStorage` and clarify comment of `StorageImpl`. - 7b375800 MarkLee131: Clarify qpint8 rejection wording - 1437d94b MarkLee131: Use xnn_safe_mul/xnn_safe_add in get_tensor_size - f074438f MarkLee131: Split xnn_safe_mul/xnn_safe_add into separate statements - 64081049 Dillon Sharlet: Resubmit openxla#10069 - 56496fd6 Dillon Sharlet: Add int32 sum kernels - bbc68d90 XNNPACK Team: Merge pull request openxla#9963 from velonica0:rvv-elementwise - 51759bd4 XNNPACK Team: Merge pull request openxla#10102 from MarkLee131:fix/integer-overflow-tensor-size - 562e5274 Dillon Sharlet: Refactor `make_schedule` to allow building just the loop splits, and not a whole `scheduling_info` - 8e4e9d5b Dillon Sharlet: Change reduce to make the identity buffer in slinky, instead of in the subgraph - 64d21ff8 Ken Unger: handle unconfigured f16-vcmul kernel - 834051a2 XNNPACK Team: Merge pull request openxla#10101 from MarkLee131:fix/qpint8-null-deref - b3a5d44f Ken Unger: merge master - 8b3bda45 Ken Unger: update-microkernels - e2da1edb Frank Barchard: Add f16_wasmrelaxedsimd SIMD headers - 5aa5d64e Quentin Khan: Add a parallel lib to `utils:matchers` for internal targets that are only compiled with OSS. - b2f46c0c Quentin Khan: Add a matcher to to check whether two graph are isomorphic. - f81e3eda Volodymyr Kysenko: Support channelwise zero points in YNNPACK quantized dot products. - 4a318ee8 Frank Barchard: Add portable SIMD template for f16-vsqrt - 4780ab70 Frank Barchard: Run generator to create rvv kernels - 3659dcf2 Ken Unger: merge master - f589c63c Jonathan Clohessy: Update CMakeLists.txt to match SME defaults from bazel - 6833e630 Dillon: Merge branch 'master' into f16-unary-trig-rvv - 26c61a7a XNNPACK Team: Merge pull request openxla#9989 from ken-unger:f16-unary-trig-rvv - b0328fc2 Frank Barchard: Fix WAsm typo in XNNPACK by renaming to Wasm - 04b67752 Dillon Sharlet: Refactor tolerance calculations - 8c2df4d5 Dillon Sharlet: Parallelize reductions in YNNPACK - e9de2685 Dillon Sharlet: Add reference kernels for fp64 elementwise ops - a9390e5a Dillon Sharlet: Fix hexagon build - a3da013b Dillon Sharlet: Add benchmark coverage of reference fp64 elementwise ops - 8a3902dd Dillon Sharlet: Add optimized kernels for fp64 elementwise ops - c8c86398 Dillon Sharlet: Add fp64 fma rules to elementwise compiler - 807d9f9c Alexander Shaposhnikov: Introduce XNN_NO_SANITIZE_FUNCTION macro. - 8e406b86 Dillon Sharlet: Loosen tolerances for dequantize_dot test - bb6c6a48 Misha Gutman: Added convert from qint8 to qcint8. - 689c5c60 Misha Gutman: Removed convert qint8 to qcint8 tests from ynnpack test set. - a3664b21 Dillon Sharlet: Avoid capturing kernel in reduce ops - 0fc9e7e7 Volodymyr Kysenko: Disable subgraph_matcher_test when use_ynnpack is enabled. - b5bc455b Dillon Sharlet: Enable adding and removing dimensions via static_transpose - 6dfbf304 Frank Barchard: Optimize xnn_round_f32 for Hexagon HVX. - 52d94589 XNNPACK Team: Merge pull request openxla#9851 from mohammadmseet-hue:fix/nchw-reduce-overflow-and-shim-bounds - 6bd50499 Misha Gutman: Fixed the crash due to unaligned read. - 94ce3bb6 Volodymyr Kysenko: Refactor extent handling in YNNPACK subgraph. - cceae52c Dillon Sharlet: Always constant fold pack_b ops - e0729a7c Dillon Sharlet: Add assert to catch infinite loop case - 50b01640 Frank Barchard: Fix Hexagon HVX build failure 'sf type used as qf32' on Clang 19 - b3daaef9 Dillon Sharlet: Enable sum(squared(x)) => sum_squared(x) for fp64 - 8f17e0c0 Dillon Sharlet: Relax tolerances of dequantize_dot more - a493bbeb Dillon Sharlet: Add missing benchmark - 95103d5b Frank Barchard: Enable f16 vsqrt wasmrelaxedsimd kernel and scalar fallbacks - d830cd16 Volodymyr Kysenko: Rewrite reduce(static_transpose(x)) into reduce(x) - 7b1bde34 Dillon Sharlet: Remove ternary multiply for purely float types - a571a74b Dillon Sharlet: Add tolerance for quantized int8 operations that may round differently - 58698bd6 Dillon Sharlet: Add `exp2_round` simd helper - ece55c6e Dillon Sharlet: Add rewrite for `sum(a*b)` => `dot(a, b)` where appropriate - fc7f8975 Frederic Rechtenstein: Fix alignment-related crash on AVX512 - 58a233a4 XNNPACK Team: Merge pull request openxla#10167 from JonathanC-ARM:jonclo01/sync_bazel_cmake_defaults - 778408a8 Dillon Sharlet: Add exp_fp64 kernels - 16c63a38 Volodymyr Kysenko: Add benchmarks for fully connected with QC4W and QC2W weights. - 62f1d600 Misha Gutman: Added rewrite `bmm(a:f32, dequant(b:qint8):f32) -> f32` into - f6cf463c Volodymyr Kysenko: Disable BatchMatrixMultiplyDequantBmmRewrite test under ynnpack. - 5660b4b0 Dillon Sharlet: Implement `static_expand_dims` using `static_transpose` - 11e206b8 XNNPACK Team: Implement `static_expand_dims` using `static_transpose` - 8e4e78fd Quentin Khan: Don't use `graph::Tensor` in the XNNPack lowering interface. - b12ed13b Quentin Khan: Fix memory outdated planning optimization invalidated by reshapes. - c3d8c276 Misha Gutman: Disabled bmm rewrite by default as gemma4 fails precision. - fb152529 Volodymyr Kysenko: Rename QD8F32QC8W benchmark to QD8F32QC8WFullyConnected for consistency. - d8f5abe9 Dillon Sharlet: Rename svcnt => svcnts - 445e613a Dillon Sharlet: Fix spurious debug messages about sum(a*b) -> dot(a, b) rewrites - 48e1d0f0 Dillon Sharlet: Add test coverage of static and dynamic shapes - be45bb35 Dillon Sharlet: Add more test coverage for reduce operators - d48bc34c Dillon Sharlet: Add support for rewriting `sum(a*b, init_c)` => `dot(a, b, init_c)` - 4908d191 Marie White: Fix get_dot_kernel type bug - 84aa6a95 Dillon Sharlet: Move gemm, conv shapes hardcoded in benchmarks to text files - c5c413de Richard Townsend: [gn] Update DEPS - d877e1a1 Dillon Sharlet: Fix warning "unexpected tokens following preprocessor directive - expected a newline" - dbf04022 Volodymyr Kysenko: Fix handling of sub-byte types in packer. - 5039d217 Dillon Sharlet: Fix unsimplified slice extents - 6c8ac561 Frank Barchard: F16-VTANH for avx512, wasm and scalar - 3245ce20 Frank Barchard: Enable f16 vsin and vcos wasmrelaxedsimd kernel and scalar fallbacks - 091b9be6 XNNPACK Team: Enable f16 vsin and vcos wasmrelaxedsimd kernel and scalar fallbacks - 99e4485d Dillon Sharlet: Add `horizontal_sum` for floating point types - f919d369 Quentin Khan: Don't call optimize in fp16 rewrite tests. - c723a993 Quentin Khan: Prepare static_reduce test for upcoming fp16 to fp32 rewrite. - 0a27dcf1 Frank Barchard: Enable f16 vsin and vcos wasmrelaxedsimd kernel and scalar fallbacks - f43db489 Dillon Sharlet: Fix loss of precision for fp64 constants - 6e50ae9f Dillon Sharlet: Fix reshape -> slice pattern - 74daa88a Dillon Sharlet: Use internal define_static_expand_dims in define_dot - 28ef957f Dillon Sharlet: Disable sum(a*b) => dot(a, b) rewrite if there are no broadcast dimensions on either side - 016914cb Richard Townsend: [gn] Add pthreadpool for the Chromium config - 25d15607 Volodymyr Kysenko: Fix store in the tail of transpose kernels for sub-byte types. - 5d007c4c Volodymyr Kysenko: Make reference int2/int4 convert work with unaligned n. - 713c3b72 Dillon Sharlet: Require reshape strides to be the shape we need too - 2e6e343b Dillon Sharlet: Rewrite reduce kernels to optimize for numerical behavior - 73c5abb5 Marie White: Fix bug in `get_max_concurrency`. - 1dbb15fc Marie White: Fix fully-connected DynamicB tests to work with QP8. - cf96f77e Marie White: Fix fully-connected DynamicB tests to work with QP8. - 7829cd69 Quentin Khan: Move row sum rewrite to after other optimization rewrites. - 2d16035f Dillon Sharlet: Fix bugs with reduce fusion - 0b66c9f1 Dillon Sharlet: Fix slice bugs - 768003bd Marie White: Fix rank pollution in channelwise quantized scales for YNNPACK. - 26f5c9e1 Marie White: Fix logical extent calculation during constant folding for sub-byte types. - bb971d4 Dillon Sharlet: Refactor the implementation of `remove_static_broadcast_from_elementwise` - 860a6421 XNNPACK Team: Fix rank pollution in channelwise quantized scales for YNNPACK. - f3513194 Frank Barchard: Add rules for updating copyright for new files and removing trailing spaces on blank lines - 5dba5dad Dillon Sharlet: Improve static_slice test coverage - f569d17b Ken Unger: merge master - 12f71cd4 XNNPACK Team: Merge pull request openxla#9971 from ken-unger:f16-vcmul-rvv - 2466b8c2 Dillon Sharlet: Update deps to get bug fixes - cc278f5c Dillon Sharlet: Add support for strides to static_slice - 53007d69 Dillon Sharlet: Add YNN_FLAG_NO_EXCESS_PRECISION - fe166973 Dillon Sharlet: Disable static_slice test until slinky bug is fixed - 4fad5b39 Dillon Sharlet: Disable static_slice test until slinky bug is fixed - d72fa85c Dillon Sharlet: Improve log_fp32 kernels - 95ee916a Dillon Sharlet: Use a better unroll factor for log2_fp32_sse2 - 9ab80cd6 Volodymyr Kysenko: Allow adding function own loops even if some of its non-trivial loops has been already fused. - 11fb8859 Dillon Sharlet: Implement round to nearest even for float -> bf16 conversions - 49e266f7 Volodymyr Kysenko: Add optimized convert int2/int4 to int8 kernels. - ace56b61 Dillon Sharlet: Improve `exp` kernel accuracy and correctness - 34c80155 Volodymyr Kysenko: Make sure partial reduction splits match the loop step. - 7bf9c692 Frank Barchard: Fix ambiguous std::isfinite, std::abs, and std::fpclassify calls for _Float16 in test framework by explicitly casting to float. - c3ac56a5 Quentin Khan: Add subgraph matcher target to `BUILD.gn`. - 1c292bfc Richard Townsend: [gn] Test building AVX512 - 8da42ae2 Gerardo Carranza: Add support for log fp16 in XNNPACK. - 1052f90b Richard Townsend: [gn] Add support for building/testing AArch32 - 01db6e14 Dillon Sharlet: Fix possible infinite recursion in convert - f1fe9b5c Dillon Sharlet: Only rewrite reduce(convert(x)) if we have a kernel for that reduction type. - 98c8ded4 Dillon Sharlet: Polynomial approximation improvements for `exp` and `log` Commit history for dsharlet/slinky (1032be67 -> eb004cb3): - 63c773f3 Dillon: Simplify `make_buffer` with new broadcast dimensions to `transpose` (#802) - 66efc5ef Dillon: Fix `can_fuse` for broadcast dimensions (#803) - 6fcfed78 Dillon: Fix more instances of `fold_factor` that should have been changed to `stride` after #802 (#806) - 2af0a012 Dillon: Remove unnecessary branches for the rank of buffers when accessing dims (#807) - 70b443b7 Dillon: Add fast path to `for_each_element` for rank 0 buffers (#805) - dea32175 Dillon: Remove extent 1 dimensions in `optimize_dims` (#797) - 7e02995b Dillon: Fix out of bounds vector access when simplifying nested transpose ops (#808) - 9140d8ac Dillon: Change drop-loops to keep the loop but rewrite the extent (#809) - f3ab7b63 Dillon: Fix aliases that use buffer bounds before they are defined (#810) - 7bc45e1f Dillon: Add support for `slice_buffer`, `slice_dim`, and `transpose` in `alias_copies` (#811) - 0335d87e Dillon: Cast object instead of function pointer (#812) - 27f5d9d9 Dillon: Fix externally defined fold factors (#813) - c08ef409 Dillon: Fix copy aliasing for copies that remove dimensions (#814) - 284794e8 Dillon: Fix bugs uncovered by copying from a rank > 0 buffer to a scalar (#815) - 56f8638a Dillon: Fix crop simplification bug (#816) - 0fbea044 Dillon: Fix simplify of nested transposes (#817) - c01931be Dillon: Fix a straggler usage of `op->dims` => `dims` (#818) - eb004cb3 Dillon: Fix strided copies (#819) PiperOrigin-RevId: 917405521
================== NOT MEANT FOR MERGING YET — DISCUSSION PR ==================
TL;DR ( if you dont want to read the whole story )
jax.nn.dot_product_attention(impl='xla')intermittently dies with aRET_CHECKatlayout_normalization.cc:431on ROCm under multi-GPU pytest. Root cause: when the autotuner picks hipBLASLt over Triton for a dot,GemmRewriterpins the__cublas$lt$matmuloutput to canonical{n-1,...,1,0}, but layout assignment had already planned thedownstream reshape for a non-canonical layout, the reshape is no longer a bitcast, and
LayoutNormalizationaborts. The flap is rare because hipBLASLt is normally slower than Triton, but contention inflates the autotuner's single-shot Triton timing while hipBLASLt'smin(128 algos)stays stable.The fix is two
pipeline.AddPass<ReshapeDecomposer>()calls ingpu_compiler.cc.Full details
Hey folks, I'm currently debugging an interesting JAX unit test issue where
tests/nn_test.py::NNFunctionsTest::testDotProductAttention*andpallas/gpu_paged_attention_test.py::*test_quantized_paged_attention*intermittently fail in pytest with:
INTERNAL: RET_CHECK failure (external/xla/xla/service/layout_normalization.cc:431) ShapeUtil::ReshapeIsBitcast(s, operand->shape())This is more likely to appear if we make pytest to see multiple GPUs. But it can be reproduced in single gpu run aswell. But it is not very easy to reproduce it needs certain condition to be met during autotune.
Below is full details:
The triggering JAX call is the standard
jax.nn.dot_product_attention(impl='xla')path. JAX traces a Grouped Query Attention computation which fissions into a
jit_dot_generalHLO module containing a single batcheddot_general. JAX side it seems alright to me.Narrowing down to a minimal HLO
What I did is first tried to turn this flaky test into semi deterministic fail so that I can debug easily and I narrowed it down to below HLO
HloModule jit_dot_general, entry_computation_layout={(f16[2,1,32,4,128]{4,3,2,1,0}, f16[4,2,1,128,128]{4,3,2,1,0})->f16[2,1,32,128]{3,2,1,0}} ENTRY %main.1 (args_0_.1: f16[2,1,32,4,128], args_1_.1: f16[4,2,1,128,128]) -> f16[2,1,32,128] { %args_0_.1 = f16[2,1,32,4,128]{4,3,2,1,0} parameter(0) %args_1_.1 = f16[4,2,1,128,128]{4,3,2,1,0} parameter(1) ROOT %dot_general.1 = f16[2,1,32,128]{3,2,1,0} dot(%args_0_.1, %args_1_.1), lhs_batch_dims={0,1}, lhs_contracting_dims={3,4}, rhs_batch_dims={1,2}, rhs_contracting_dims={0,3} }just running this HLO once won't reproduce this issue you need to run this under parallel load only then in rare cases autotuner decides hipBLASLT is the winner based on profiling time and chooses it as backend. hipBLASLT has limitation.
Making the bug deterministic for debugging
I have hacked XLA to simulate same scenario and it can reproduce the bug reliably without needing to run it mutliple time.
you need to apply below patch :
Then run below command:
Why this layout mismatch is problematic
The catch with hipBLASLt is that it only knows how to write its output in one shape i,e canonical row-major.
When XLA's
GemmRewriterrewrites akDotinto a__cublas$lt$matmulcall, it has to pin the custom-call'soutput layout to
{n-1, n-2, ..., 1, 0}to match what the library expects. The compiler doesn't get a vote.The problem is that layout assignment ran much earlier, and at that point it was free to choose any layout for the dot's output. It picked one that matched what the rest of the graph wanted: a downstream
bitcastfurther along the chain expected to read its input in a non-canonical{2,3,1,0}layout, and layout assignment propagated that constraint backward through areshapeso that reshape would just be a free relabel.Now hipBLASLt swaps in and forces the dot's output to
{2,1,0}. The reshape suddenly has an operand laid out one way and an output expected the other way. It's no longer a free relabel to actually go from onelayout to the other you'd have to physically move bytes around. The next pass,
LayoutNormalization, walks everykReshapeand checks exactly this withShapeUtil::ReshapeIsBitcast(...). It getsfalseback, realises the invariant is broken, and fires theRET_CHECK.Concretely, here's the failing HLO snippet we captured (one pass before
the crash):
Why does the autotuner sometimes pick hipBLASLt over Triton?
Now we know the failure next part is to figure why we land in this scenario? why autotuner sometime picks hipBLASLT over Triton path:
For each dot, the autotuner asks every registered backend for its candidate list via
GetSupportedConfigs(instr). For our shape, the Triton backend returns 1 candidate and the HipblasLt backend returns 128 candidates (one per algorithm enumerated byhipblasLtMatmulAlgoGetHeuristic, capped at the XLA constantGemmConfig::kNumAlgorithms=128). The autotuner then profiles all 129 candidates with 1 warmup + 1 timed run each and picks the one with minimum measured duration.Generally this is how timing looks: we can clearly see Triton wins here
but under parallel workers sometime this flips like below: where hipBLASLT timing remain as it is but triton shoots up. Here comparison is unfair for triton where it gets single warmup and single timing while hipBLASLT gets 128 runs and we select min(128 runs time).
I'm taking extreme parallel case of 64 processes running the hlo to highlight the bug
Is the Triton kernel actually slower under load?
To measure this I fired up
rocprofv3and try to compare timing info in the autotunerFrom the table it is clear that:
The actual Triton kernel runs in 1.2-7.7 µs every single time, regardless of contention. The "inflation" the autotuner sees is entirely in the wall-clock window between
hipEventRecord(start)andhipEventRecord(stop)that XLA's profiler uses for timing. That window also catches whatever work the GPU did for OTHER processes between processing the start and stopevents, and
hipEventElapsedTimereports it all as if it were our kernel runtime. The kernel is queued and waiting at the front of its stream while neighbor work gets time-sliced ahead of it. But I'm unable to prove this because Tracing itself adds enough serialization that the 50-100 µs spikes only show up without a profiler attached. I think it is some sort of Schrodinger cat scenario.The fix
Inserts a
ReshapeDecomposerpass before each post-GemmRewriterLayoutNormalizationingpu_compiler.cc. After GemmRewriter pins a__cublas$lt$matmul(hipBLASLt on ROCm) output layout, any downstream non-bitcast reshape is decomposed intotranspose + bitcast, restoring theShapeUtil::ReshapeIsBitcastinvariant thatLayoutNormalizationasserts and was crashing on under autotune-driven backend flips.
Some questions from me:
fix I'm proposing is reasonable? my rationale is even though hipBLASLT is slow it is still valid backend ( but with limitation) so XLA should be able to handle it and it should not fail with RET_CHECK(). very minimum let the autotuner take slower path and complete execution.
Should
GpuLayoutAssignmentbe the real fix? A more thorough fix would be inxla/backends/gpu/transforms/layout_assignment.cc: whenAddBackendConstraintsprocesses akDotthat could be rewritten to__cublas$lt$matmul(any Triton-eligible dot, since the autotuner may later pick the FissionBackend(HipblasLt) candidate), pre-constrain the dot's output layout to{n-1,...,1,0}so the downstream consumer chain plans around it from the start. This is more invasive (it would affect every dot that could go cuBLAS-LT, not just the ones that actually do, potentially addingkCopyoperations in the Triton-wins path).Should the autotuner measurement methodology be improved?
The root flakiness is that
GpuProfiler::Profiledoes1 warmup + 1 timed runper candidate usinghipEventElapsedTime,which captures wait time from neighbor processes' work in multi-tenant GPU scenarios (the "exclusive lock" in
GpuProfiler::Executeis intra-process only,GetGpuMutexis astatic absl::Mutexper(Platform, device_ordinal), not anIPC lock). HipblasLt happens to be robust against this because its
min(128 algo)is order-statistics-stable; but triton suffers itbecause of lack of potential candidate configs