Skip to content

[XLA:GPU] Fix num_warps in ROCm default Triton config for 16x16 tile#853

Open
phambinhfin wants to merge 1 commit into
mainfrom
phambinh/fix-rocm-triton-cost-model-default-configs
Open

[XLA:GPU] Fix num_warps in ROCm default Triton config for 16x16 tile#853
phambinhfin wants to merge 1 commit into
mainfrom
phambinh/fix-rocm-triton-cost-model-default-configs

Conversation

@phambinhfin
Copy link
Copy Markdown

Fixes the following three tests on ROCm (verified on MI350X / gfx950, applicable to MI200/MI300 as well):

TritonBackendTest.CostModelOptions_TopFromDefault
TritonBackendTest.CostModelOptions_Filter
TritonBackendTest.CostModelOptions_Combination

Background

TritonDotFusionSearchSpace::OptimizeConfigSet filters the exhaustive search space against the hints from rocm.txtpb using exact TritonGemmConfig equality (only block_m / block_n / block_k are clamped; num_warps / num_stages / waves_per_eu / ... are not). Therefore each hint must agree with what the search space actually generates for the declared tile shape, or it is silently dropped.

According to xla/backends/gpu/autotuner/triton/dot_search_space.cc:

int TritonDotFusionSearchSpace::GetMaxWarpsPerCta(OutputTile tile) const {
  // A single mma instruction is of output shape at least 16x8 (the same
  // also holds for wgmma: the warp-group level instruction is at least
  // 64x8, and split 4-ways across the 4 warps in the group).
  constexpr OutputTile kMmaSubTile = {16, 8};
  const int max_warps =
      device_description_.threads_per_block_limit() /
      std::max<int>(device_description_.threads_per_warp(), 1);
  const int lhs_warps = CeilOfRatio(tile.lhs_dim, kMmaSubTile.lhs_dim);
  const int rhs_warps = CeilOfRatio(tile.rhs_dim, kMmaSubTile.rhs_dim);
  return std::max(min_warps_per_cta_,
                  std::min(max_warps, lhs_warps * rhs_warps));
}

For a 16x16 output tile this evaluates (on every AMD GPU, since threads_per_block_limit=1024 and threads_per_warp=64 are identical across gfx90a / gfx942 / gfx950) to:

lhs_warps        = ceil(16 / 16) = 1
rhs_warps        = ceil(16 / 8)  = 2
max_warps        = 1024 / 64     = 16
GetMaxWarpsPerCta = max(min_warps_per_cta_=2,
                        min(16, 1*2))      = 2

AddCtaSizeParameter therefore only emits num_warps=2 for a 16x16 tile. The previous hint asked for num_warps=4, so it had no matching entry in the exhaustive set and was filtered out by OptimizeConfigSet, leaving the optimized default set with a single config. The three CostModelOptions_* tests assume the default set yields >=2 estimable configs and so they all fail:

  • CostModelOptions_TopFromDefault expects size 2 -> gets 1.
  • CostModelOptions_Filter expects new < default -> gets 1 vs 1.
  • CostModelOptions_Combination expects size 7 -> gets 6 (1 top +
    5 mixin).

Fix

Align the 16x16 hint with the search space by setting num_warps=2. Both hints now survive OptimizeConfigSet, the default set has two distinct configs again, and the three tests pass.

This is a regression introduced by commit bda66ea ("[XLA:GPU] Remove split_k > 1 from default configs") which reduced rocm.txtpb from 6 hints to 2 and so unmasked the latent mismatch in the 16x16 hint.

Verification

Executed inside the same docker image the ROCm/xla CI uses:
rocm/tensorflow-build:latest-jammy-pythonall-rocm7.2.1-ci_official
on AMD Instinct MI350X (gfx950) with ROCm 7.2.1:

bazel test --config=rocm --repo_env=TF_ROCM_AMDGPU_TARGETS=gfx950
--test_filter='TritonBackendTest.*'
//xla/backends/gpu/autotuner:triton_test

-> 19 PASSED, 7 SKIPPED (CUDA-only), 0 FAILED.

The 7 SKIPPED tests all hit GTEST_SKIP() << "Not supported on ROCm."; they cover Ampere / Hopper / Blackwell / TMA / warp-specialization features. triton_configs_test also passes.

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@draganmladjenovic
Copy link
Copy Markdown

Are we sure GetMaxWarpsPerCta is right for us anyway? I think we have mfma instruction that do as little as 4x4. Not sure for wmma and not sure what triton for amdgpu prefers to use.

@phambinhfin
Copy link
Copy Markdown
Author

@draganmladjenovic makes sense, i just realized that we also could modify dot_search_space.cc to exploit the AMD capacity spaces , let me check

@phambinhfin phambinhfin force-pushed the phambinh/fix-rocm-triton-cost-model-default-configs branch from 92e9c6c to 709897d Compare May 11, 2026 14:23
…-aware on ROCm

GetMaxWarpsPerCta returns `min(threads_per_block_limit/threads_per_warp,
ceil(M/16) * ceil(N/8))`, where the `{16, 8}` sub-tile is the smallest
NVIDIA mma instruction (and the comment in the function makes that
intent explicit, also covering wgmma). This is correct on NVIDIA, but
two things break on AMD:

  1. The sub-tile shape doesn't match what Triton emits on AMD:
     - CDNA (MI200/MI300/MI350): Triton's amd_mfma encoding only ever
       uses instrShape = [16,16,*] or [32,32,*]; see
       chooseMfmaInstruction in triton-lang/triton third_party/amd
       (PRs openxla#5937 reworking the MFMA intrinsic map and openxla#8213 unifying
       mDim/nDim into instrShape). The smaller V_MFMA_*_4x4x*
       instructions exist in hardware (AMD "matrix cores" lab note:
       https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme/)
       but Triton does not select them.
     - RDNA3+: WMMA only supports 16x16x16 at all (AMD GPUOpen
       "WMMA on RDNA3":
       https://gpuopen.com/learn/wmma_on_rdna3, and the RDNA3 ISA
       reference). So the smallest per-wave sub-tile is 16x16, not
       16x8.
  2. More importantly, the underlying assumption "one warp per output
     sub-tile" doesn't hold on AMD. Extra warps per CTA are used by
     Triton's stream pipeliner to software-pipeline the K loop (see
     triton-lang/triton PR openxla#4148 "Introduce stream pipeliner v2", and
     ROCm "Optimizing Triton kernels" docs), not to tile the M*N
     output spatially. A spatial cap is therefore conceptually the
     wrong bound on AMD.

Concretely this was silently filtering valid ROCm hints. For a 16x16
output tile the formula yields `min(16, ceil(16/16) * ceil(16/8)) = 2`,
so the default ROCm hint `{block_m=16, block_n=16, num_warps=4}` was
clamped, failed the exact-match in OptimizeConfigSet, and disappeared
from the default config set -- which is what made the following three
tests fail on ROCm:

  TritonBackendTest.CostModelOptions_Combination
  TritonBackendTest.CostModelOptions_Filter
  TritonBackendTest.CostModelOptions_TopFromDefault

Fix: branch on `gpu_compute_capability().IsRocm()` (same pattern this
file already uses at line 145 for AddWavesPerEuParameter) and on ROCm
drop the spatial term, keeping only the hardware bound
`threads_per_block_limit / threads_per_warp`. The NVIDIA branch is
byte-identical to before. No changes needed to rocm.txtpb -- the
existing hints become valid under the corrected cap.

Verified on MI350 (gfx950), inside the rocm_ci docker image
`rocm/tensorflow-build:latest-jammy-pythonall-rocm7.2.1-ci_official`:
  * Full TritonBackendTest suite: 19 passed, 7 skipped (Hopper/Ampere/
    TMA/WarpSpecialization, correctly skipped on ROCm). The three
    CostModelOptions_* tests above all pass.
  * Full dot_search_space_test: 25/25 pass, including
    ConsidersFewWarpsPerCtaAndMmaForSmallProblem and
    EnsuresWgmmaShapeForLargeProblem (NVIDIA-side coverage of
    GetMaxWarpsPerCta).
@phambinhfin phambinhfin force-pushed the phambinh/fix-rocm-triton-cost-model-default-configs branch from 709897d to c628673 Compare May 11, 2026 14:30
@i-chaochen i-chaochen requested a review from nurmukhametov May 12, 2026 08:34
Copy link
Copy Markdown
Collaborator

@i-chaochen i-chaochen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

// On AMD, Triton lowers tl.dot to MFMA on CDNA or WMMA on RDNA3+. In
// Triton's amd_mfma encoding only 16x16x* and 32x32x* instrShapes are
// emitted (see triton-lang/triton third_party/amd's chooseMfmaInstruction
// and PRs #5937 / #8213); the smaller V_MFMA_*_4x4x* exists in hardware
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems your commit message is incorrect

chooseMfmaInstruction in triton-lang/triton third_party/amd
(PRs openxla#5937 reworking the MFMA intrinsic map and openxla#8213

it should be triton-lang/triton#5937 and triton-lang/triton#8213

// "Optimizing Triton kernels"), not to tile the M*N output spatially.
// So we don't impose a per-output-tile cap here -- only the hardware
// bound (threads_per_block_limit / threads_per_warp) applies.
return std::max(min_warps_per_cta_, max_warps);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would not this make the exhaustive tune unusable? We sure there is not some other upper cap like maybe num stages * warp size or something.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NVM. It caps it at 16.

@draganmladjenovic
Copy link
Copy Markdown

I just notice this whole file is NVIDIA specific. Someone should look through it.

@phambinhfin
Copy link
Copy Markdown
Author

@nurmukhametov can you take a look ?

@nurmukhametov
Copy link
Copy Markdown
Member

@nurmukhametov can you take a look ?

These tests pass after openxla#42248, so I believe it is better to keep the code here as-is.

The mentioned CostModel indeed is Nvidia-specific, so we should test and adjust it as needed before it is enabled by default.

Independently of the UT fix, I wonder if the reasoning is correct minding that Triton's AMD backend can also lower tl.dot via FMA (as a fall back when no MFMA/WMMA intrinsic matches). That path has no 16×16 instruction tile and warps tile M×N spatially. So the assumptions here seem to apply to MFMA/WMMA-eligible dots rather than to every ROCm dot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants