Skip to content

[ROCm] Allow mixed (F8E4M3FNUZ, F8E5M2FNUZ) in Triton dot gate#897

Open
Ruturaj4 wants to merge 4 commits into
mainfrom
ruvaidya/fnuz-mixed-fp8-triton-dot
Open

[ROCm] Allow mixed (F8E4M3FNUZ, F8E5M2FNUZ) in Triton dot gate#897
Ruturaj4 wants to merge 4 commits into
mainfrom
ruvaidya/fnuz-mixed-fp8-triton-dot

Conversation

@Ruturaj4
Copy link
Copy Markdown

The mixed-FP8 in IsTritonSupportedDot only listed the OCP pair (F8E5M2, F8E4M3FN). The ROCm-native FNUZ pair was rejected even though the rest of the file already accepts FNUZ FP8 inputs on ROCm.

This blocks TransformerEngine FP8 GEMM on MI300 (gfx94X), which lowers dgrad to dot_general(F8E4M3FNUZ, F8E5M2FNUZ) and gets routed to a __triton_nested_gemm_fusion. The gate then refuses it at codegen time with "INTERNAL: ... Dot operation only supports same types for lhs and rhs."

Mirror the existing OCP allowance under gpu_version.IsRocm() so the FNUZ pair passes the same check.

Submission Checklist

The mixed-FP8 carve-out in IsTritonSupportedDot only listed the OCP
pair (F8E5M2, F8E4M3FN). The ROCm-native FNUZ pair was rejected even
though the rest of the file already accepts FNUZ FP8 inputs on ROCm.

This blocks TransformerEngine FP8 GEMM on MI300 (gfx94X), which lowers
dgrad to dot_general(F8E4M3FNUZ, F8E5M2FNUZ) and gets routed to a
__triton_nested_gemm_fusion. The gate then refuses it at codegen time
with "INTERNAL: ... Dot operation only supports same types for lhs and
rhs."

Mirror the existing OCP allowance under gpu_version.IsRocm() so the
FNUZ pair passes the same check.
if (lhs_type != rhs_type && !types_are(F8E5M2, F8E4M3FN)) {
const bool mixed_fp8_ok =
types_are(F8E5M2, F8E4M3FN) ||
(gpu_version.IsRocm() && types_are(F8E5M2FNUZ, F8E4M3FNUZ));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may unconditionally support it here and let it fail at AreDotAlgorithmInputAndOutputConversionsSupported?

Ruturaj4 added 3 commits May 30, 2026 00:55
Extends AllDevicesToTest() to include gfx942 (MI300) and gfx950
(MI355x) so the FNUZ pair gets exercised against hardware that
actually supports FNUZ FP8. Adds the FNUZ pair to MixedF8DotTest's
parameterization.

Also fixes the existing F8-requires-Hopper skip to use cc.IsCuda()
instead of the pointer-nonnull anti-pattern. A default-constructed
GpuComputeCapability holds a default CudaComputeCapability whose
IsAtLeastHopper() is false, which the old check would silently treat
as "CUDA pre-Hopper" and skip the test even on what callers think is
ROCm.

Addresses TODO(b/393299275).
IsTritonSupportedDot's only caller (IsTritonSupportedInstructionImpl)
already rejects FNUZ inputs on CUDA via IsTritonSupportedDataType, so
the additional gpu_version.IsRocm() check inside mixed_fp8_ok was dead
code. Drop it for consistency with the OCP pair (F8E5M2, F8E4M3FN)
above, which similarly relies on the upstream data-type filter.
CI clang-format check rejected the manual line-wrap on this two-call OR.
Apply the canonical formatting CI provided.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants