Skip to content

[Triton] batched_gemm_a8w8_smallB_blockscale: split-K FP8 GEMM for small-B regime (fixes ROCm/aiter#3000)#2

Open
buddywhitman wants to merge 21 commits into
mainfrom
feat/triton-batched-gemm-smallB-blockscale
Open

[Triton] batched_gemm_a8w8_smallB_blockscale: split-K FP8 GEMM for small-B regime (fixes ROCm/aiter#3000)#2
buddywhitman wants to merge 21 commits into
mainfrom
feat/triton-batched-gemm-smallB-blockscale

Conversation

@buddywhitman
Copy link
Copy Markdown
Owner

@buddywhitman buddywhitman commented Jun 5, 2026

Summary

Implements batched_gemm_a8w8_smallB_blockscale — a Triton FP8 batched GEMM
kernel optimized for small-B (B=2, DeepSeek V4 grouped output LoRA wo_a).

This is a contribution targeting ROCm#3000. Unblocks ROCm/ATOM#676.

Problem

Existing aiter FP8 batched GEMM kernels lose to BF16 einsum by 1.5×–7× at B=2
due to CU starvation: grid (B=2, M_tiles=1, N_tiles=8) = 16 work-groups on a
304-CU GPU → 5% CU utilization.

Solution

1. Grid collapse + split-K=8

Flatten B into M: (B×M_tiles, N_tiles, split_k) grid.
At B=2, M=1, split_k=8: 128 work-groups → 42% CU utilization (was 5%).

2. Fused bf16 output for split_k=1

When split_k=1, the main kernel writes bf16 directly — no partial-sum buffer,
no separate reduction kernel. Eliminates ~160μs overhead at M=256.

3. Per-128-block W-scales inline

Weight scales loaded per 128-element block within the kernel.
Single-row vector load (no 128× register expansion from broadcast).

New files

  • aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
  • aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
  • aiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json (MI300X tuned configs)
  • aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json (MI355X tuned configs)
  • op_tests/test_batched_gemm_smallB.py

Expected results on AMD MI300X (gfx942)

Correctness: tested on NVIDIA RTX 3070Ti (sm_86) using float8_e5m2.
AMD production uses float8_e4m3fnuz (set AITER_AMD_FP8=1).
NVIDIA sm_86 has no native FP8 hardware — AMD MI300X/MI355X numbers
expected to show FP8 ≥ BF16 at all M values.

Work-group count improvement at B=2, N=1024, BLOCK_N=128:

M Legacy WGs This PR WGs AMD CU util
1 16 128 5% → 42%
4 16 128 5% → 42%
16 16 64 5% → 21%
64 16 512 5% → 168% (multi-wave)

AI Assistance Disclosure

Implementation assisted by Claude Sonnet 4.6. All changes reviewed and
understood by the submitter. Tests run and passing on NVIDIA RTX 3070Ti.

Summary by CodeRabbit

  • New Features

    • FP8 batched GEMM for small batch sizes and integrated per-token/group quantization
    • gfx1250 WMMA/preshuffle GEMM support, weight preshuffle helper, and grouped top-k routing for MoE
    • Optional operand upcast for fused KV-cache paths and improved MoE reduce/masking support
  • Backward Compatibility

    • Existing GEMM entry points remain resolvable via backward-compat mappings
  • Tests

    • New correctness suites and benchmarks covering GEMM, WMMA gfx1250, routing, and MoE
  • Chores

    • CI/workflow timeout and download robustness improvements

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Jun 5, 2026

Review Change Stack

📝 Walkthrough

Walkthrough

Adds Triton FP8 small-B batched GEMM kernels and Python wrappers; extensive gfx1250 FlyDSL/WMMA kernels, tuners, and utilities; MoE routing and masked reduction plumbing; fused KV-cache upcast support; aiter_tensor_t operator ABI migration; CI/workflow/script updates; and many tests/benchmarks.

Changes

Core Features & Kernels

Layer / File(s) Summary
Triton batched GEMM kernels
aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py, aiter/ops/triton/gemm/batched/*
Adds per-token-group FP8 quant kernel, main batched GEMM kernel with split-K tiling, flat split-K reduction, Python orchestration (batched_gemm_a8w8_smallB_blockscale, per_token_group_quant, BF16 fused entry), package exports, and backward-compat mapping.
gfx1250 FlyDSL WMMA backend & tuner
aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py, aiter/ops/flydsl/gemm_tune/*, aiter/ops/flydsl/kernels/gemm_common_gfx1250.py, aiter/ops/triton/_gluon_kernels/gfx1250/*
Adds gfx1250 preshuffled WMMA GEMM backend, kernel-name utils, WMMA candidate catalogue and tuner, shared gfx1250 GEMM helpers, MXFP4 preshuffle kernel, and gluon/triton dispatch adjustments.
MoE routing, grouped-topk, and reduction
aiter/ops/triton/moe/moe_routing/*, aiter/ops/flydsl/kernels/moe_gemm_2stage.py, aiter/ops/flydsl/moe_kernels.py
Rewrites routing API to support score_mode and grouped_topk, adds _grouped_topk kernel and grouped_topk wrapper, and changes compile/launcher for MoE stage2 reduction to accept expert_mask and topk_ids with masked reduction support.
Fused KV-cache upcast support
aiter/ops/triton/fusions/fused_kv_cache.py, _triton_kernels/.../fused_kv_cache.py
Adds upcast_operand runtime flag and UPCAST_OPERAND compile-time flag to optionally upcast cos/sin to float32 in fused KV-cache kernels; adjusts kernel pid/work distribution and gating.
Operator ABI & codegen migration
csrc/include/*, csrc/kernels/*, csrc/kernels/generate_binaryop.py, csrc/pybind/aiter_operator_pybind.cu, aiter/ops/aiter_operator.py
Migrates binary operator API and codegen from torch::Tensor to aiter_tensor_t with bool success returns, updates generated dispatch/signatures, implements compiled-kernel-backed Python binary ops with fallbacks, and sets pybind stream state.

CI, Scripts, Configs, Tests & Benchmarks

Layer / File(s) Summary
CI/workflow & scripts
.github/workflows/*, .github/scripts/*
Adds multiple explicit artifact download timeouts, improves Triton wheel prepare/download with ROCm fallback, S3 manifest generation for main, and test-splitting memory-weight floors.
Hardware tuning configs
aiter/ops/triton/configs/gemm/*
Adds gfx942/gfx950 batched GEMM configs and simplifies gfx1250 AFP4/WFP4 preshuffle config map.
Tests & benchmarks
op_tests/*, op_tests/triton_tests/*, op_tests/op_benchmarks/*
Adds correctness tests for batched GEMM smallB and gfx1250 WMMA; extends MoE, MLA, GEMM AFP4/WFP4 tests; updates benchmarks to propagate FP8 dtype and CLI flags.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant batched_gemm_a8w8_smallB_blockscale_bf16
    participant per_token_group_quant
    participant per_token_group_quant_kernel
    participant batched_gemm_a8w8_smallB_blockscale
    participant _batched_gemm_a8w8_smallB_blockscale_kernel
    participant _split_k_reduce_flat_kernel

    Caller->>batched_gemm_a8w8_smallB_blockscale_bf16: X (BF16), B_weight, B_scale
    batched_gemm_a8w8_smallB_blockscale_bf16->>per_token_group_quant: X, group_size, transpose_bm
    per_token_group_quant->>per_token_group_quant_kernel: Launch quantization kernel
    per_token_group_quant_kernel-->>per_token_group_quant: X_q (FP8), scale (FP32)
    per_token_group_quant-->>batched_gemm_a8w8_smallB_blockscale_bf16: quantized X_q, A_scale
    batched_gemm_a8w8_smallB_blockscale_bf16->>batched_gemm_a8w8_smallB_blockscale: X_q, B_weight, A_scale, B_scale, split_k
    batched_gemm_a8w8_smallB_blockscale->>_batched_gemm_a8w8_smallB_blockscale_kernel: Launch GEMM (split_k tiles)
    _batched_gemm_a8w8_smallB_blockscale_kernel-->>batched_gemm_a8w8_smallB_blockscale: float32 partials or BF16 output
    alt split_k > 1
        batched_gemm_a8w8_smallB_blockscale->>_split_k_reduce_flat_kernel: Launch reduction
        _split_k_reduce_flat_kernel-->>batched_gemm_a8w8_smallB_blockscale: BF16 reduced result
    end
    batched_gemm_a8w8_smallB_blockscale-->>Caller: BF16 output
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Poem

🐰 In kernels deep where scalings dance,

FP8 hops in, given a chance,
Triton tiles and WMMA hum,
MoE routes and tests now come,
Aiter springs forward — code and prance.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/triton-batched-gemm-smallB-blockscale

xytpai and others added 11 commits June 5, 2026 10:31
…e co-issue (ROCm#3469)

* add 8w policy

* add selection filter

* update filter

* upd dsv3

* upd gptoss

* upd kimi&kimik2 k2 still in tune

* Update glm5_bf16_tuned_gemm.csv

* Update gemm_kernels.py: add tm80

* Update kimik2_bf16_tuned_gemm.csv

* change to fx.Pointer

* Update dsv4_bf16_tuned_gemm.csv

* Update kimi_bf16_tuned_gemm.csv

* Update kimik2_bf16_tuned_gemm.csv

* Update dsv4_bf16_tuned_gemm.csv
* Only show the error when someone calls gemm_a16w16_opus or opus_gemm_a16w16_tune on a non-gfx950 device
)

* flydsl moe: add EP reduce path with optional reduce mask

moe_sorting gains need_reduce/has_reduce_mask to size moe_buf for the
stage2 reduce epilogue (per-topk slab vs atomic accumulate vs caller-
owned). Stage2 wrapper and moe_gemm_2stage updated to drive the
compile_moe_reduction path; test_moe_ep covers the new EP variants.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>

* improve

* hack topk to retrieve correct tuned result

* format

* flydsl moe_reduction: fix i32 voffset overflow for large intermediates

The compile_moe_reduction kernel addressed X [tokens, topk, model_dim] via
fx.rocdl.make_buffer_tensor + i32 per-thread voffsets. For large batches the
per-token byte offset exceeds 2^32 (e.g. 131072*6*4096*2 = 6 GiB), wrapping
the i32 voffset and corrupting the reduction.

Fold the per-WG token offset into each buffer descriptor's 48-bit base
pointer via create_buffer_resource(base_byte_offset=fx.Int64(...)). In-kernel
voffsets stay i32-safe because they only span one token's slab
(topk*model_dim*elem_bytes).

The vector path now uses direct buffer_load/buffer_store against the
base-offset-shifted resources (the make_buffer_tensor helper has no base
hook). Behavior at typical sizes is unchanged; logits_diff at token=131072
drops from 0.337 to 0.0054, matching torch.sum.

* flydsl moe_gemm_2stage: port to fx.Pointer API after main merge

Adopt main's commit 1be4ee9 (ROCm#3373) plumbing in compile_moe_gemm1
and compile_moe_gemm2: switch kernel/launcher signatures to fx.Pointer,
add a local _ptr_buffer_resource helper, precompute w_nbytes/sw_nbytes,
and replace extract_base_index with ptrtoint + index_cast.

For compile_moe_reduction, preserve the EP-reduce semantics (X, Y,
expert_mask, topk_ids inputs and the fused valid = expert_mask[topk_ids
[t,k]] != 0 gather) and the i32 voffset overflow fix from 7e1935f,
while migrating its arguments to fx.Pointer. The per-token slab offset
is folded into the descriptor base address via a helper that does the
add in i64 before create_buffer_resource_from_addr. Add a new
compile-time num_experts kwarg so the expert_mask descriptor can be
sized exactly. Update _MoeGemm2ReduceWrapper and the moe_kernels reduce
call site to pass pointers via flyc.from_c_void_p / _ptr_view_safe.

* fix

* fused_moe: replace need_reduce/has_reduce_mask with accumulate flag

Collapse the two boolean knobs on moe_sorting into a single `accumulate`
flag that mirrors how the stage2 path actually uses moe_buf, and pick it
automatically from the compiled FlyDSL stage2 kernel name.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>

* fused_moe: allocate moe_out in stage2 reduce mode; unhardcode a8w4 mxfp4 activation

In FlyDSL stage2 reduce mode (no expert_mask), moe_sorting returns a
zero-shape placeholder for moe_buf since the reduce path owns the final
write. fused_moe_2stages now allocates a real [token_num, model_dim]
output tensor before invoking stage2 when it detects the placeholder.

Also splits the legacy case generator in test_moe_2stage.py so the
a8w4 mxfp4 (per_1x32, fp8, fp4x2) triple iterates args.act instead of
hardcoding Swiglu, allowing -a silu to actually take effect for -q 7.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>

* fix

* fused_moe: fix black formatting

---------

Co-authored-by: Claude Opus 4 <noreply@anthropic.com>
Co-authored-by: Sergey Solo <ssolovye@amd.com>
* CI: add timeouts for artifact downloads

Fail stuck artifact downloads promptly so CI jobs do not occupy runners indefinitely.

* CI: avoid packing memory-heavy tests together

Weight high-memory Aiter tests during shard assignment so they do not regress into the same shard after FILE_TIMES updates.
…OCm#3542)

* [mla_decode_gluon] return_lse: full decode + merged fp32 lse (remove DCP)

return_lse now runs the full decode (stage-1 + stage-2 reduce) and returns
the merged fp32 lse [batch, nhead] (e_max + log(e_sum) over the whole
sequence), matching the asm/torch references. Removes the stage-1-only DCP
path. Supported for all regimes.

- per-split lse kept in a separate fp32 mid_lse buffer (not packed into the
  bf16 O), fixing the bf16 precision loss in the stage-2 merge
- stage-2 (_mla_softmax_reducev_kernel) writes final_lse; NUM_KV_SPLITS==1
  fast path writes it directly in stage-1
- test_mla.py: validate returned lse vs lse_ref; `o` is now identical with and
  without -lse, so the gluon bh16bn64 gate is unified to just non-empty splits
  (ctx >= 256 // batch_size). README updated.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

* [mla_decode_gluon] simplify lse comments

Trim verbose comments along the call chain (drop repeated "separate fp32
tensor" / "matching logsumexp" prose, use "split" not "K").

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
* fix multithread_reduce_max_dpp

* fix
* CI: add fallback Triton wheel download

* CI: retry Aiter artifact downloads

* Revert "CI: retry Aiter artifact downloads"

This reverts commit 48d5e01.
…m#3546)

* update

* new grid layout for triton

* black format

* add upcast_operand option
…P2 (ROCm#3287)

Signed-off-by: Xavier Aguilar <Xavier.AguilarFruto@amd.com>
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json (1)

1-8: 💤 Low value

Note: Config is identical to gfx942.

The gfx950 config is byte-for-byte identical to gfx942. This is acceptable if the optimal tuning parameters are the same for both architectures, but worth noting in case gfx950-specific tuning is planned for the future.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In
`@aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json`
around lines 1 - 8, The file gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json is
byte-for-byte identical to the gfx942 config; to make this explicit and avoid
confusion later, add a clear metadata entry or comment inside the JSON
indicating it's based on or copied from gfx942 (for example a top-level
"based_on": "gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json" or "note":
"identical to gfx942") and keep all existing entries (e.g., the keys like
"B=2,K=4096,N=1024,M=1" and their BLOCK_M/BLOCK_N/BLOCK_K/split_k values)
unchanged so future reviewers know the duplication is intentional.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py`:
- Around line 196-201: The call to per_token_group_quant_kernel uses a hardcoded
_FP8_E4M3_MAX (448.0) even when _default_fp8_dtype() selects float8_e5m2,
causing incorrect scaling; update per_token_group_quant() to compute fp8_max
from the selected fp8_dtype (e.g., map torch.float8_e4m3fnuz -> _FP8_E4M3_MAX
and torch.float8_e5m2 -> _FP8_E5M2_MAX (57344.0)) and pass that computed fp8_max
into the per_token_group_quant_kernel invocation (replace the literal
_FP8_E4M3_MAX), ensuring the kernel uses the correct dynamic range for both
scale and clamping.

---

Nitpick comments:
In
`@aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json`:
- Around line 1-8: The file gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json is
byte-for-byte identical to the gfx942 config; to make this explicit and avoid
confusion later, add a clear metadata entry or comment inside the JSON
indicating it's based on or copied from gfx942 (for example a top-level
"based_on": "gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json" or "note":
"identical to gfx942") and keep all existing entries (e.g., the keys like
"B=2,K=4096,N=1024,M=1" and their BLOCK_M/BLOCK_N/BLOCK_K/split_k values)
unchanged so future reviewers know the duplication is intentional.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 6690b32f-001a-4dfc-9ed0-7e8cade445b5

📥 Commits

Reviewing files that changed from the base of the PR and between b5aa298 and 7b91a13.

📒 Files selected for processing (7)
  • aiter/ops/triton/__init__.py
  • aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
  • aiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
  • aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
  • aiter/ops/triton/gemm/batched/__init__.py
  • aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
  • op_tests/test_batched_gemm_smallB.py

Comment on lines +196 to +201
per_token_group_quant_kernel[(B_dim * M_dim, n_groups)](
X, X_q, scale,
M_dim, B_dim, K, group_size, n_groups,
fp8_max=_FP8_E4M3_MAX,
TRANSPOSE_BM=transpose_bm,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify FP8 dtype max values
python3 << 'EOF'
import torch
print("Checking FP8 max values...")
print(f"float8_e4m3fnuz: {torch.finfo(torch.float8_e4m3fnuz).max}")
print(f"float8_e5m2: {torch.finfo(torch.float8_e5m2).max}")
EOF

Repository: buddywhitman/aiter

Length of output: 180


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Locate definitions/usages
rg -n "_FP8_E4M3_MAX|fp8_max=_FP8_E4M3_MAX|fp8_max=" aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py

# Find other FP8 max helpers/constants in repo
rg -n "_FP8_MAX|57344\.0|448\.0" aiter | head -n 50

# Also find any helper that maps fp8_dtype -> max
rg -n "default_fp8|_default_fp8_dtype|get_fp8_max|finfo\(torch\.float8" aiter | head -n 50

Repository: buddywhitman/aiter

Length of output: 3867


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect the target file around FP8 dtype selection and the fp8_max usage
sed -n '1,120p' aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -n
sed -n '160,240p' aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -n

echo "-----"

# Inspect the corresponding _triton_kernels implementation (may have correct fp8_max logic)
sed -n '1,120p' aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -n
sed -n '160,240p' aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -n

Repository: buddywhitman/aiter

Length of output: 15426


🏁 Script executed:

#!/bin/bash
set -euo pipefail
sed -n '1,120p' aiter/ops/triton/_triton_kernels/moe/quant_moe.py | cat -n

Repository: buddywhitman/aiter

Length of output: 4756


🏁 Script executed:

#!/bin/bash
set -euo pipefail

rg -n "_FP8_MAX|57344\.0|448\.0 if .*_AMD_FP8|_AMD_FP8.*57344" aiter | head -n 50

Repository: buddywhitman/aiter

Length of output: 276


🏁 Script executed:

#!/bin/bash
set -euo pipefail
rg -n "_AMD_FP8|_FP8_MAX|57344\.0" aiter | head -n 100
rg -n "_FP8_MAX" . | head -n 100

Repository: buddywhitman/aiter

Length of output: 2002


🏁 Script executed:

#!/bin/bash
set -euo pipefail
sed -n '1,130p' op_tests/test_batched_gemm_smallB.py | cat -n

Repository: buddywhitman/aiter

Length of output: 5597


🏁 Script executed:

#!/bin/bash
set -euo pipefail
rg -n "per_token_group_quant\\b" aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py

Repository: buddywhitman/aiter

Length of output: 139


🏁 Script executed:

#!/bin/bash
set -euo pipefail
sed -n '140,320p' aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -n

Repository: buddywhitman/aiter

Length of output: 5141


Fix fp8_max to match selected FP8 dtype (float8_e5m2 must use 57344.0, not 448.0)

aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py selects fp8_dtype via _default_fp8_dtype() (AMD: torch.float8_e4m3fnuz, else torch.float8_e5m2), but per_token_group_quant() always launches per_token_group_quant_kernel with fp8_max=_FP8_E4M3_MAX where _FP8_E4M3_MAX = 448.0. The Triton kernel uses fp8_max for both scale = abs_max / fp8_max and clamping, so float8_e5m2 gets a severely underutilized dynamic range. The reference test op_tests/test_batched_gemm_smallB.py uses _FP8_MAX = 448.0 if _AMD_FP8 else 57344.0 consistent with float8_e4m3fnuz vs float8_e5m2.

🐛 Proposed fix

Replace the hardcoded max with a dtype→max helper and pass the computed value into the kernel:

-_FP8_E4M3_MAX = 448.0
+def _get_fp8_max(dtype: torch.dtype) -> float:
+    """Return the maximum representable value for the given FP8 dtype."""
+    if dtype == torch.float8_e4m3fnuz or dtype == torch.float8_e4m3fn:
+        return 448.0
+    elif dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz:
+        return 57344.0
+    else:
+        raise ValueError(f"Unsupported FP8 dtype: {dtype}")
+    if fp8_dtype is None:
+        fp8_dtype = _default_fp8_dtype()
+    fp8_max_val = _get_fp8_max(fp8_dtype)
+
     per_token_group_quant_kernel[(B_dim * M_dim, n_groups)](
         X, X_q, scale,
         M_dim, B_dim, K, group_size, n_groups,
-        fp8_max=_FP8_E4M3_MAX,
+        fp8_max=fp8_max_val,
         TRANSPOSE_BM=transpose_bm,
     )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py` around
lines 196 - 201, The call to per_token_group_quant_kernel uses a hardcoded
_FP8_E4M3_MAX (448.0) even when _default_fp8_dtype() selects float8_e5m2,
causing incorrect scaling; update per_token_group_quant() to compute fp8_max
from the selected fp8_dtype (e.g., map torch.float8_e4m3fnuz -> _FP8_E4M3_MAX
and torch.float8_e5m2 -> _FP8_E5M2_MAX (57344.0)) and pass that computed fp8_max
into the per_token_group_quant_kernel invocation (replace the literal
_FP8_E4M3_MAX), ensuring the kernel uses the correct dynamic range for both
scale and clamping.

Boss2002n and others added 10 commits June 5, 2026 11:43
)

* grouped topk kernel with sigmoid for deepseek r1 integration

* cleanup after rebase

* restructured grouped topk kernel

* moved grouped_topk to regular topk files

* fix testing

* refactor routing_a8w4 name to unify dsv4 and grouped topk + sigmoid

* ruff

* black formatting

* shift routing to unified routing function

* ruff

* comment fixes

* black

* verified triton kernel is functionally the same as the one needed in atom. small change to import in routing

* review adjustments

* ruff

* ruff

* small change for routing without fused shared experts in topk score_mode branch

* change to routing to not support fused shared experts
…OCm#3572)

* fix(ci): build multi-Python wheels and publish versioned S3 manifest

The aiter-test workflow builds a single wheel using rocm/pytorch:latest
(Python 3.12), but downstream consumers like ATOM run Python 3.10.
Installing a cp312 wheel into a cp310 environment fails.

Add a build matrix (py3.10 + py3.12) to build_aiter_wheels and split
the S3 upload into a dedicated upload_s3_manifest job that collects
both wheels and generates a versioned latest.json manifest with a
per-Python-tag `wheels` field. The legacy top-level wheel_name/wheel_url
fields are preserved for backward compatibility.

* chore(ci): bump py3.10 build image from rocm7.2.2 to rocm7.2.4

* fix(ci): skip py3.10 wheel build on non-main branches
* [module_aiter_operator] refactor

* update
ROCm#3501)

* [feat](pa_mqa_logits) add gfx1250 (RDNA/WMMA wave32) support to deepgemm fp8 paged MQA logits

* [pa_mqa_logits] strip redundant gfx1250 comments

---------

Co-authored-by: Felix Li <felix.li@amd.com>
* feat(gemm): add FlyDSL MXFP8 dense GEMM for gfx1250

Add gemm_a8w8_mxscale, the OCP MXFP8 (E4M3 activation/weight, 1x32 E8M0
scales) dense GEMM for gfx1250, as an a8w8-family variant in gemm_op_a8w8.py.
The op is gfx1250-only and auto-routes via the a8w8_mxscale tuned CSV, falling
back to a safe default kernel for untuned shapes.

- backend: aiter/ops/flydsl/mxscale_gemm.py (launch + kernel-name encode/parse)
  with WMMA host layout helpers in mxscale_layout.py; vendored gfx1250 kernel
  under ops/flydsl/kernels/ (gemm_fp8fp4_gfx1250, gemm_common_gfx1250,
  pipeline_utils). Weight and scales are preshuffled into WMMA layout inside the
  backend, so callers pass row-major / unshuffled tensors.
- routing/config: AITER_CONFIG_GEMM_MXSCALE + a8w8_mxscale_{tuned,untuned}_gemm.csv.
- tuner: FlyDSL-only driver + candidate catalogue under ops/flydsl/gemm_tune/.
- AOT: mxscale branch in aot/flydsl/gemm.py (arch taken from the kernel name).
- tests: op correctness (op_tests/test_gemm_mxscale.py) and CPU-only layout
  checks (op_tests/test_mxscale_layout.py).

MX A8W4 is not integrated this round (vendored kernel supports it, but no public
op / routing / tuner path).

* fix(gemm): correct FlyDSL MXFP8 gfx1250 split_k

- split_k cross-workgroup accumulation now uses a device (agent) scoped
  atomicrmw fadd, so partial sums are coherent across CUs (the buffer
  atomic with default scope lost updates -> non-deterministic results)
- accumulate split_k>1 partials in an f32 scratch buffer and cast once,
  so bf16/f16 output keeps single-shot precision instead of degrading
  with the split count
- disable workgroup clusters in the tuner grid; keep split_k as a tuned knob

* feat(gemm): hoist MXScale B-weight preshuffle out of the per-call path

flydsl_mxscale_gemm re-padded and 16x16-preshuffled the B weight (and its
E8M0 scale) on every call -- a multi-MB layout pass that dominated latency
for memory-bound shapes (e.g. M=1 GEMV).

- split the per-call pad+preshuffle into weight-side / activation-side
  helpers (preshuffle_mxscale_weight / preshuffle_mxscale_activation),
  bit-identical to the old combined pad_mxscale_inputs path
- add public shuffle_weight_mxscale: pre-shuffle the weight once at load
  time, marked is_shuffled (mirrors shuffle_weight / bpreshuffle convention);
  the backend detects the marker, recovers N from the stashed shuffle key,
  and reuses the tensors verbatim instead of re-shuffling
- mxscale tuner caches the shuffle per kernel_name so the weight layout pass
  is excluded from measured kernel latency
- drop now-unused pad_mxscale_inputs

* fix(gemm): skip MXScale test on non-gfx1250 and trim split_k options

- test_gemm_mxscale.py: exit 0 (skip) instead of failing the CI job when run
  as a standalone script on non-gfx1250 hardware; fall back to df.to_string
  when tabulate is unavailable.
- flydsl_gemm_mxscale_gfx1250_common.py: prune split_k tuning options.

* feat(gemm): integrate MXScale GEMM into bpreshuffle path

* debug(mxscale): all gemm_bpreshuffle fallback to flydsl

* refactor(gemm): MXScale MXFP8 via bpreshuffle tuned-CSV dispatch (no K-pad)

B uses the generic shuffle_weight 16x16 layout; K is never padded (require
%128, raise otherwise); dispatch is driven purely by tuned-CSV flydsl_mxscale_*
kernel names. Drops the K-pad / weight marker / fallback machinery; adds an
early 1-byte dtype guard on A/B/scales.

TODO(tuned): shapes without a tuned-CSV entry never route to MXScale (fall back
to CK). Need tuned configs for all target shapes; K padding + untuned fallback
to be planned later.

* debug(gemm): fall back untuned bpreshuffle to flydsl MXScale instead of CK

Temporary: route untuned (config=None) shapes through the default MXScale kernel
so MXFP8 can be exercised without a tuned CSV. Revert once tuned configs land.

* [gfx1250][FlyDSL] Add WMMA a8w8 bpreshuffle FP8 GEMM

Route gemm_a8w8_bpreshuffle on gfx1250 to the FlyDSL gfx1250 WMMA kernel
(ptpc scale mode) via tuned-CSV kernelName prefix flydsl_bpreshuffle_wmma_,
with a dedicated backend and a standalone WMMA auto-tuner. Removes the
earlier MXScale op / layout / tuner.

* [debug] all gfx1250 bpreshuffle untuned kernel fallback to flydsl

* fix(gemm): fp32 atomic accumulation for FlyDSL gfx1250 split-k

split-k sums per-chunk partials via global atomic add. With a bf16/f16
output buffer the kernel rounded every partial to an 8-bit mantissa
before adding, so the error compounded with the chunk count (sk2 warned,
sk4+ failed the tuner's atol/rtol=0.01 check).

Accumulate split-k partials into an fp32 scratch buffer and narrow once
at the end, matching the non-split (sk1) result. All paths (dispatch,
untuned fallback, auto-tune) converge on run_preshuffle_gemm_a8_gfx1250,
so the single backend change covers them. Tighten test_split_k to pin
the split result to sk1 (the old cosine>0.99 check was too loose).

* Revert "[debug] all gfx1250 bpreshuffle untuned kernel fallback to flydsl"

This reverts commit 034e293.

---------

Co-authored-by: Felix Li <felix.li@amd.com>
…for small-B regime

Fixes ROCm#3000. Unblocks ROCm/ATOM#676.

Existing FP8 batched GEMM kernels dispatch only 16 work-groups on a
304-CU GPU (B=2, M=1: 5% CU utilization), losing 1.5x-7x to BF16 einsum.

Three fixes:
1. Grid collapse + split-K=8: (B*M_tiles, N_tiles, split_k) grid
   -> 128 work-groups at B=2, M=1 -> 42% CU utilization
2. Fused bf16 output for split_k=1: eliminates reduction kernel overhead
3. Per-128-block W-scales loaded inline (no dequant/requant precision loss)

New files:
- aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
- aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
- aiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
- aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
- op_tests/test_batched_gemm_smallB.py

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…or ATOM#773

Two additions requested by ROCm/ATOM#773 closing comment:
'the proper fix lives in aiter — extend the per-group quant path to accept
a permuted (G,S,R) input layout so the wo_a einsum can be replaced by a
fused quant + grouped GEMM without requiring a non-contiguous bmm'

1. per_token_group_quant kernel (TRANSPOSE_BM=True):
   Accepts X in (M, B, K) = (tokens, n_local_groups, d_per_group) layout
   directly, matching the natural output of DeepSeek V4 attention.
   No .transpose(0,1) needed; avoids the non-contiguous view that violated
   hipBLAS strided-batched GEMM contract and caused the GSM8K regression.

2. batched_gemm_a8w8_smallB_blockscale_bf16 entry-point:
   Accepts raw BF16 activations, runs fused per-token-group quant (Triton
   kernel) then the split-K GEMM — one call, no separate act_quant launch.

All 18 tests pass (NVIDIA RTX 3070Ti, sm_86, CUDA 13.3, Triton 3.2.0).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@buddywhitman buddywhitman force-pushed the feat/triton-batched-gemm-smallB-blockscale branch from 7b91a13 to 208ff0e Compare June 7, 2026 12:39
Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 20

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
csrc/kernels/generate_binaryop.py (1)

18-34: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Keep the generated dtype maps in sync.

DATA_TYPE_MAP still advertises float64 and bool, and cmdGenFunc() now forwards raw torch.dtype names into this generator, but AITER_DTYPE_MAP no longer has matching entries. The first binary-op JIT on one of those dtypes will blow up with a KeyError during codegen before the Python fallback gets a chance to run.

Either add the missing AITER_DTYPE_* mappings or explicitly short-circuit unsupported dtypes before invoking codegen.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/kernels/generate_binaryop.py` around lines 18 - 34, The DATA_TYPE_MAP
and AITER_DTYPE_MAP are out of sync causing a KeyError during codegen; update
AITER_DTYPE_MAP to include entries for the dtypes referenced in DATA_TYPE_MAP
(at minimum "float64" and "bool") with the correct AITER_DTYPE_* constants, or
alter cmdGenFunc() to validate/short-circuit unsupported torch.dtype names
before calling the generator so codegen is never invoked for dtypes missing in
AITER_DTYPE_MAP; reference DATA_TYPE_MAP, AITER_DTYPE_MAP and cmdGenFunc() when
making the change.
op_tests/test_mla.py (1)

667-681: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Don’t skip valid bh16bn64 cases when batch_size > 256.

The wrapper uses NUM_KV_SPLITS = max(1, 256 // batch_size), so larger batches are still valid with NUM_KV_SPLITS == 1. This gate hard-caps at 256 and silently drops those configurations from the Gluon test/benchmark.

Suggested fix
-        and 1 <= batch_size <= 256
-        and ctx_lens >= (256 // batch_size)
+        and batch_size >= 1
+        and ctx_lens >= max(1, 256 // batch_size)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@op_tests/test_mla.py` around lines 667 - 681, The gate incorrectly excludes
batch_size > 256; update the condition in the Gluon bh16bn64 decode check (the
if using get_gfx(), dtype, kvtype, nhead, decode_qlen, v_head_dim, qk_head_dim,
page_size, batch_size, ctx_lens) to mirror NUM_KV_SPLITS semantics by removing
the hard cap and using ctx_lens >= max(1, 256 // batch_size) with a simple
batch_size >= 1 check (i.e., replace "1 <= batch_size <= 256 and ctx_lens >=
(256 // batch_size)" with "batch_size >= 1 and ctx_lens >= max(1, 256 //
batch_size)").
op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py (1)

55-62: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

Remove unsupported use_aot=True argument.

Line 58 passes use_aot=True to gemm_afp4wfp4_preshuffle, but the function signature (defined at lines 416-425 in aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py) does not accept a use_aot parameter. This will raise a TypeError at runtime.

🐛 Proposed fix to remove the unsupported parameter
-        ms = triton.testing.do_bench(
-            lambda: gemm_afp4wfp4_preshuffle(
-                x, w, x_scale, w_scale, c_dtype, y, use_aot=True
-            ),
+        ms = triton.testing.do_bench(
+            lambda: gemm_afp4wfp4_preshuffle(
+                x, w, x_scale, w_scale, c_dtype, y
+            ),
             warmup=25,
             rep=100,
         )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py` around lines 55 - 62,
The bench call is passing an unsupported keyword use_aot=True to
gemm_afp4wfp4_preshuffle which will raise a TypeError; remove the use_aot=True
argument from the lambda passed into triton.testing.do_bench so the call matches
the gemm_afp4wfp4_preshuffle signature (keep all other args: x, w, x_scale,
w_scale, c_dtype, y) and ensure the lambda still returns the function invocation
for benchmarking.
aiter/ops/triton/moe/moe_op_gemm_a8w4.py (1)

364-420: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard out_mx_quant off on gfx1250 until the Gluon kernels actually emit scales.

The new gfx1250 launch path never passes a YMxScale/HAS_MX_OUT equivalent into either Gluon kernel, so y_scale is returned uninitialized. On top of that, n_out is derived after w is transposed, so the scale shape is computed from K instead of the output width N.

Suggested fix
     use_gluon = get_arch() == "gfx1250"
+    if use_gluon and out_mx_quant:
+        raise NotImplementedError(
+            "out_mx_quant is not implemented for gfx1250 Gluon kernels yet"
+        )
+
     assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"

Also applies to: 435-538

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@aiter/ops/triton/moe/moe_op_gemm_a8w4.py` around lines 364 - 420, Disable the
out_mx_quant path when using the Gluon kernels on gfx1250 (i.e. when use_gluon
is true and the device/gfx id indicates gfx1250) so we don't return an
uninitialized y_scale; specifically, gate the existing out_mx_quant logic (the
out_dtype assignment, the allocate_output call that depends on
reduction_n_matmul, and the y_scale allocation) behind a check that clears
out_mx_quant for Gluon+gfx1250, and fix the n_out calculation to derive the
post-swiglu output width from the actual output tensor shape (use y.shape[-1] or
padded_N divided by reduction_n_matmul) instead of w.shape[-1] so the scale
buffer size is computed from N not K; refer to symbols out_mx_quant, use_gluon,
w, y, y_scale, allocate_output, and n_out when making the change.
🧹 Nitpick comments (2)
csrc/include/hip_reduce.h (1)

291-292: Tighten multithread_reduce_max_dpp’s non-__GFX9__ contract around thread_num==64

The thread_num==64 implementation in csrc/include/hip_reduce.h is compiled out for non-__GFX9__, but current callers don’t instantiate <64> there: topk_gating_kernels.cu passes WARP_SIZE (32 when not __GFX9__), and fused_qk_rmsnorm_group_quant.cu enforces group_size <= WARP_SIZE * thread_data_size, so non-__GFX9__ dispatch can’t produce ReduceThreadSize==64. Still, tightening the static_assert to thread_num <= 32 for !__GFX9__ (or an explicit compile-time error for thread_num==64) would better align the template contract with actual support.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/include/hip_reduce.h` around lines 291 - 292, The static_assert in
multithread_reduce_max_dpp currently allows thread_num up to 64 unconditionally;
update the contract so non-__GFX9__ builds disallow 64: modify the check around
the static_assert in multithread_reduce_max_dpp (symbol:
multithread_reduce_max_dpp / the static_assert line) to be conditional on
__GFX9__ — for __GFX9__ keep the existing <=64 power-of-two check, otherwise
tighten to thread_num >=1 && thread_num <=32 && power-of-two (or explicitly
static_assert(false) when thread_num==64 on non-__GFX9__), so callers cannot
instantiate a 64-thread variant on targets that don't support it.
op_tests/triton_tests/fusions/test_fused_kv_cache.py (1)

219-240: ⚡ Quick win

Add a partial-prefix decode-output case here.

This test always uses num_decode_toks_for_zeros = T, so it never exercises the case where only the first few batches should write decode_q_pe_out / q_nope_zeros_out. The new pid transpose in the fused kernels only breaks in that partial-prefix setup, so a small case like T=8, num_decode_toks_for_zeros=3 would catch the current OOB regression immediately.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@op_tests/triton_tests/fusions/test_fused_kv_cache.py` around lines 219 - 240,
The test currently sets num_decode_toks_for_zeros = T so it never exercises the
partial-prefix decode-output path; modify the test around the call to
fused_qk_rope_cat_and_cache_mla to add at least one scenario where
num_decode_toks_for_zeros < T (for example T=8 and num_decode_toks_for_zeros=3)
so the function fused_qk_rope_cat_and_cache_mla (and outputs decode_q_pe_out /
q_nope_zeros_out) is exercised for partial-prefix writes; ensure the same setup
(slot_mapping, positions, shuffled_kv_cache, upcast_operand) is reused and only
num_decode_toks_for_zeros is changed so the new case triggers the pid-transpose
OOB regression.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In @.github/scripts/download_triton_wheel.sh:
- Around line 13-19: The pipeline assignment to ROCM_VERSION can fail under set
-euo pipefail, making the else fallback unreachable; modify the command that
sets ROCM_VERSION (the line assigning ROCM_VERSION using dpkg -l rocm-core |
awk) to neutralize failures (e.g., append || true) so the command never returns
non-zero and the subsequent if [[ -n "${ROCM_VERSION}" ]]; then ... else ...
block can run and correctly set ROCM_MAJOR_MINOR and TRITON_INDEX_URL or fall
back to the default ROCM version message.

In `@aiter/fused_moe.py`:
- Around line 1807-1811: The current branch that sets
extra_stage2_args["expert_mask"] and extra_stage2_args["topk_ids"] for the
_flydsl_stage2_wrapper incorrectly forwards localized ids (local_topk_ids) into
the reduction kernel; change it so the reduction path receives the original
global topk ids while kernels that need split-k localized ids still get
local_topk_ids. Concretely, in the if block guarded by stage2_func is
_flydsl_stage2_wrapper and expert_mask is not None, add/assign a distinct key
(e.g., "global_topk_ids") with the original topk_ids and leave "topk_ids"
pointing to local_topk_ids only for places that use localized routing; update
any call-sites inside fused_moe_2stages / fused_moe_ that expect global vs local
ids to read the appropriate key so the new reduce kernel uses global_topk_ids
while other kernels continue to use topk_ids (local_topk_ids).
- Around line 1420-1426: The code currently enables the FlyDSL force override by
default because _flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "1") ==
"1"; change this to be opt-in by default (e.g. use
os.environ.get("AITER_FLYDSL_FORCE", "0") == "1" or check presence with
os.environ.get("AITER_FLYDSL_FORCE") == "1") so _flydsl_force is False when the
variable is unset; update the surrounding use_mxfp4_flydsl condition (symbols:
_flydsl_force and use_mxfp4_flydsl) accordingly so FlyDSL is only forced when
the env var is explicitly set to "1".

In `@aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py`:
- Around line 87-110: Add device and output-shape validation before launching
the kernel: verify XQ, WQ, x_scale, w_scale, and Out are on the same device and
that Out.shape == (M, N); also ensure Out is contiguous (or call Out =
Out.contiguous()) when padded_m == M to avoid out-of-bounds writes. Implement
these checks alongside the existing dimensional/element-size checks in the
gfx1250 path (the block validating XQ/WQ, M, N, K, tile_n/tile_k and Out.dtype)
and duplicate the same guard logic in the other similar validation blocks
present later in this file (the other gfx1250/validation sections).
- Around line 119-123: Compute num_k_tiles as before, then do not coerce nb up
to 2; set nb = min(int(num_buffers), num_k_tiles) and if nb is not in
_SUPPORTED_NUM_BUFFERS pick the largest supported buffer count <= nb, but if
that yields a value < 1 or yields a supported nb that is greater than
num_k_tiles (e.g. forcing 2 when num_k_tiles == 1) raise an explicit exception
(ValueError) rejecting the unsupported pipeline depth instead of silently
forcing nb=2; update the logic around num_k_tiles, nb and _SUPPORTED_NUM_BUFFERS
and ensure callers that rely on kernel_fits_shape(...) will see the error when a
pipeline depth cannot be satisfied.

In `@aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py`:
- Around line 116-126: The CSV always records splitK as 0 because the info tuple
third slot is being hardcoded; fix by passing the real split_k when constructing
the info tuple in get_wmma_tune_task() (use ki.split_k instead of 0) and ensure
result_to_df() (and the similar block around lines 148-179) reads that third
element into splitK and writes it into the DataFrame (falling back to 0 only if
the info tuple lacks the slot). Update all places that build the info tuple so
the third element carries the actual ki.split_k, and keep result_to_df's
unpacking (keys, kernelId, splitK, kernelName, libtype) to persist the real
splitK value.

In `@aiter/ops/flydsl/kernels/moe_gemm_2stage.py`:
- Around line 3644-3647: The module_name currently ignores num_experts and the
new public ABI change; update construction of module_name (the variable
module_name) to also encode num_experts and an ABI tag so cached binaries are
unique to the expert count and signature change — include num_experts and a
short ABI identifier (e.g. "abi_expertmask_topkids" or similar) alongside
existing parts built from use_mask, dtype_str, topk and model_dim so the name
reflects the new expert_mask descriptor and the (expert_mask, topk_ids) ABI.

In `@aiter/ops/shuffle.py`:
- Around line 29-37: The 3D branch always treats w as (E, K, N) then transposes
to (E, N, K), which breaks callers that pass (E, N, K); update the branch in
shuffle.py to first detect the layout (e.g., inspect which axis lengths are
divisible by 32 vs 16 or use an explicit flag if available), then apply
divisibility checks against the correct axes (K must be divisible by 32 and N by
16) and only transpose when the input is in (E, K, N) form; finally perform the
view/permutations using the correct dimension ordering so the output layout is
consistent for both input layouts (refer to the existing variables and
operations: w, w.ndim == 3 branch, the transpose(), view(...16, K // 32, 2, 16),
permute(...), and final transpose()).

In `@aiter/ops/triton/_gluon_kernels/gfx1250/fusions/fused_kv_cache.py`:
- Around line 394-399: The transpose of pid (pid_hq = pid // B; pid_b = pid % B)
broke the original batch-based gate; update the downstream condition that
currently uses pid (e.g., the `if pid < num_decode_toks_for_zeros * QH` / OUTPUT
block that writes to decode_q_pe_out and q_nope_zeros_out) to also require the
batch predicate pid_b < num_decode_toks_for_zeros so stores are only issued for
batches inside the decode-output prefix; locate the check that previously used
pid and add `and pid_b < num_decode_toks_for_zeros` (or equivalent) to gate the
tail path that writes to decode_q_pe_out / q_nope_zeros_out.

In `@aiter/ops/triton/_gluon_kernels/gfx1250/gemm/basic/gemm_mxfp4.py`:
- Around line 279-376: The pipeline assumes k_tiles >= NUM_BUFFERS and will
overcompute when that's false; update the GEMM tile loop logic (variables
k_tiles, NUM_BUFFERS, main_iters, the prologue/epilogue sections and final WMMA)
to guard that case: if k_tiles < NUM_BUFFERS, set a local pipeline_depth =
min(NUM_BUFFERS, k_tiles) (or short-circuit into a simple non-pipelined path)
and use pipeline_depth in place of NUM_BUFFERS for prologue/epilogue waits,
buffer indexing (load_idx/compute_idx), and main_iters so you only issue/compute
exactly k_tiles tiles; ensure async_load/async_wait counts and the final WMMA
use the adjusted pipeline_depth to avoid draining nonexistent tiles.

In `@aiter/ops/triton/_triton_kernels/fusions/fused_kv_cache.py`:
- Around line 310-315: The batch vs head flattening was changed (pid_hq = pid //
B; pid_b = pid % B) which breaks the decode-output bounds check; restore correct
batch-major mapping or ensure bounds use pid_b. Either revert to the original
mapping (pid_b = pid // QH; pid_hq = pid % QH) in fused_kv_cache.py so pid_b
represents the batch index, or keep the current mapping but change the
decode/OUTPUT block checks that reference pid/flattened pid to use pid_b when
testing against num_decode_toks_for_zeros so writes to decode_q_pe_out and
q_nope_zeros_out are clamped by batch index.

In `@aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py`:
- Around line 538-551: The bias-mode top-2 computation currently suppresses the
first max by value-equality (scores_for_choice == gm1_per_e), which removes all
tied-max experts; instead compute the per-group argmax index and suppress by
index. Replace the equality-based suppression around gm1_per_e with an index
mask derived from tl.argmax over group_max1 (or the appropriate per-group argmax
of scores_for_choice using gid_eq) so only the chosen max position is masked
(use that argmax mask to create sup_3d/expanded2), preserving other tied maxima
and then compute group_max2 as before.

In `@aiter/ops/triton/attention/pa_mqa_logits.py`:
- Around line 459-460: The code currently mutates the caller-visible WavePerEU
when forcing gfx1250/non-preshuffle to compile with WavePerEU=1; instead, do not
overwrite the API parameter WavePerEU—instead introduce a new local variable
(e.g., effective_wave_per_eu) set to WavePerEU by default and set
effective_wave_per_eu = 1 only when get_gfx() == "gfx1250" and not Preshuffle;
then use effective_wave_per_eu for compilation decisions and any artifact naming
logic (the places that previously read WavePerEU for AOT artifact names and the
compilation path), leaving the original WavePerEU value untouched for callers
and tests like op_tests/op_benchmarks/triton/bench_deepgemm_attention.py.

In `@aiter/ops/triton/gluon/README.md`:
- Around line 87-89: The README for the bh16bn64 regime is ambiguous: document
the NUM_KV_SPLITS==1 fast path where the wrapper skips stage-2 reduction and
writes o directly (instead of running the two-stage reduce), and clarify how
return_lse behaves in that case; update the bh16bn64 paragraph to state that
when NUM_KV_SPLITS = 1 (e.g., batch_size >= 256) the implementation bypasses
stage-2 and mla_decode_gluon(...) returns (o, final_lse) only when
return_lse=True and the merged fp32 LSE is computed locally, otherwise the
wrapper writes o directly without a separate stage-2 reduction.

In `@aiter/ops/triton/moe/moe_routing/routing.py`:
- Around line 308-310: The small-token sort path calls
triton.next_power_of_2(num_tokens) unguarded and will crash when num_tokens ==
0; change the logic in the routing code that sets HIST_BLOCK_M and sort_fn (the
block using HIST_BLOCK_M = triton.next_power_of_2(num_tokens) and sort_fn =
sort_tokens_fused) to use a safe value such as safe_num_tokens = max(num_tokens,
1) (consistent with the fused score-mode/hash paths) and call
triton.next_power_of_2(safe_num_tokens) so empty local batches do not trigger an
error.

In `@aiter/ops/triton/moe/moe_routing/topk.py`:
- Around line 43-79: Validate expert_group mapping in grouped_topk before
launching _grouped_topk: ensure all expert_group values are in [0,
num_expert_group), compute per-group counts with torch.bincount (length
num_expert_group) and assert no empty groups, and additionally assert the
worst‑case selection of topk_group groups can provide k experts by checking that
the sum of the smallest topk_group counts >= k; raise a clear AssertionError
referencing expert_group/num_expert_group/topk_group/k when any check fails.

In `@csrc/kernels/binary_operator.cu`:
- Around line 24-41: The in-place native wrappers aiter_add_, aiter_sub_,
aiter_mul_, and aiter_div_ must reject non-contiguous destinations: before
calling binary_op_dispatch("...", input, other, input) check
input.is_contiguous() and return false when it is not; this prevents
binary_op_dispatch (and the transpose/broadcast kernels it may select) from
writing dense offsets into the wrong locations and lets the Python layer fall
back to the correct strided in-place implementation.

In `@op_tests/op_benchmarks/triton/bench_deepgemm_attention.py`:
- Around line 26-35: kv_cache_cast_to_fp8 currently hardcodes the FP8 max (240)
when computing sf; instead, derive the FP8 max from the selected fp8_dtype and
use that to compute sf. Update kv_cache_cast_to_fp8 (and where fp8_dtype is set
via get_fp8_e4m3_dtype) to compute fp8_max = torch.finfo(fp8_dtype).max (or
equivalent API for the chosen FP8 dtype) and set sf = x_amax / fp8_max before
scaling; keep the existing fallback when fp8_dtype is None and preserve
x.shape/asserts and subsequent casting to fp8_dtype.

In `@op_tests/test_mla.py`:
- Around line 582-584: The current early skip ("if return_lse: pass") prevents
the bh64 Gluon path from ever seeing return_lse, so ensure the Gluon validation
receives the flag: remove the no-op skip and forward the return_lse flag into
the Gluon test/run invocation (eg. pass return_lse into test_absorb_decode_gluon
or the Gluon runner call that handles bh64) so the Gluon branch evaluates the
merged LSE contract; update the call site(s) where test_absorb_decode_gluon or
the Gluon run is invoked to accept and propagate the return_lse parameter.

In `@op_tests/test_moe_ep.py`:
- Around line 482-509: The test for quant_label "a8w4_mxfp4" doesn't force the
fp8 runtime path, so set the fp8 runtime bound before calling run_perftest
(e.g., export or os.environ["AITER_BF16_FP8_MOE_BOUND"]="1" or a value <= token)
when quant_label == "a8w4_mxfp4" so the code path exercised by
run_perftest(fused_moe, ...) uses the fp8 EP; update the branch that sets act
and gate_mode in op_tests/test_moe_ep.py to set that environment variable (or
otherwise invoke the runtime-mode switch) prior to the run_perftest call.

---

Outside diff comments:
In `@aiter/ops/triton/moe/moe_op_gemm_a8w4.py`:
- Around line 364-420: Disable the out_mx_quant path when using the Gluon
kernels on gfx1250 (i.e. when use_gluon is true and the device/gfx id indicates
gfx1250) so we don't return an uninitialized y_scale; specifically, gate the
existing out_mx_quant logic (the out_dtype assignment, the allocate_output call
that depends on reduction_n_matmul, and the y_scale allocation) behind a check
that clears out_mx_quant for Gluon+gfx1250, and fix the n_out calculation to
derive the post-swiglu output width from the actual output tensor shape (use
y.shape[-1] or padded_N divided by reduction_n_matmul) instead of w.shape[-1] so
the scale buffer size is computed from N not K; refer to symbols out_mx_quant,
use_gluon, w, y, y_scale, allocate_output, and n_out when making the change.

In `@csrc/kernels/generate_binaryop.py`:
- Around line 18-34: The DATA_TYPE_MAP and AITER_DTYPE_MAP are out of sync
causing a KeyError during codegen; update AITER_DTYPE_MAP to include entries for
the dtypes referenced in DATA_TYPE_MAP (at minimum "float64" and "bool") with
the correct AITER_DTYPE_* constants, or alter cmdGenFunc() to
validate/short-circuit unsupported torch.dtype names before calling the
generator so codegen is never invoked for dtypes missing in AITER_DTYPE_MAP;
reference DATA_TYPE_MAP, AITER_DTYPE_MAP and cmdGenFunc() when making the
change.

In `@op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py`:
- Around line 55-62: The bench call is passing an unsupported keyword
use_aot=True to gemm_afp4wfp4_preshuffle which will raise a TypeError; remove
the use_aot=True argument from the lambda passed into triton.testing.do_bench so
the call matches the gemm_afp4wfp4_preshuffle signature (keep all other args: x,
w, x_scale, w_scale, c_dtype, y) and ensure the lambda still returns the
function invocation for benchmarking.

In `@op_tests/test_mla.py`:
- Around line 667-681: The gate incorrectly excludes batch_size > 256; update
the condition in the Gluon bh16bn64 decode check (the if using get_gfx(), dtype,
kvtype, nhead, decode_qlen, v_head_dim, qk_head_dim, page_size, batch_size,
ctx_lens) to mirror NUM_KV_SPLITS semantics by removing the hard cap and using
ctx_lens >= max(1, 256 // batch_size) with a simple batch_size >= 1 check (i.e.,
replace "1 <= batch_size <= 256 and ctx_lens >= (256 // batch_size)" with
"batch_size >= 1 and ctx_lens >= max(1, 256 // batch_size)").

---

Nitpick comments:
In `@csrc/include/hip_reduce.h`:
- Around line 291-292: The static_assert in multithread_reduce_max_dpp currently
allows thread_num up to 64 unconditionally; update the contract so non-__GFX9__
builds disallow 64: modify the check around the static_assert in
multithread_reduce_max_dpp (symbol: multithread_reduce_max_dpp / the
static_assert line) to be conditional on __GFX9__ — for __GFX9__ keep the
existing <=64 power-of-two check, otherwise tighten to thread_num >=1 &&
thread_num <=32 && power-of-two (or explicitly static_assert(false) when
thread_num==64 on non-__GFX9__), so callers cannot instantiate a 64-thread
variant on targets that don't support it.

In `@op_tests/triton_tests/fusions/test_fused_kv_cache.py`:
- Around line 219-240: The test currently sets num_decode_toks_for_zeros = T so
it never exercises the partial-prefix decode-output path; modify the test around
the call to fused_qk_rope_cat_and_cache_mla to add at least one scenario where
num_decode_toks_for_zeros < T (for example T=8 and num_decode_toks_for_zeros=3)
so the function fused_qk_rope_cat_and_cache_mla (and outputs decode_q_pe_out /
q_nope_zeros_out) is exercised for partial-prefix writes; ensure the same setup
(slot_mapping, positions, shuffled_kv_cache, upcast_operand) is reused and only
num_decode_toks_for_zeros is changed so the new case triggers the pid-transpose
OOB regression.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: fcbd3bff-42b8-49d0-bb7d-3db8e7a84668

📥 Commits

Reviewing files that changed from the base of the PR and between 7b91a13 and 208ff0e.

⛔ Files ignored due to path filters (8)
  • aiter/configs/model_configs/dsv3_bf16_tuned_gemm.csv is excluded by !**/*.csv
  • aiter/configs/model_configs/dsv4_bf16_tuned_gemm.csv is excluded by !**/*.csv
  • aiter/configs/model_configs/glm5_bf16_tuned_gemm.csv is excluded by !**/*.csv
  • aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv is excluded by !**/*.csv
  • aiter/configs/model_configs/kimi_bf16_tuned_gemm.csv is excluded by !**/*.csv
  • aiter/configs/model_configs/kimik2_bf16_tuned_gemm.csv is excluded by !**/*.csv
  • aiter/configs/model_configs/kimik2_fp4_tuned_fmoe.csv is excluded by !**/*.csv
  • hsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_silu.csv is excluded by !**/*.csv
📒 Files selected for processing (69)
  • .github/scripts/download_triton_wheel.sh
  • .github/scripts/split_tests.sh
  • .github/workflows/aiter-release.yaml
  • .github/workflows/aiter-test.yaml
  • .github/workflows/amd-ci-job-monitor.yml
  • .github/workflows/docs.yml
  • .github/workflows/prepare-triton-wheel.yaml
  • .github/workflows/test-whl.yaml
  • .github/workflows/triton-test.yaml
  • .github/workflows/vllm_benchmark.yaml
  • aiter/fused_moe.py
  • aiter/ops/aiter_operator.py
  • aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py
  • aiter/ops/flydsl/gemm_kernels.py
  • aiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.py
  • aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py
  • aiter/ops/flydsl/kernels/gemm_common_gfx1250.py
  • aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py
  • aiter/ops/flydsl/kernels/moe_gemm_2stage.py
  • aiter/ops/flydsl/kernels/pipeline_utils.py
  • aiter/ops/flydsl/kernels/splitk_hgemm.py
  • aiter/ops/flydsl/moe_kernels.py
  • aiter/ops/gemm_op_a8w8.py
  • aiter/ops/opus/__init__.py
  • aiter/ops/shuffle.py
  • aiter/ops/triton/__init__.py
  • aiter/ops/triton/_gluon_kernels/gfx1250/fusions/fused_kv_cache.py
  • aiter/ops/triton/_gluon_kernels/gfx1250/gemm/basic/gemm_mxfp4.py
  • aiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.py
  • aiter/ops/triton/_triton_kernels/fusions/fused_kv_cache.py
  • aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
  • aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py
  • aiter/ops/triton/attention/pa_mqa_logits.py
  • aiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4_PRESHUFFLED.json
  • aiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
  • aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
  • aiter/ops/triton/fusions/fused_kv_cache.py
  • aiter/ops/triton/gemm/basic/gemm_afp4wfp4.py
  • aiter/ops/triton/gemm/batched/__init__.py
  • aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
  • aiter/ops/triton/gluon/README.md
  • aiter/ops/triton/gluon/gemm_afp4wfp4.py
  • aiter/ops/triton/gluon/mla_decode_gluon.py
  • aiter/ops/triton/gluon/pa_mqa_logits.py
  • aiter/ops/triton/moe/moe_op_gemm_a8w4.py
  • aiter/ops/triton/moe/moe_routing/routing.py
  • aiter/ops/triton/moe/moe_routing/topk.py
  • csrc/include/aiter_operator.h
  • csrc/include/binary_operator.cuh
  • csrc/include/hip_reduce.h
  • csrc/kernels/binary_operator.cu
  • csrc/kernels/generate_binaryop.py
  • csrc/kernels/topk_gating_kernels.cu
  • csrc/pybind/aiter_operator_pybind.cu
  • hsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_ps_32x384.co
  • hsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_novs_silu_32x384.co
  • hsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x384.co
  • hsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_vs_silu_32x384.co
  • op_tests/op_benchmarks/triton/bench_deepgemm_attention.py
  • op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py
  • op_tests/test_batched_gemm_smallB.py
  • op_tests/test_gemm_a8w8_bpreshuffle_gfx1250.py
  • op_tests/test_mla.py
  • op_tests/test_moe_2stage.py
  • op_tests/test_moe_ep.py
  • op_tests/triton_tests/fusions/test_fused_kv_cache.py
  • op_tests/triton_tests/fusions/test_fused_routing_from_topk.py
  • op_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.py
  • op_tests/triton_tests/moe/test_moe_routing.py
💤 Files with no reviewable changes (2)
  • op_tests/triton_tests/fusions/test_fused_routing_from_topk.py
  • aiter/ops/opus/init.py
✅ Files skipped from review due to trivial changes (3)
  • csrc/pybind/aiter_operator_pybind.cu
  • aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
  • aiter/ops/triton/gemm/batched/init.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • aiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
  • aiter/ops/triton/init.py
  • op_tests/test_batched_gemm_smallB.py
  • aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
  • aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py

Comment on lines 13 to 19
ROCM_VERSION=$(dpkg -l rocm-core 2>/dev/null | awk '/^ii/{print $3}')
if [[ -n "${ROCM_VERSION}" ]]; then
ROCM_MAJOR_MINOR=$(echo "${ROCM_VERSION}" | cut -d. -f1,2)
TRITON_INDEX_URL="https://pypi.amd.com/triton/release_/rocm-${ROCM_MAJOR_MINOR}.0/simple/"
else
echo "rocm-core not found; using default ROCm version ${TRITON_DEFAULT_ROCM_VERSION}"
fi
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

#!/usr/bin/env bash
set -euo pipefail

echo "Inspect relevant lines:"
nl -ba .github/scripts/download_triton_wheel.sh | sed -n '11,20p'

echo
echo "Reproduce errexit+pipefail behavior in command substitution:"
set +e
bash -lc 'set -euo pipefail; v=$(false | cat); echo "SURVIVED"'
status=$?
set -e
echo "Observed exit status: ${status} (expected non-zero; fallback branch would be skipped)."

Repository: buddywhitman/aiter

Length of output: 147


🏁 Script executed:

#!/usr/bin/env bash
set -euo pipefail

echo "1) Show relevant lines from .github/scripts/download_triton_wheel.sh (lines ~1-30):"
# use cat -n instead of nl (nl not available)
cat -n .github/scripts/download_triton_wheel.sh | sed -n '1,40p'

echo
echo "2) Locate references to TRITON_DEFAULT_ROCM_VERSION and roc m-core handling:"
rg -n "TRITON_DEFAULT_ROCM_VERSION|rocm-core|ROCM_VERSION|TRITON_INDEX_URL" .github/scripts/download_triton_wheel.sh .github/workflows/prepare-triton-wheel.yaml || true

echo
echo "3) Inspect .github/workflows/prepare-triton-wheel.yaml for how the script is invoked:"
sed -n '1,200p' .github/workflows/prepare-triton-wheel.yaml | rg -n "download_triton_wheel\.sh|prepare-triton-wheel|TRITON_DEFAULT_ROCM_VERSION|ROCM_VERSION|rocm-core" || true

echo
echo "4) Reproduce errexit+pipefail behavior in command substitution (minimal cases):"
set +e

# Case A: failing command in pipeline substitution
bash -lc 'set -euo pipefail; v=$(false | cat); echo "A_SURVIVED"' ; echo "A_EXIT=$?"

# Case B: failing left side (simulating dpkg not found) piped to awk; ensure pipeline failure
bash -lc 'set -euo pipefail; v=$(command_not_found_in_subst 2>/dev/null | awk "/x/"); echo "B_SURVIVED"' ; echo "B_EXIT=$?"

# Case C: force non-zero but allow pipeline to succeed via `|| true` (proposed fix baseline)
bash -lc 'set -euo pipefail; v=$(false | cat || true); echo "C_SURVIVED"' ; echo "C_EXIT=$?"

# Case D: proposed fix pattern: command substitution with `|| true` appended
bash -lc 'set -euo pipefail; v=$(false | cat || true); echo "D_SURVIVED"' ; echo "D_EXIT=$?"

set -e

Repository: buddywhitman/aiter

Length of output: 4868


🏁 Script executed:

#!/usr/bin/env bash
set -euo pipefail

echo "1) Show relevant lines from .github/scripts/download_triton_wheel.sh (lines ~1-40):"
cat -n .github/scripts/download_triton_wheel.sh | sed -n '1,40p'

echo
echo "2) Locate references to TRITON_DEFAULT_ROCM_VERSION and rocm-core handling:"
rg -n "TRITON_DEFAULT_ROCM_VERSION|rocm-core|ROCM_VERSION|TRITON_INDEX_URL" .github/scripts/download_triton_wheel.sh .github/workflows/prepare-triton-wheel.yaml || true

echo
echo "3) Inspect .github/workflows/prepare-triton-wheel.yaml for how the script is invoked:"
sed -n '1,240p' .github/workflows/prepare-triton-wheel.yaml | rg -n "download_triton_wheel\.sh|TRITON_DEFAULT_ROCM_VERSION|ROCM_VERSION|rocm-core|prepare-triton-wheel" || true

echo
echo "4) Reproduce errexit+pipefail behavior in command substitution (minimal cases):"
set +e

# Case A: failing command in pipeline substitution
bash -lc 'set -euo pipefail; v=$(false | cat); echo "A_SURVIVED"' ; echo "A_EXIT=$?"

# Case B: missing command on left side piped to awk; pipeline fails under pipefail
bash -lc 'set -euo pipefail; v=$(command_not_found_in_subst 2>/dev/null | awk "/x/"); echo "B_SURVIVED"' ; echo "B_EXIT=$?"

# Case C: proposed style inside substitution using `|| true` on pipeline
bash -lc 'set -euo pipefail; v=$(false | cat || true); echo "C_SURVIVED"' ; echo "C_EXIT=$?"

set -e

Repository: buddywhitman/aiter

Length of output: 4748


Fix unreachable fallback in download_triton_wheel.sh (lines 13-19).

With set -euo pipefail, a non-zero status from the pipeline in ROCM_VERSION=$(dpkg -l rocm-core ... | awk ...) aborts the script before the if [[ -n "${ROCM_VERSION}" ]]; then ... else ... block executes (matching the sandbox repro for v=$(false | cat)). So the “rocm-core not found; using default …” branch at lines 17-19 won’t run when dpkg -l rocm-core returns non-zero. Add || true (or otherwise neutralize the failure) to allow the fallback logic to execute; the current diff approach is appropriate.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @.github/scripts/download_triton_wheel.sh around lines 13 - 19, The pipeline
assignment to ROCM_VERSION can fail under set -euo pipefail, making the else
fallback unreachable; modify the command that sets ROCM_VERSION (the line
assigning ROCM_VERSION using dpkg -l rocm-core | awk) to neutralize failures
(e.g., append || true) so the command never returns non-zero and the subsequent
if [[ -n "${ROCM_VERSION}" ]]; then ... else ... block can run and correctly set
ROCM_MAJOR_MINOR and TRITON_INDEX_URL or fall back to the default ROCM version
message.

Comment thread aiter/fused_moe.py
Comment on lines +1420 to +1426
# Debug: AITER_FLYDSL_FORCE=1 is for debug use.
_flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "1") == "1"
use_mxfp4_flydsl = (
dtype in [dtypes.bf16, dtypes.fp16]
and q_type == QuantType.per_1x32
and activation == ActivationType.Swiglu
and q_dtype_a == dtypes.fp4x2
and (activation == ActivationType.Swiglu or _flydsl_force)
and q_dtype_a in (dtypes.fp4x2, dtypes.fp8)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Make the FlyDSL force override opt-in.

AITER_FLYDSL_FORCE is documented here as a debug switch, but defaulting it to "1" makes the override active even when the env var is unset. That broadens the fallback to non-SwiGLU activations on the normal path and changes kernel selection by default instead of only under explicit opt-in.

Suggested fix
-    _flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "1") == "1"
+    _flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "0") == "1"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Debug: AITER_FLYDSL_FORCE=1 is for debug use.
_flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "1") == "1"
use_mxfp4_flydsl = (
dtype in [dtypes.bf16, dtypes.fp16]
and q_type == QuantType.per_1x32
and activation == ActivationType.Swiglu
and q_dtype_a == dtypes.fp4x2
and (activation == ActivationType.Swiglu or _flydsl_force)
and q_dtype_a in (dtypes.fp4x2, dtypes.fp8)
# Debug: AITER_FLYDSL_FORCE=1 is for debug use.
_flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "0") == "1"
use_mxfp4_flydsl = (
dtype in [dtypes.bf16, dtypes.fp16]
and q_type == QuantType.per_1x32
and (activation == ActivationType.Swiglu or _flydsl_force)
and q_dtype_a in (dtypes.fp4x2, dtypes.fp8)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@aiter/fused_moe.py` around lines 1420 - 1426, The code currently enables the
FlyDSL force override by default because _flydsl_force =
os.environ.get("AITER_FLYDSL_FORCE", "1") == "1"; change this to be opt-in by
default (e.g. use os.environ.get("AITER_FLYDSL_FORCE", "0") == "1" or check
presence with os.environ.get("AITER_FLYDSL_FORCE") == "1") so _flydsl_force is
False when the variable is unset; update the surrounding use_mxfp4_flydsl
condition (symbols: _flydsl_force and use_mxfp4_flydsl) accordingly so FlyDSL is
only forced when the env var is explicitly set to "1".

Comment thread aiter/fused_moe.py
Comment on lines +1807 to +1811
# EP: forward expert_mask + topk_ids to the flydsl stage2 wrapper so it can
# switch to reduce mode and fuse the validity gather in compile_moe_reduction.
if stage2_func is _flydsl_stage2_wrapper and expert_mask is not None:
extra_stage2_args["expert_mask"] = expert_mask
extra_stage2_args["topk_ids"] = topk_ids
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Pass global expert ids to FlyDSL reduce masking.

This forwards topk_ids from fused_moe_2stages, but that parameter can already be replaced with local_topk_ids in fused_moe_ for split-k bias handling. The new reduce kernel treats topk_ids as indices into expert_mask, so feeding localized ids here will mark the wrong routes valid/invalid under EP. Thread the original global topk_ids separately for the reduction path and keep the localized ids only for kernels that actually need them.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@aiter/fused_moe.py` around lines 1807 - 1811, The current branch that sets
extra_stage2_args["expert_mask"] and extra_stage2_args["topk_ids"] for the
_flydsl_stage2_wrapper incorrectly forwards localized ids (local_topk_ids) into
the reduction kernel; change it so the reduction path receives the original
global topk ids while kernels that need split-k localized ids still get
local_topk_ids. Concretely, in the if block guarded by stage2_func is
_flydsl_stage2_wrapper and expert_mask is not None, add/assign a distinct key
(e.g., "global_topk_ids") with the original topk_ids and leave "topk_ids"
pointing to local_topk_ids only for places that use localized routing; update
any call-sites inside fused_moe_2stages / fused_moe_ that expect global vs local
ids to read the appropriate key so the new reduce kernel uses global_topk_ids
while other kernels continue to use topk_ids (local_topk_ids).

Comment on lines +87 to +110
if XQ.dim() != 2 or WQ.dim() != 2:
raise RuntimeError(
f"[FlyDSL gfx1250] A/B must be 2-D, got {tuple(XQ.shape)}, {tuple(WQ.shape)}"
)
if XQ.element_size() != 1 or WQ.element_size() != 1:
raise RuntimeError("[FlyDSL gfx1250] A/B must be 1-byte fp8 storage")

M, K = XQ.shape
N = WQ.shape[0]
if K != WQ.shape[1]:
raise RuntimeError(f"[FlyDSL gfx1250] K mismatch: A.K={K} vs B.K={WQ.shape[1]}")
if N % tile_n != 0:
raise RuntimeError(f"[FlyDSL gfx1250] N={N} not a multiple of tile_n={tile_n}")
if K % _WMMA_K != 0 or K % tile_k != 0:
raise RuntimeError(
f"[FlyDSL gfx1250] K={K} must be a multiple of WMMA_K={_WMMA_K} and "
f"tile_k={tile_k}"
)

out_dtype = _OUT_DTYPE_NAME.get(Out.dtype)
if out_dtype is None:
raise RuntimeError(
f"[FlyDSL gfx1250] unsupported out dtype {Out.dtype}; expected bf16/fp16"
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate tensor device and output-shape contracts before launching.

This path never checks that XQ, WQ, x_scale, w_scale, and Out live on the same accelerator device, and it never checks that Out.shape == (M, N). With padded_m == M, Out.contiguous() is handed straight to the compiled kernel, so a mismatched Out can become an out-of-bounds device write instead of a clean Python error.

Proposed guard
     M, K = XQ.shape
     N = WQ.shape[0]
+    expected_device = XQ.device
+    tensors = {
+        "WQ": WQ,
+        "x_scale": x_scale,
+        "w_scale": w_scale,
+        "Out": Out,
+    }
+    if expected_device.type != "cuda":
+        raise RuntimeError(
+            f"[FlyDSL gfx1250] expected accelerator tensors, got {expected_device}"
+        )
+    for name, tensor in tensors.items():
+        if tensor.device != expected_device:
+            raise RuntimeError(
+                f"[FlyDSL gfx1250] {name} must be on {expected_device}, got {tensor.device}"
+            )
+    if Out.shape != (M, N):
+        raise RuntimeError(
+            f"[FlyDSL gfx1250] Out must have shape {(M, N)}, got {tuple(Out.shape)}"
+        )

Also applies to: 125-126, 159-181

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py` around lines 87 - 110, Add
device and output-shape validation before launching the kernel: verify XQ, WQ,
x_scale, w_scale, and Out are on the same device and that Out.shape == (M, N);
also ensure Out is contiguous (or call Out = Out.contiguous()) when padded_m ==
M to avoid out-of-bounds writes. Implement these checks alongside the existing
dimensional/element-size checks in the gfx1250 path (the block validating XQ/WQ,
M, N, K, tile_n/tile_k and Out.dtype) and duplicate the same guard logic in the
other similar validation blocks present later in this file (the other
gfx1250/validation sections).

Comment on lines +119 to +123
# Pipeline depth needs >= 1 K tile per buffer (per split-k chunk).
num_k_tiles = (K // split_k) // tile_k
nb = max(2, min(int(num_buffers), num_k_tiles))
if nb not in _SUPPORTED_NUM_BUFFERS:
nb = max(b for b in _SUPPORTED_NUM_BUFFERS if b <= nb)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject unsupported pipeline depths instead of silently forcing nb=2.

When a split-K chunk contains only one K tile, max(2, min(..., num_k_tiles)) still forces nb back to 2. That bypasses the same shape constraint enforced by kernel_fits_shape(...) and lets direct callers compile a two-stage pipeline for a one-tile chunk.

Proposed fix
     # Pipeline depth needs >= 1 K tile per buffer (per split-k chunk).
     num_k_tiles = (K // split_k) // tile_k
-    nb = max(2, min(int(num_buffers), num_k_tiles))
+    if num_k_tiles < min(_SUPPORTED_NUM_BUFFERS):
+        raise RuntimeError(
+            f"[FlyDSL gfx1250] split_k={split_k} leaves only {num_k_tiles} K tiles; "
+            f"need at least {min(_SUPPORTED_NUM_BUFFERS)} for this pipeline"
+        )
+    nb = min(int(num_buffers), num_k_tiles)
     if nb not in _SUPPORTED_NUM_BUFFERS:
         nb = max(b for b in _SUPPORTED_NUM_BUFFERS if b <= nb)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py` around lines 119 - 123, Compute
num_k_tiles as before, then do not coerce nb up to 2; set nb =
min(int(num_buffers), num_k_tiles) and if nb is not in _SUPPORTED_NUM_BUFFERS
pick the largest supported buffer count <= nb, but if that yields a value < 1 or
yields a supported nb that is greater than num_k_tiles (e.g. forcing 2 when
num_k_tiles == 1) raise an explicit exception (ValueError) rejecting the
unsupported pipeline depth instead of silently forcing nb=2; update the logic
around num_k_tiles, nb and _SUPPORTED_NUM_BUFFERS and ensure callers that rely
on kernel_fits_shape(...) will see the error when a pipeline depth cannot be
satisfied.

Comment on lines +43 to +79
assert num_expert_group > 1
assert (
num_expert_group <= 16
), f"NUM_EXPERT_GROUP ({num_expert_group}) > 16 not supported"
assert 0 < topk_group <= num_expert_group
assert 0 < k <= 16
assert score_mode in (
"softmax",
"sigmoid",
"sqrtsoftplus",
"none",
), f"unknown score_mode {score_mode!r}"
has_bias = bias is not None
if has_bias:
assert bias.dim() == 1 and bias.shape[0] == n_cols
assert bias.dtype == torch.float32
assert score_mode in (
"sqrtsoftplus",
"sigmoid",
), "bias only supported with sqrtsoftplus / sigmoid"

dev = x.device

# Default expert→group mapping = contiguous DeepSeek layout.
if expert_group is None:
assert n_cols % num_expert_group == 0, (
f"n_expts_tot ({n_cols}) not divisible by num_expert_group "
f"({num_expert_group}); pass an explicit expert_group table."
)
g_size = n_cols // num_expert_group
expert_group = (
torch.arange(n_cols, device=dev, dtype=torch.int32) // g_size
).to(torch.int32)
else:
assert expert_group.dim() == 1 and expert_group.shape[0] == n_cols
assert expert_group.dtype == torch.int32

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate grouped configs before launching the kernel.

grouped_topk() now accepts arbitrary expert_group, but Lines 67-79 only check shape and dtype. If the mapping contains out-of-range ids, empty groups, or a grouping where some valid topk_group selection can expose fewer than k experts, _grouped_topk will exhaust real candidates and keep emitting duplicate indices instead of padding. Please reject invalid mappings/counts up front, or add explicit sentinel handling in the kernel.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@aiter/ops/triton/moe/moe_routing/topk.py` around lines 43 - 79, Validate
expert_group mapping in grouped_topk before launching _grouped_topk: ensure all
expert_group values are in [0, num_expert_group), compute per-group counts with
torch.bincount (length num_expert_group) and assert no empty groups, and
additionally assert the worst‑case selection of topk_group groups can provide k
experts by checking that the sum of the smallest topk_group counts >= k; raise a
clear AssertionError referencing expert_group/num_expert_group/topk_group/k when
any check fails.

Comment on lines +24 to +41
bool aiter_add_(aiter_tensor_t &input, aiter_tensor_t &other)
{
torch::Tensor output = initOutput(input, other);
binary_op_dispatch("div", input, other, output);
return output;
return binary_op_dispatch("add", input, other, input);
}

// inp interface
torch::Tensor aiter_add_(torch::Tensor &input, torch::Tensor &other)
bool aiter_sub_(aiter_tensor_t &input, aiter_tensor_t &other)
{
binary_op_dispatch("add", input, other, input);
return input;
return binary_op_dispatch("sub", input, other, input);
}

torch::Tensor aiter_sub_(torch::Tensor &input, torch::Tensor &other)
bool aiter_mul_(aiter_tensor_t &input, aiter_tensor_t &other)
{
binary_op_dispatch("sub", input, other, input);
return input;
return binary_op_dispatch("mul", input, other, input);
}

torch::Tensor aiter_mul_(torch::Tensor &input, torch::Tensor &other)
bool aiter_div_(aiter_tensor_t &input, aiter_tensor_t &other)
{
binary_op_dispatch("mul", input, other, input);
return input;
}

torch::Tensor aiter_div_(torch::Tensor &input, torch::Tensor &other)
{
binary_op_dispatch("div", input, other, input);
return input;
return binary_op_dispatch("div", input, other, input);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject non-contiguous input in the in-place native wrappers.

These wrappers alias input as output, but the transpose/broadcast kernels assume the destination is laid out contiguously. A non-contiguous input can still match the transpose path, so binary_op_dispatch() will write back with dense offsets into the wrong locations. Returning false here when !input.is_contiguous() lets the Python layer fall back to input.<op>_() and preserves correct strided behavior.

Suggested guard
 bool aiter_add_(aiter_tensor_t &input, aiter_tensor_t &other)
 {
+  if (!input.is_contiguous()) {
+    return false;
+  }
   return binary_op_dispatch("add", input, other, input);
 }

 bool aiter_sub_(aiter_tensor_t &input, aiter_tensor_t &other)
 {
+  if (!input.is_contiguous()) {
+    return false;
+  }
   return binary_op_dispatch("sub", input, other, input);
 }

 bool aiter_mul_(aiter_tensor_t &input, aiter_tensor_t &other)
 {
+  if (!input.is_contiguous()) {
+    return false;
+  }
   return binary_op_dispatch("mul", input, other, input);
 }

 bool aiter_div_(aiter_tensor_t &input, aiter_tensor_t &other)
 {
+  if (!input.is_contiguous()) {
+    return false;
+  }
   return binary_op_dispatch("div", input, other, input);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bool aiter_add_(aiter_tensor_t &input, aiter_tensor_t &other)
{
torch::Tensor output = initOutput(input, other);
binary_op_dispatch("div", input, other, output);
return output;
return binary_op_dispatch("add", input, other, input);
}
// inp interface
torch::Tensor aiter_add_(torch::Tensor &input, torch::Tensor &other)
bool aiter_sub_(aiter_tensor_t &input, aiter_tensor_t &other)
{
binary_op_dispatch("add", input, other, input);
return input;
return binary_op_dispatch("sub", input, other, input);
}
torch::Tensor aiter_sub_(torch::Tensor &input, torch::Tensor &other)
bool aiter_mul_(aiter_tensor_t &input, aiter_tensor_t &other)
{
binary_op_dispatch("sub", input, other, input);
return input;
return binary_op_dispatch("mul", input, other, input);
}
torch::Tensor aiter_mul_(torch::Tensor &input, torch::Tensor &other)
bool aiter_div_(aiter_tensor_t &input, aiter_tensor_t &other)
{
binary_op_dispatch("mul", input, other, input);
return input;
}
torch::Tensor aiter_div_(torch::Tensor &input, torch::Tensor &other)
{
binary_op_dispatch("div", input, other, input);
return input;
return binary_op_dispatch("div", input, other, input);
bool aiter_add_(aiter_tensor_t &input, aiter_tensor_t &other)
{
if (!input.is_contiguous()) {
return false;
}
return binary_op_dispatch("add", input, other, input);
}
bool aiter_sub_(aiter_tensor_t &input, aiter_tensor_t &other)
{
if (!input.is_contiguous()) {
return false;
}
return binary_op_dispatch("sub", input, other, input);
}
bool aiter_mul_(aiter_tensor_t &input, aiter_tensor_t &other)
{
if (!input.is_contiguous()) {
return false;
}
return binary_op_dispatch("mul", input, other, input);
}
bool aiter_div_(aiter_tensor_t &input, aiter_tensor_t &other)
{
if (!input.is_contiguous()) {
return false;
}
return binary_op_dispatch("div", input, other, input);
}
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/kernels/binary_operator.cu` around lines 24 - 41, The in-place native
wrappers aiter_add_, aiter_sub_, aiter_mul_, and aiter_div_ must reject
non-contiguous destinations: before calling binary_op_dispatch("...", input,
other, input) check input.is_contiguous() and return false when it is not; this
prevents binary_op_dispatch (and the transpose/broadcast kernels it may select)
from writing dense offsets into the wrong locations and lets the Python layer
fall back to the correct strided in-place implementation.

Comment on lines +26 to +35
def kv_cache_cast_to_fp8(
x: torch.Tensor, padding=False, fp8_dtype=None
) -> torch.Tensor:
if fp8_dtype is None:
fp8_dtype = get_fp8_e4m3_dtype()
num_blocks, block_size, num_heads, head_dim = x.shape
assert num_heads == 1
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
sf = x_amax / 240.0
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fnuz)
x_scaled = (x * (1.0 / sf)).to(fp8_dtype)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🌐 Web query:

What are the max finite values (or recommended quantization ranges) for torch.float8_e4m3fnandtorch.float8_e4m3fnuz in PyTorch, and should they use different scale denominators when quantizing BF16 tensors to FP8?

💡 Result:

The max finite value for both torch.float8_e4m3fn and torch.float8_e4m3fnuz is 448.0 [1][2]. torch.float8_e4m3fn is based on the OCP (Open Compute Project) FP8 specification, where the suffix "fn" denotes "Finite, no infinity" [3][2][4]. torch.float8_e4m3fnuz (often referred to as the E4M3 FNUZ format) is an alternative variant supported by specific hardware (e.g., AMD MI300 series) where "fnuz" stands for "Finite, NaN, Unsigned, Zero" [5][6][7][8]. Both types lack representations for infinity and are generally restricted to the same nominal maximum value of 448.0 in practice when quantizing [1][2]. Regarding scale denominators, they should generally be treated identically for quantization purposes when using the same max value (448.0) [1]. Because both types share the same dynamic range, the scaling factor used to map higher-precision tensors (like bfloat16 or float32) to the FP8 range is calculated using the same logic: scale = 448.0 / amax [2] While the bit-level encodings differ between fn and fnuz variants (specifically regarding how NaNs and zeros are represented) [5][7][8], the quantization strategy—clamping values to the [–448.0, 448.0] range before casting—remains the same for both [1][9]. In standard PyTorch implementations, including those in torchao and NVIDIA Model-Optimizer, saturation is highly recommended because native casting does not always clamp by default, which can lead to unexpected overflows or NaN generation [10][11][9]. Top results: [1][5][9][2][7]

Citations:


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 1) Locate the helper and constant usage in this repo
rg -n "get_fp8_e4m3_dtype|240\.0|kv_cache_cast_to_fp8" -S . || true

# 2) Show the surrounding code in the target file around the snippet
target="op_tests/op_benchmarks/triton/bench_deepgemm_attention.py"
if [ -f "$target" ]; then
  sed -n '1,120p' "$target" | nl -ba | sed -n '1,120p'
fi

Repository: buddywhitman/aiter

Length of output: 9821


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 3) Find implementations of get_fp8_e4m3_dtype and any fp8 scaling utilities/constants
rg -n "def get_fp8_e4m3_dtype|get_fp8_e4m3_dtype\s*\(" -S . || true

# If present, inspect likely nearby files (by searching filenames containing fp8)
fd -a "fp8" . || true

Repository: buddywhitman/aiter

Length of output: 50375


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Show the helper + dtype selection/printing area
target="op_tests/op_benchmarks/triton/bench_deepgemm_attention.py"
echo "=== kv_cache_cast_to_fp8 + dtype selection context ==="
# Print around lines 1-120
awk 'NR>=1 && NR<=120 {printf "%6d  %s\n", NR, $0}' "$target" | sed -n '1,120p'

echo "=== later FP8 dtype usage around line ~220-320 ==="
awk 'NR>=220 && NR<=320 {printf "%6d  %s\n", NR, $0}' "$target" | sed -n '1,120p'

echo "=== definitions of get_fp8_e4m3_dtype ==="
types="aiter/ops/triton/utils/types.py"
awk 'NR>=1 && NR<=200 {printf "%6d  %s\n", NR, $0}' "$types" | sed -n '1,200p'

echo "=== locate fp8 min/max bounds utilities ==="
rg -n "get_fp8_min_max_bounds|get_fp8_.*max|fp8_max|max_pos|inv_max_pos|240\.0|448\.0" aiter/ops/triton -S || true

echo "=== inspect fused_fp8_quant fp8 range logic ==="
fused="aiter/ops/triton/quant/fused_fp8_quant.py"
# show around where it mentions get_fp8_min_max_bounds and dtype
awk 'NR>=350 && NR<=520 {printf "%6d  %s\n", NR, $0}' "$fused" | sed -n '1,200p'

Repository: buddywhitman/aiter

Length of output: 42658


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== ck_tile/vec_convert.h around constants ==="
file="csrc/include/ck_tile/vec_convert.h"
if [ -f "$file" ]; then
  sed -n '1,120p' "$file" | nl -ba | sed -n '35,80p'
else
  echo "missing $file"
fi

echo "=== mx_quant_utils.h around max_pos/max_normal comments ==="
file="csrc/include/mx_quant_utils.h"
if [ -f "$file" ]; then
  sed -n '70,130p' "$file" | nl -ba | sed -n '60,120p'
else
  echo "missing $file"
fi

echo "=== aiter/utility/fp4_utils.py fp8 dtype mapping ==="
file="aiter/utility/fp4_utils.py"
if [ -f "$file" ]; then
  sed -n '60,110p' "$file" | nl -ba | sed -n '1,80p'
else
  echo "missing $file"
fi

Repository: buddywhitman/aiter

Length of output: 170


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "=== ck_tile/vec_convert.h (show lines 1-120 or whole if smaller) ==="
file="csrc/include/ck_tile/vec_convert.h"
if [ -f "$file" ]; then
  wc -l "$file"
  sed -n '35,80p' "$file"
fi

echo "=== mx_quant_utils.h (show lines 70-130) ==="
file="csrc/include/mx_quant_utils.h"
if [ -f "$file" ]; then
  wc -l "$file"
  sed -n '80,130p' "$file"
fi

echo "=== aiter/utility/fp4_utils.py (show lines 60-110) ==="
file="aiter/utility/fp4_utils.py"
if [ -f "$file" ]; then
  wc -l "$file"
  sed -n '60,110p' "$file"
fi

Repository: buddywhitman/aiter

Length of output: 7426


Use the FP8 max for the selected dtype when computing sf in kv_cache_cast_to_fp8().

kv_cache_cast_to_fp8() always scales with sf = x_amax / 240.0, but torch.float8_e4m3fn and torch.float8_e4m3fnuz have different max finite ranges in this codebase (448 vs 240). Derive the denominator from fp8_dtype (e.g., fp8_max = torch.finfo(fp8_dtype).max; sf = x_amax / fp8_max) so the KV cache is scaled consistently with the chosen FP8 encoding.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@op_tests/op_benchmarks/triton/bench_deepgemm_attention.py` around lines 26 -
35, kv_cache_cast_to_fp8 currently hardcodes the FP8 max (240) when computing
sf; instead, derive the FP8 max from the selected fp8_dtype and use that to
compute sf. Update kv_cache_cast_to_fp8 (and where fp8_dtype is set via
get_fp8_e4m3_dtype) to compute fp8_max = torch.finfo(fp8_dtype).max (or
equivalent API for the chosen FP8 dtype) and set sf = x_amax / fp8_max before
scaling; keep the existing fallback when fp8_dtype is None and preserve
x.shape/asserts and subsequent casting to fp8_dtype.

Comment thread op_tests/test_mla.py
Comment on lines +582 to 584
# The ASM decode baseline aborts for these MLA configs when lse is requested
if return_lse:
pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

--return_lse still skips the bh64 Gluon return_lse path.

Once the ASM baseline is skipped here, the Gluon run becomes the only validation path — but test_absorb_decode_gluon() never forwards return_lse, so -lse on bh64 still times the output-only kernel and never checks the merged LSE contract.

Suggested fix
-        (attn_logits, attn_lse), us_gluon_decode = run_perftest(
+        (_, attn_lse), us_gluon_decode = run_perftest(
             mla_decode_gluon,
             q_nope,
             q_pe,
             kv_c,
             out_gluon.view(batch_size, nhead, v_head_dim),
             page_table,
             seq_info,
             sm_scale,
             use_2d_view=use_2d_view,
             min_kv_seq_len=ctx_lens,
+            return_lse=return_lse,
         )
@@
         err = checkAllclose(
             out_ref,
             out_gluon,
             msg=f"mla_decode-absorb    [golden vs gluon_mla]: {us_gluon_decode:>8.2f} us......",
         )
+        if return_lse and attn_lse is not None:
+            checkAllclose(
+                lse_ref,
+                attn_lse.reshape(total_q, nhead),
+                msg=f"mla_decode-absorb    [lse_ref vs gluon_mla_lse]: {us_gluon_decode:>8.2f} us......",
+            )
         return err, us_gluon_decode
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@op_tests/test_mla.py` around lines 582 - 584, The current early skip ("if
return_lse: pass") prevents the bh64 Gluon path from ever seeing return_lse, so
ensure the Gluon validation receives the flag: remove the no-op skip and forward
the return_lse flag into the Gluon test/run invocation (eg. pass return_lse into
test_absorb_decode_gluon or the Gluon runner call that handles bh64) so the
Gluon branch evaluates the merged LSE contract; update the call site(s) where
test_absorb_decode_gluon or the Gluon run is invoked to accept and propagate the
return_lse parameter.

Comment thread op_tests/test_moe_ep.py
Comment on lines +482 to +509
# a4w4: Silu + SEPARATED -> FlyDSL fp4/fp4 (AITER_FLYDSL_FORCE=1 drops the
# Swiglu gate). a8w4: Silu + INTERLEAVE -> q_dtype_a auto-picker selects
# fp8 on gfx950 (fused_moe.py:357-361), and since the L1261 CK-Tile
# pre-emption requires Swiglu, Silu falls through to the
# swiglu_mxfp4_flydsl branch (with FLYDSL_FORCE=1) and lands on
# flydsl_moe1_afp8_wfp4_... Needs AITER_BF16_FP8_MOE_BOUND<=token.
if quant_label == "a8w4_mxfp4":
act = ActivationType.Silu
gate_mode = GateMode.INTERLEAVE.value
else:
act = ActivationType.Silu
gate_mode = GateMode.SEPARATED.value
out, us = run_perftest(
fused_moe,
input_,
w1_a,
w2_a,
topk_weights,
topk_ids,
expert_mask=expert_mask,
activation=act,
gate_mode=gate_mode,
quant_type=QuantType.per_1x32,
w1_scale=w1_s,
w2_scale=w2_s,
num_warmup=3,
num_iters=16,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Force the fp8 runtime mode in the a8w4_mxfp4 EP test.

This test never sets AITER_BF16_FP8_MOE_BOUND, so the default token=128 case can stay on the bf16 runtime path instead of the fp8 path the label/comment claim to cover. That means g1u1_a8w4_mxfp4 can pass without ever validating the new fp8 EP route.

Suggested fix
+import os
 import argparse
@@
     if quant_label == "a8w4_mxfp4":
         act = ActivationType.Silu
         gate_mode = GateMode.INTERLEAVE.value
+        _old_bound = os.environ.get("AITER_BF16_FP8_MOE_BOUND")
+        os.environ["AITER_BF16_FP8_MOE_BOUND"] = "0"
     else:
         act = ActivationType.Silu
         gate_mode = GateMode.SEPARATED.value
-    out, us = run_perftest(
-        fused_moe,
-        input_,
-        w1_a,
-        w2_a,
-        topk_weights,
-        topk_ids,
-        expert_mask=expert_mask,
-        activation=act,
-        gate_mode=gate_mode,
-        quant_type=QuantType.per_1x32,
-        w1_scale=w1_s,
-        w2_scale=w2_s,
-        num_warmup=3,
-        num_iters=16,
-    )
+    try:
+        out, us = run_perftest(
+            fused_moe,
+            input_,
+            w1_a,
+            w2_a,
+            topk_weights,
+            topk_ids,
+            expert_mask=expert_mask,
+            activation=act,
+            gate_mode=gate_mode,
+            quant_type=QuantType.per_1x32,
+            w1_scale=w1_s,
+            w2_scale=w2_s,
+            num_warmup=3,
+            num_iters=16,
+        )
+    finally:
+        if quant_label == "a8w4_mxfp4":
+            if _old_bound is None:
+                os.environ.pop("AITER_BF16_FP8_MOE_BOUND", None)
+            else:
+                os.environ["AITER_BF16_FP8_MOE_BOUND"] = _old_bound
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@op_tests/test_moe_ep.py` around lines 482 - 509, The test for quant_label
"a8w4_mxfp4" doesn't force the fp8 runtime path, so set the fp8 runtime bound
before calling run_perftest (e.g., export or
os.environ["AITER_BF16_FP8_MOE_BOUND"]="1" or a value <= token) when quant_label
== "a8w4_mxfp4" so the code path exercised by run_perftest(fused_moe, ...) uses
the fp8 EP; update the branch that sets act and gate_mode in
op_tests/test_moe_ep.py to set that environment variable (or otherwise invoke
the runtime-mode switch) prior to the run_perftest call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.