[XLA:GPU] Fix num_warps in ROCm default Triton config for 16x16 tile#853
[XLA:GPU] Fix num_warps in ROCm default Triton config for 16x16 tile#853phambinhfin wants to merge 1 commit into
Conversation
|
Are we sure GetMaxWarpsPerCta is right for us anyway? I think we have mfma instruction that do as little as 4x4. Not sure for wmma and not sure what triton for amdgpu prefers to use. |
|
@draganmladjenovic makes sense, i just realized that we also could modify dot_search_space.cc to exploit the AMD capacity spaces , let me check |
92e9c6c to
709897d
Compare
…-aware on ROCm
GetMaxWarpsPerCta returns `min(threads_per_block_limit/threads_per_warp,
ceil(M/16) * ceil(N/8))`, where the `{16, 8}` sub-tile is the smallest
NVIDIA mma instruction (and the comment in the function makes that
intent explicit, also covering wgmma). This is correct on NVIDIA, but
two things break on AMD:
1. The sub-tile shape doesn't match what Triton emits on AMD:
- CDNA (MI200/MI300/MI350): Triton's amd_mfma encoding only ever
uses instrShape = [16,16,*] or [32,32,*]; see
chooseMfmaInstruction in triton-lang/triton third_party/amd
(PRs openxla#5937 reworking the MFMA intrinsic map and openxla#8213 unifying
mDim/nDim into instrShape). The smaller V_MFMA_*_4x4x*
instructions exist in hardware (AMD "matrix cores" lab note:
https://gpuopen.com/learn/amd-lab-notes/amd-lab-notes-matrix-cores-readme/)
but Triton does not select them.
- RDNA3+: WMMA only supports 16x16x16 at all (AMD GPUOpen
"WMMA on RDNA3":
https://gpuopen.com/learn/wmma_on_rdna3, and the RDNA3 ISA
reference). So the smallest per-wave sub-tile is 16x16, not
16x8.
2. More importantly, the underlying assumption "one warp per output
sub-tile" doesn't hold on AMD. Extra warps per CTA are used by
Triton's stream pipeliner to software-pipeline the K loop (see
triton-lang/triton PR openxla#4148 "Introduce stream pipeliner v2", and
ROCm "Optimizing Triton kernels" docs), not to tile the M*N
output spatially. A spatial cap is therefore conceptually the
wrong bound on AMD.
Concretely this was silently filtering valid ROCm hints. For a 16x16
output tile the formula yields `min(16, ceil(16/16) * ceil(16/8)) = 2`,
so the default ROCm hint `{block_m=16, block_n=16, num_warps=4}` was
clamped, failed the exact-match in OptimizeConfigSet, and disappeared
from the default config set -- which is what made the following three
tests fail on ROCm:
TritonBackendTest.CostModelOptions_Combination
TritonBackendTest.CostModelOptions_Filter
TritonBackendTest.CostModelOptions_TopFromDefault
Fix: branch on `gpu_compute_capability().IsRocm()` (same pattern this
file already uses at line 145 for AddWavesPerEuParameter) and on ROCm
drop the spatial term, keeping only the hardware bound
`threads_per_block_limit / threads_per_warp`. The NVIDIA branch is
byte-identical to before. No changes needed to rocm.txtpb -- the
existing hints become valid under the corrected cap.
Verified on MI350 (gfx950), inside the rocm_ci docker image
`rocm/tensorflow-build:latest-jammy-pythonall-rocm7.2.1-ci_official`:
* Full TritonBackendTest suite: 19 passed, 7 skipped (Hopper/Ampere/
TMA/WarpSpecialization, correctly skipped on ROCm). The three
CostModelOptions_* tests above all pass.
* Full dot_search_space_test: 25/25 pass, including
ConsidersFewWarpsPerCtaAndMmaForSmallProblem and
EnsuresWgmmaShapeForLargeProblem (NVIDIA-side coverage of
GetMaxWarpsPerCta).
709897d to
c628673
Compare
| // On AMD, Triton lowers tl.dot to MFMA on CDNA or WMMA on RDNA3+. In | ||
| // Triton's amd_mfma encoding only 16x16x* and 32x32x* instrShapes are | ||
| // emitted (see triton-lang/triton third_party/amd's chooseMfmaInstruction | ||
| // and PRs #5937 / #8213); the smaller V_MFMA_*_4x4x* exists in hardware |
There was a problem hiding this comment.
seems your commit message is incorrect
chooseMfmaInstruction in triton-lang/triton third_party/amd
(PRs openxla#5937 reworking the MFMA intrinsic map and openxla#8213
it should be triton-lang/triton#5937 and triton-lang/triton#8213
| // "Optimizing Triton kernels"), not to tile the M*N output spatially. | ||
| // So we don't impose a per-output-tile cap here -- only the hardware | ||
| // bound (threads_per_block_limit / threads_per_warp) applies. | ||
| return std::max(min_warps_per_cta_, max_warps); |
There was a problem hiding this comment.
Would not this make the exhaustive tune unusable? We sure there is not some other upper cap like maybe num stages * warp size or something.
|
I just notice this whole file is NVIDIA specific. Someone should look through it. |
|
@nurmukhametov can you take a look ? |
These tests pass after openxla#42248, so I believe it is better to keep the code here as-is. The mentioned CostModel indeed is Nvidia-specific, so we should test and adjust it as needed before it is enabled by default. Independently of the UT fix, I wonder if the reasoning is correct minding that Triton's AMD backend can also lower tl.dot via FMA (as a fall back when no MFMA/WMMA intrinsic matches). That path has no 16×16 instruction tile and warps tile M×N spatially. So the assumptions here seem to apply to MFMA/WMMA-eligible dots rather than to every ROCm dot. |
Fixes the following three tests on ROCm (verified on MI350X / gfx950, applicable to MI200/MI300 as well):
Background
TritonDotFusionSearchSpace::OptimizeConfigSet filters the exhaustive search space against the hints from rocm.txtpb using exact TritonGemmConfig equality (only block_m / block_n / block_k are clamped; num_warps / num_stages / waves_per_eu / ... are not). Therefore each hint must agree with what the search space actually generates for the declared tile shape, or it is silently dropped.
According to xla/backends/gpu/autotuner/triton/dot_search_space.cc:
For a 16x16 output tile this evaluates (on every AMD GPU, since threads_per_block_limit=1024 and threads_per_warp=64 are identical across gfx90a / gfx942 / gfx950) to:
AddCtaSizeParameter therefore only emits num_warps=2 for a 16x16 tile. The previous hint asked for num_warps=4, so it had no matching entry in the exhaustive set and was filtered out by OptimizeConfigSet, leaving the optimized default set with a single config. The three CostModelOptions_* tests assume the default set yields >=2 estimable configs and so they all fail:
5 mixin).
Fix
Align the 16x16 hint with the search space by setting num_warps=2. Both hints now survive OptimizeConfigSet, the default set has two distinct configs again, and the three tests pass.
This is a regression introduced by commit bda66ea ("[XLA:GPU] Remove split_k > 1 from default configs") which reduced rocm.txtpb from 6 hints to 2 and so unmasked the latent mismatch in the 16x16 hint.
Verification
Executed inside the same docker image the ROCm/xla CI uses:
rocm/tensorflow-build:latest-jammy-pythonall-rocm7.2.1-ci_official
on AMD Instinct MI350X (gfx950) with ROCm 7.2.1:
bazel test --config=rocm --repo_env=TF_ROCM_AMDGPU_TARGETS=gfx950
--test_filter='TritonBackendTest.*'
//xla/backends/gpu/autotuner:triton_test
-> 19 PASSED, 7 SKIPPED (CUDA-only), 0 FAILED.
The 7 SKIPPED tests all hit GTEST_SKIP() << "Not supported on ROCm."; they cover Ampere / Hopper / Blackwell / TMA / warp-specialization features. triton_configs_test also passes.
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist