[ROCm] Allow mixed (F8E4M3FNUZ, F8E5M2FNUZ) in Triton dot gate#897
Open
Ruturaj4 wants to merge 4 commits into
Open
[ROCm] Allow mixed (F8E4M3FNUZ, F8E5M2FNUZ) in Triton dot gate#897Ruturaj4 wants to merge 4 commits into
Ruturaj4 wants to merge 4 commits into
Conversation
The mixed-FP8 carve-out in IsTritonSupportedDot only listed the OCP pair (F8E5M2, F8E4M3FN). The ROCm-native FNUZ pair was rejected even though the rest of the file already accepts FNUZ FP8 inputs on ROCm. This blocks TransformerEngine FP8 GEMM on MI300 (gfx94X), which lowers dgrad to dot_general(F8E4M3FNUZ, F8E5M2FNUZ) and gets routed to a __triton_nested_gemm_fusion. The gate then refuses it at codegen time with "INTERNAL: ... Dot operation only supports same types for lhs and rhs." Mirror the existing OCP allowance under gpu_version.IsRocm() so the FNUZ pair passes the same check.
| if (lhs_type != rhs_type && !types_are(F8E5M2, F8E4M3FN)) { | ||
| const bool mixed_fp8_ok = | ||
| types_are(F8E5M2, F8E4M3FN) || | ||
| (gpu_version.IsRocm() && types_are(F8E5M2FNUZ, F8E4M3FNUZ)); |
There was a problem hiding this comment.
You may unconditionally support it here and let it fail at AreDotAlgorithmInputAndOutputConversionsSupported?
Extends AllDevicesToTest() to include gfx942 (MI300) and gfx950 (MI355x) so the FNUZ pair gets exercised against hardware that actually supports FNUZ FP8. Adds the FNUZ pair to MixedF8DotTest's parameterization. Also fixes the existing F8-requires-Hopper skip to use cc.IsCuda() instead of the pointer-nonnull anti-pattern. A default-constructed GpuComputeCapability holds a default CudaComputeCapability whose IsAtLeastHopper() is false, which the old check would silently treat as "CUDA pre-Hopper" and skip the test even on what callers think is ROCm. Addresses TODO(b/393299275).
IsTritonSupportedDot's only caller (IsTritonSupportedInstructionImpl) already rejects FNUZ inputs on CUDA via IsTritonSupportedDataType, so the additional gpu_version.IsRocm() check inside mixed_fp8_ok was dead code. Drop it for consistency with the OCP pair (F8E5M2, F8E4M3FN) above, which similarly relies on the upstream data-type filter.
CI clang-format check rejected the manual line-wrap on this two-call OR. Apply the canonical formatting CI provided.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The mixed-FP8 in
IsTritonSupportedDotonly listed the OCP pair (F8E5M2,F8E4M3FN). The ROCm-native FNUZ pair was rejected even though the rest of the file already accepts FNUZ FP8 inputs on ROCm.This blocks TransformerEngine FP8 GEMM on MI300 (gfx94X), which lowers dgrad to dot_general(F8E4M3FNUZ, F8E5M2FNUZ) and gets routed to a
__triton_nested_gemm_fusion. The gate then refuses it at codegen time with "INTERNAL: ... Dot operation only supports same types for lhs and rhs."Mirror the existing OCP allowance under gpu_version.IsRocm() so the FNUZ pair passes the same check.
Submission Checklist