[Triton] batched_gemm_a8w8_smallB_blockscale: split-K FP8 GEMM for small-B regime (fixes ROCm/aiter#3000)#2
Conversation
📝 WalkthroughWalkthroughAdds Triton FP8 small-B batched GEMM kernels and Python wrappers; extensive gfx1250 FlyDSL/WMMA kernels, tuners, and utilities; MoE routing and masked reduction plumbing; fused KV-cache upcast support; aiter_tensor_t operator ABI migration; CI/workflow/script updates; and many tests/benchmarks. ChangesCore Features & Kernels
CI, Scripts, Configs, Tests & Benchmarks
Sequence Diagram(s)sequenceDiagram
participant Caller
participant batched_gemm_a8w8_smallB_blockscale_bf16
participant per_token_group_quant
participant per_token_group_quant_kernel
participant batched_gemm_a8w8_smallB_blockscale
participant _batched_gemm_a8w8_smallB_blockscale_kernel
participant _split_k_reduce_flat_kernel
Caller->>batched_gemm_a8w8_smallB_blockscale_bf16: X (BF16), B_weight, B_scale
batched_gemm_a8w8_smallB_blockscale_bf16->>per_token_group_quant: X, group_size, transpose_bm
per_token_group_quant->>per_token_group_quant_kernel: Launch quantization kernel
per_token_group_quant_kernel-->>per_token_group_quant: X_q (FP8), scale (FP32)
per_token_group_quant-->>batched_gemm_a8w8_smallB_blockscale_bf16: quantized X_q, A_scale
batched_gemm_a8w8_smallB_blockscale_bf16->>batched_gemm_a8w8_smallB_blockscale: X_q, B_weight, A_scale, B_scale, split_k
batched_gemm_a8w8_smallB_blockscale->>_batched_gemm_a8w8_smallB_blockscale_kernel: Launch GEMM (split_k tiles)
_batched_gemm_a8w8_smallB_blockscale_kernel-->>batched_gemm_a8w8_smallB_blockscale: float32 partials or BF16 output
alt split_k > 1
batched_gemm_a8w8_smallB_blockscale->>_split_k_reduce_flat_kernel: Launch reduction
_split_k_reduce_flat_kernel-->>batched_gemm_a8w8_smallB_blockscale: BF16 reduced result
end
batched_gemm_a8w8_smallB_blockscale-->>Caller: BF16 output
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Poem
✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
|
…e co-issue (ROCm#3469) * add 8w policy * add selection filter * update filter * upd dsv3 * upd gptoss * upd kimi&kimik2 k2 still in tune * Update glm5_bf16_tuned_gemm.csv * Update gemm_kernels.py: add tm80 * Update kimik2_bf16_tuned_gemm.csv * change to fx.Pointer * Update dsv4_bf16_tuned_gemm.csv * Update kimi_bf16_tuned_gemm.csv * Update kimik2_bf16_tuned_gemm.csv * Update dsv4_bf16_tuned_gemm.csv
* Only show the error when someone calls gemm_a16w16_opus or opus_gemm_a16w16_tune on a non-gfx950 device
) * flydsl moe: add EP reduce path with optional reduce mask moe_sorting gains need_reduce/has_reduce_mask to size moe_buf for the stage2 reduce epilogue (per-topk slab vs atomic accumulate vs caller- owned). Stage2 wrapper and moe_gemm_2stage updated to drive the compile_moe_reduction path; test_moe_ep covers the new EP variants. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> * improve * hack topk to retrieve correct tuned result * format * flydsl moe_reduction: fix i32 voffset overflow for large intermediates The compile_moe_reduction kernel addressed X [tokens, topk, model_dim] via fx.rocdl.make_buffer_tensor + i32 per-thread voffsets. For large batches the per-token byte offset exceeds 2^32 (e.g. 131072*6*4096*2 = 6 GiB), wrapping the i32 voffset and corrupting the reduction. Fold the per-WG token offset into each buffer descriptor's 48-bit base pointer via create_buffer_resource(base_byte_offset=fx.Int64(...)). In-kernel voffsets stay i32-safe because they only span one token's slab (topk*model_dim*elem_bytes). The vector path now uses direct buffer_load/buffer_store against the base-offset-shifted resources (the make_buffer_tensor helper has no base hook). Behavior at typical sizes is unchanged; logits_diff at token=131072 drops from 0.337 to 0.0054, matching torch.sum. * flydsl moe_gemm_2stage: port to fx.Pointer API after main merge Adopt main's commit 1be4ee9 (ROCm#3373) plumbing in compile_moe_gemm1 and compile_moe_gemm2: switch kernel/launcher signatures to fx.Pointer, add a local _ptr_buffer_resource helper, precompute w_nbytes/sw_nbytes, and replace extract_base_index with ptrtoint + index_cast. For compile_moe_reduction, preserve the EP-reduce semantics (X, Y, expert_mask, topk_ids inputs and the fused valid = expert_mask[topk_ids [t,k]] != 0 gather) and the i32 voffset overflow fix from 7e1935f, while migrating its arguments to fx.Pointer. The per-token slab offset is folded into the descriptor base address via a helper that does the add in i64 before create_buffer_resource_from_addr. Add a new compile-time num_experts kwarg so the expert_mask descriptor can be sized exactly. Update _MoeGemm2ReduceWrapper and the moe_kernels reduce call site to pass pointers via flyc.from_c_void_p / _ptr_view_safe. * fix * fused_moe: replace need_reduce/has_reduce_mask with accumulate flag Collapse the two boolean knobs on moe_sorting into a single `accumulate` flag that mirrors how the stage2 path actually uses moe_buf, and pick it automatically from the compiled FlyDSL stage2 kernel name. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> * fused_moe: allocate moe_out in stage2 reduce mode; unhardcode a8w4 mxfp4 activation In FlyDSL stage2 reduce mode (no expert_mask), moe_sorting returns a zero-shape placeholder for moe_buf since the reduce path owns the final write. fused_moe_2stages now allocates a real [token_num, model_dim] output tensor before invoking stage2 when it detects the placeholder. Also splits the legacy case generator in test_moe_2stage.py so the a8w4 mxfp4 (per_1x32, fp8, fp4x2) triple iterates args.act instead of hardcoding Swiglu, allowing -a silu to actually take effect for -q 7. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com> * fix * fused_moe: fix black formatting --------- Co-authored-by: Claude Opus 4 <noreply@anthropic.com>
Co-authored-by: Sergey Solo <ssolovye@amd.com>
* CI: add timeouts for artifact downloads Fail stuck artifact downloads promptly so CI jobs do not occupy runners indefinitely. * CI: avoid packing memory-heavy tests together Weight high-memory Aiter tests during shard assignment so they do not regress into the same shard after FILE_TIMES updates.
…OCm#3542) * [mla_decode_gluon] return_lse: full decode + merged fp32 lse (remove DCP) return_lse now runs the full decode (stage-1 + stage-2 reduce) and returns the merged fp32 lse [batch, nhead] (e_max + log(e_sum) over the whole sequence), matching the asm/torch references. Removes the stage-1-only DCP path. Supported for all regimes. - per-split lse kept in a separate fp32 mid_lse buffer (not packed into the bf16 O), fixing the bf16 precision loss in the stage-2 merge - stage-2 (_mla_softmax_reducev_kernel) writes final_lse; NUM_KV_SPLITS==1 fast path writes it directly in stage-1 - test_mla.py: validate returned lse vs lse_ref; `o` is now identical with and without -lse, so the gluon bh16bn64 gate is unified to just non-empty splits (ctx >= 256 // batch_size). README updated. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> * [mla_decode_gluon] simplify lse comments Trim verbose comments along the call chain (drop repeated "separate fp32 tensor" / "matching logsumexp" prose, use "split" not "K"). Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
* fix multithread_reduce_max_dpp * fix
* CI: add fallback Triton wheel download * CI: retry Aiter artifact downloads * Revert "CI: retry Aiter artifact downloads" This reverts commit 48d5e01.
…m#3546) * update * new grid layout for triton * black format * add upcast_operand option
…P2 (ROCm#3287) Signed-off-by: Xavier Aguilar <Xavier.AguilarFruto@amd.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json (1)
1-8: 💤 Low valueNote: Config is identical to gfx942.
The gfx950 config is byte-for-byte identical to gfx942. This is acceptable if the optimal tuning parameters are the same for both architectures, but worth noting in case gfx950-specific tuning is planned for the future.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json` around lines 1 - 8, The file gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json is byte-for-byte identical to the gfx942 config; to make this explicit and avoid confusion later, add a clear metadata entry or comment inside the JSON indicating it's based on or copied from gfx942 (for example a top-level "based_on": "gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json" or "note": "identical to gfx942") and keep all existing entries (e.g., the keys like "B=2,K=4096,N=1024,M=1" and their BLOCK_M/BLOCK_N/BLOCK_K/split_k values) unchanged so future reviewers know the duplication is intentional.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py`:
- Around line 196-201: The call to per_token_group_quant_kernel uses a hardcoded
_FP8_E4M3_MAX (448.0) even when _default_fp8_dtype() selects float8_e5m2,
causing incorrect scaling; update per_token_group_quant() to compute fp8_max
from the selected fp8_dtype (e.g., map torch.float8_e4m3fnuz -> _FP8_E4M3_MAX
and torch.float8_e5m2 -> _FP8_E5M2_MAX (57344.0)) and pass that computed fp8_max
into the per_token_group_quant_kernel invocation (replace the literal
_FP8_E4M3_MAX), ensuring the kernel uses the correct dynamic range for both
scale and clamping.
---
Nitpick comments:
In
`@aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json`:
- Around line 1-8: The file gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json is
byte-for-byte identical to the gfx942 config; to make this explicit and avoid
confusion later, add a clear metadata entry or comment inside the JSON
indicating it's based on or copied from gfx942 (for example a top-level
"based_on": "gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json" or "note":
"identical to gfx942") and keep all existing entries (e.g., the keys like
"B=2,K=4096,N=1024,M=1" and their BLOCK_M/BLOCK_N/BLOCK_K/split_k values)
unchanged so future reviewers know the duplication is intentional.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 6690b32f-001a-4dfc-9ed0-7e8cade445b5
📒 Files selected for processing (7)
aiter/ops/triton/__init__.pyaiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.pyaiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.jsonaiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.jsonaiter/ops/triton/gemm/batched/__init__.pyaiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.pyop_tests/test_batched_gemm_smallB.py
| per_token_group_quant_kernel[(B_dim * M_dim, n_groups)]( | ||
| X, X_q, scale, | ||
| M_dim, B_dim, K, group_size, n_groups, | ||
| fp8_max=_FP8_E4M3_MAX, | ||
| TRANSPOSE_BM=transpose_bm, | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify FP8 dtype max values
python3 << 'EOF'
import torch
print("Checking FP8 max values...")
print(f"float8_e4m3fnuz: {torch.finfo(torch.float8_e4m3fnuz).max}")
print(f"float8_e5m2: {torch.finfo(torch.float8_e5m2).max}")
EOFRepository: buddywhitman/aiter
Length of output: 180
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Locate definitions/usages
rg -n "_FP8_E4M3_MAX|fp8_max=_FP8_E4M3_MAX|fp8_max=" aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
# Find other FP8 max helpers/constants in repo
rg -n "_FP8_MAX|57344\.0|448\.0" aiter | head -n 50
# Also find any helper that maps fp8_dtype -> max
rg -n "default_fp8|_default_fp8_dtype|get_fp8_max|finfo\(torch\.float8" aiter | head -n 50Repository: buddywhitman/aiter
Length of output: 3867
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Inspect the target file around FP8 dtype selection and the fp8_max usage
sed -n '1,120p' aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -n
sed -n '160,240p' aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -n
echo "-----"
# Inspect the corresponding _triton_kernels implementation (may have correct fp8_max logic)
sed -n '1,120p' aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -n
sed -n '160,240p' aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -nRepository: buddywhitman/aiter
Length of output: 15426
🏁 Script executed:
#!/bin/bash
set -euo pipefail
sed -n '1,120p' aiter/ops/triton/_triton_kernels/moe/quant_moe.py | cat -nRepository: buddywhitman/aiter
Length of output: 4756
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n "_FP8_MAX|57344\.0|448\.0 if .*_AMD_FP8|_AMD_FP8.*57344" aiter | head -n 50Repository: buddywhitman/aiter
Length of output: 276
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n "_AMD_FP8|_FP8_MAX|57344\.0" aiter | head -n 100
rg -n "_FP8_MAX" . | head -n 100Repository: buddywhitman/aiter
Length of output: 2002
🏁 Script executed:
#!/bin/bash
set -euo pipefail
sed -n '1,130p' op_tests/test_batched_gemm_smallB.py | cat -nRepository: buddywhitman/aiter
Length of output: 5597
🏁 Script executed:
#!/bin/bash
set -euo pipefail
rg -n "per_token_group_quant\\b" aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.pyRepository: buddywhitman/aiter
Length of output: 139
🏁 Script executed:
#!/bin/bash
set -euo pipefail
sed -n '140,320p' aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py | cat -nRepository: buddywhitman/aiter
Length of output: 5141
Fix fp8_max to match selected FP8 dtype (float8_e5m2 must use 57344.0, not 448.0)
aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py selects fp8_dtype via _default_fp8_dtype() (AMD: torch.float8_e4m3fnuz, else torch.float8_e5m2), but per_token_group_quant() always launches per_token_group_quant_kernel with fp8_max=_FP8_E4M3_MAX where _FP8_E4M3_MAX = 448.0. The Triton kernel uses fp8_max for both scale = abs_max / fp8_max and clamping, so float8_e5m2 gets a severely underutilized dynamic range. The reference test op_tests/test_batched_gemm_smallB.py uses _FP8_MAX = 448.0 if _AMD_FP8 else 57344.0 consistent with float8_e4m3fnuz vs float8_e5m2.
🐛 Proposed fix
Replace the hardcoded max with a dtype→max helper and pass the computed value into the kernel:
-_FP8_E4M3_MAX = 448.0
+def _get_fp8_max(dtype: torch.dtype) -> float:
+ """Return the maximum representable value for the given FP8 dtype."""
+ if dtype == torch.float8_e4m3fnuz or dtype == torch.float8_e4m3fn:
+ return 448.0
+ elif dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz:
+ return 57344.0
+ else:
+ raise ValueError(f"Unsupported FP8 dtype: {dtype}")+ if fp8_dtype is None:
+ fp8_dtype = _default_fp8_dtype()
+ fp8_max_val = _get_fp8_max(fp8_dtype)
+
per_token_group_quant_kernel[(B_dim * M_dim, n_groups)](
X, X_q, scale,
M_dim, B_dim, K, group_size, n_groups,
- fp8_max=_FP8_E4M3_MAX,
+ fp8_max=fp8_max_val,
TRANSPOSE_BM=transpose_bm,
)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py` around
lines 196 - 201, The call to per_token_group_quant_kernel uses a hardcoded
_FP8_E4M3_MAX (448.0) even when _default_fp8_dtype() selects float8_e5m2,
causing incorrect scaling; update per_token_group_quant() to compute fp8_max
from the selected fp8_dtype (e.g., map torch.float8_e4m3fnuz -> _FP8_E4M3_MAX
and torch.float8_e5m2 -> _FP8_E5M2_MAX (57344.0)) and pass that computed fp8_max
into the per_token_group_quant_kernel invocation (replace the literal
_FP8_E4M3_MAX), ensuring the kernel uses the correct dynamic range for both
scale and clamping.
) * grouped topk kernel with sigmoid for deepseek r1 integration * cleanup after rebase * restructured grouped topk kernel * moved grouped_topk to regular topk files * fix testing * refactor routing_a8w4 name to unify dsv4 and grouped topk + sigmoid * ruff * black formatting * shift routing to unified routing function * ruff * comment fixes * black * verified triton kernel is functionally the same as the one needed in atom. small change to import in routing * review adjustments * ruff * ruff * small change for routing without fused shared experts in topk score_mode branch * change to routing to not support fused shared experts
…OCm#3572) * fix(ci): build multi-Python wheels and publish versioned S3 manifest The aiter-test workflow builds a single wheel using rocm/pytorch:latest (Python 3.12), but downstream consumers like ATOM run Python 3.10. Installing a cp312 wheel into a cp310 environment fails. Add a build matrix (py3.10 + py3.12) to build_aiter_wheels and split the S3 upload into a dedicated upload_s3_manifest job that collects both wheels and generates a versioned latest.json manifest with a per-Python-tag `wheels` field. The legacy top-level wheel_name/wheel_url fields are preserved for backward compatibility. * chore(ci): bump py3.10 build image from rocm7.2.2 to rocm7.2.4 * fix(ci): skip py3.10 wheel build on non-main branches
* [module_aiter_operator] refactor * update
ROCm#3501) * [feat](pa_mqa_logits) add gfx1250 (RDNA/WMMA wave32) support to deepgemm fp8 paged MQA logits * [pa_mqa_logits] strip redundant gfx1250 comments --------- Co-authored-by: Felix Li <felix.li@amd.com>
* feat(gemm): add FlyDSL MXFP8 dense GEMM for gfx1250
Add gemm_a8w8_mxscale, the OCP MXFP8 (E4M3 activation/weight, 1x32 E8M0
scales) dense GEMM for gfx1250, as an a8w8-family variant in gemm_op_a8w8.py.
The op is gfx1250-only and auto-routes via the a8w8_mxscale tuned CSV, falling
back to a safe default kernel for untuned shapes.
- backend: aiter/ops/flydsl/mxscale_gemm.py (launch + kernel-name encode/parse)
with WMMA host layout helpers in mxscale_layout.py; vendored gfx1250 kernel
under ops/flydsl/kernels/ (gemm_fp8fp4_gfx1250, gemm_common_gfx1250,
pipeline_utils). Weight and scales are preshuffled into WMMA layout inside the
backend, so callers pass row-major / unshuffled tensors.
- routing/config: AITER_CONFIG_GEMM_MXSCALE + a8w8_mxscale_{tuned,untuned}_gemm.csv.
- tuner: FlyDSL-only driver + candidate catalogue under ops/flydsl/gemm_tune/.
- AOT: mxscale branch in aot/flydsl/gemm.py (arch taken from the kernel name).
- tests: op correctness (op_tests/test_gemm_mxscale.py) and CPU-only layout
checks (op_tests/test_mxscale_layout.py).
MX A8W4 is not integrated this round (vendored kernel supports it, but no public
op / routing / tuner path).
* fix(gemm): correct FlyDSL MXFP8 gfx1250 split_k
- split_k cross-workgroup accumulation now uses a device (agent) scoped
atomicrmw fadd, so partial sums are coherent across CUs (the buffer
atomic with default scope lost updates -> non-deterministic results)
- accumulate split_k>1 partials in an f32 scratch buffer and cast once,
so bf16/f16 output keeps single-shot precision instead of degrading
with the split count
- disable workgroup clusters in the tuner grid; keep split_k as a tuned knob
* feat(gemm): hoist MXScale B-weight preshuffle out of the per-call path
flydsl_mxscale_gemm re-padded and 16x16-preshuffled the B weight (and its
E8M0 scale) on every call -- a multi-MB layout pass that dominated latency
for memory-bound shapes (e.g. M=1 GEMV).
- split the per-call pad+preshuffle into weight-side / activation-side
helpers (preshuffle_mxscale_weight / preshuffle_mxscale_activation),
bit-identical to the old combined pad_mxscale_inputs path
- add public shuffle_weight_mxscale: pre-shuffle the weight once at load
time, marked is_shuffled (mirrors shuffle_weight / bpreshuffle convention);
the backend detects the marker, recovers N from the stashed shuffle key,
and reuses the tensors verbatim instead of re-shuffling
- mxscale tuner caches the shuffle per kernel_name so the weight layout pass
is excluded from measured kernel latency
- drop now-unused pad_mxscale_inputs
* fix(gemm): skip MXScale test on non-gfx1250 and trim split_k options
- test_gemm_mxscale.py: exit 0 (skip) instead of failing the CI job when run
as a standalone script on non-gfx1250 hardware; fall back to df.to_string
when tabulate is unavailable.
- flydsl_gemm_mxscale_gfx1250_common.py: prune split_k tuning options.
* feat(gemm): integrate MXScale GEMM into bpreshuffle path
* debug(mxscale): all gemm_bpreshuffle fallback to flydsl
* refactor(gemm): MXScale MXFP8 via bpreshuffle tuned-CSV dispatch (no K-pad)
B uses the generic shuffle_weight 16x16 layout; K is never padded (require
%128, raise otherwise); dispatch is driven purely by tuned-CSV flydsl_mxscale_*
kernel names. Drops the K-pad / weight marker / fallback machinery; adds an
early 1-byte dtype guard on A/B/scales.
TODO(tuned): shapes without a tuned-CSV entry never route to MXScale (fall back
to CK). Need tuned configs for all target shapes; K padding + untuned fallback
to be planned later.
* debug(gemm): fall back untuned bpreshuffle to flydsl MXScale instead of CK
Temporary: route untuned (config=None) shapes through the default MXScale kernel
so MXFP8 can be exercised without a tuned CSV. Revert once tuned configs land.
* [gfx1250][FlyDSL] Add WMMA a8w8 bpreshuffle FP8 GEMM
Route gemm_a8w8_bpreshuffle on gfx1250 to the FlyDSL gfx1250 WMMA kernel
(ptpc scale mode) via tuned-CSV kernelName prefix flydsl_bpreshuffle_wmma_,
with a dedicated backend and a standalone WMMA auto-tuner. Removes the
earlier MXScale op / layout / tuner.
* [debug] all gfx1250 bpreshuffle untuned kernel fallback to flydsl
* fix(gemm): fp32 atomic accumulation for FlyDSL gfx1250 split-k
split-k sums per-chunk partials via global atomic add. With a bf16/f16
output buffer the kernel rounded every partial to an 8-bit mantissa
before adding, so the error compounded with the chunk count (sk2 warned,
sk4+ failed the tuner's atol/rtol=0.01 check).
Accumulate split-k partials into an fp32 scratch buffer and narrow once
at the end, matching the non-split (sk1) result. All paths (dispatch,
untuned fallback, auto-tune) converge on run_preshuffle_gemm_a8_gfx1250,
so the single backend change covers them. Tighten test_split_k to pin
the split result to sk1 (the old cosine>0.99 check was too loose).
* Revert "[debug] all gfx1250 bpreshuffle untuned kernel fallback to flydsl"
This reverts commit 034e293.
---------
Co-authored-by: Felix Li <felix.li@amd.com>
…for small-B regime Fixes ROCm#3000. Unblocks ROCm/ATOM#676. Existing FP8 batched GEMM kernels dispatch only 16 work-groups on a 304-CU GPU (B=2, M=1: 5% CU utilization), losing 1.5x-7x to BF16 einsum. Three fixes: 1. Grid collapse + split-K=8: (B*M_tiles, N_tiles, split_k) grid -> 128 work-groups at B=2, M=1 -> 42% CU utilization 2. Fused bf16 output for split_k=1: eliminates reduction kernel overhead 3. Per-128-block W-scales loaded inline (no dequant/requant precision loss) New files: - aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py - aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py - aiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json - aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json - op_tests/test_batched_gemm_smallB.py Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…or ATOM#773 Two additions requested by ROCm/ATOM#773 closing comment: 'the proper fix lives in aiter — extend the per-group quant path to accept a permuted (G,S,R) input layout so the wo_a einsum can be replaced by a fused quant + grouped GEMM without requiring a non-contiguous bmm' 1. per_token_group_quant kernel (TRANSPOSE_BM=True): Accepts X in (M, B, K) = (tokens, n_local_groups, d_per_group) layout directly, matching the natural output of DeepSeek V4 attention. No .transpose(0,1) needed; avoids the non-contiguous view that violated hipBLAS strided-batched GEMM contract and caused the GSM8K regression. 2. batched_gemm_a8w8_smallB_blockscale_bf16 entry-point: Accepts raw BF16 activations, runs fused per-token-group quant (Triton kernel) then the split-K GEMM — one call, no separate act_quant launch. All 18 tests pass (NVIDIA RTX 3070Ti, sm_86, CUDA 13.3, Triton 3.2.0). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
7b91a13 to
208ff0e
Compare
There was a problem hiding this comment.
Actionable comments posted: 20
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
csrc/kernels/generate_binaryop.py (1)
18-34:⚠️ Potential issue | 🟠 Major | ⚡ Quick winKeep the generated dtype maps in sync.
DATA_TYPE_MAPstill advertisesfloat64andbool, andcmdGenFunc()now forwards rawtorch.dtypenames into this generator, butAITER_DTYPE_MAPno longer has matching entries. The first binary-op JIT on one of those dtypes will blow up with aKeyErrorduring codegen before the Python fallback gets a chance to run.Either add the missing
AITER_DTYPE_*mappings or explicitly short-circuit unsupported dtypes before invoking codegen.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/kernels/generate_binaryop.py` around lines 18 - 34, The DATA_TYPE_MAP and AITER_DTYPE_MAP are out of sync causing a KeyError during codegen; update AITER_DTYPE_MAP to include entries for the dtypes referenced in DATA_TYPE_MAP (at minimum "float64" and "bool") with the correct AITER_DTYPE_* constants, or alter cmdGenFunc() to validate/short-circuit unsupported torch.dtype names before calling the generator so codegen is never invoked for dtypes missing in AITER_DTYPE_MAP; reference DATA_TYPE_MAP, AITER_DTYPE_MAP and cmdGenFunc() when making the change.op_tests/test_mla.py (1)
667-681:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winDon’t skip valid
bh16bn64cases whenbatch_size > 256.The wrapper uses
NUM_KV_SPLITS = max(1, 256 // batch_size), so larger batches are still valid withNUM_KV_SPLITS == 1. This gate hard-caps at 256 and silently drops those configurations from the Gluon test/benchmark.Suggested fix
- and 1 <= batch_size <= 256 - and ctx_lens >= (256 // batch_size) + and batch_size >= 1 + and ctx_lens >= max(1, 256 // batch_size)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@op_tests/test_mla.py` around lines 667 - 681, The gate incorrectly excludes batch_size > 256; update the condition in the Gluon bh16bn64 decode check (the if using get_gfx(), dtype, kvtype, nhead, decode_qlen, v_head_dim, qk_head_dim, page_size, batch_size, ctx_lens) to mirror NUM_KV_SPLITS semantics by removing the hard cap and using ctx_lens >= max(1, 256 // batch_size) with a simple batch_size >= 1 check (i.e., replace "1 <= batch_size <= 256 and ctx_lens >= (256 // batch_size)" with "batch_size >= 1 and ctx_lens >= max(1, 256 // batch_size)").op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py (1)
55-62:⚠️ Potential issue | 🔴 Critical | ⚡ Quick winRemove unsupported
use_aot=Trueargument.Line 58 passes
use_aot=Truetogemm_afp4wfp4_preshuffle, but the function signature (defined at lines 416-425 inaiter/ops/triton/gemm/basic/gemm_afp4wfp4.py) does not accept ause_aotparameter. This will raise aTypeErrorat runtime.🐛 Proposed fix to remove the unsupported parameter
- ms = triton.testing.do_bench( - lambda: gemm_afp4wfp4_preshuffle( - x, w, x_scale, w_scale, c_dtype, y, use_aot=True - ), + ms = triton.testing.do_bench( + lambda: gemm_afp4wfp4_preshuffle( + x, w, x_scale, w_scale, c_dtype, y + ), warmup=25, rep=100, )🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py` around lines 55 - 62, The bench call is passing an unsupported keyword use_aot=True to gemm_afp4wfp4_preshuffle which will raise a TypeError; remove the use_aot=True argument from the lambda passed into triton.testing.do_bench so the call matches the gemm_afp4wfp4_preshuffle signature (keep all other args: x, w, x_scale, w_scale, c_dtype, y) and ensure the lambda still returns the function invocation for benchmarking.aiter/ops/triton/moe/moe_op_gemm_a8w4.py (1)
364-420:⚠️ Potential issue | 🟠 Major | ⚡ Quick winGuard
out_mx_quantoff on gfx1250 until the Gluon kernels actually emit scales.The new gfx1250 launch path never passes a
YMxScale/HAS_MX_OUTequivalent into either Gluon kernel, soy_scaleis returned uninitialized. On top of that,n_outis derived afterwis transposed, so the scale shape is computed fromKinstead of the output widthN.Suggested fix
use_gluon = get_arch() == "gfx1250" + if use_gluon and out_mx_quant: + raise NotImplementedError( + "out_mx_quant is not implemented for gfx1250 Gluon kernels yet" + ) + assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp"Also applies to: 435-538
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@aiter/ops/triton/moe/moe_op_gemm_a8w4.py` around lines 364 - 420, Disable the out_mx_quant path when using the Gluon kernels on gfx1250 (i.e. when use_gluon is true and the device/gfx id indicates gfx1250) so we don't return an uninitialized y_scale; specifically, gate the existing out_mx_quant logic (the out_dtype assignment, the allocate_output call that depends on reduction_n_matmul, and the y_scale allocation) behind a check that clears out_mx_quant for Gluon+gfx1250, and fix the n_out calculation to derive the post-swiglu output width from the actual output tensor shape (use y.shape[-1] or padded_N divided by reduction_n_matmul) instead of w.shape[-1] so the scale buffer size is computed from N not K; refer to symbols out_mx_quant, use_gluon, w, y, y_scale, allocate_output, and n_out when making the change.
🧹 Nitpick comments (2)
csrc/include/hip_reduce.h (1)
291-292: Tightenmultithread_reduce_max_dpp’s non-__GFX9__contract aroundthread_num==64The
thread_num==64implementation incsrc/include/hip_reduce.his compiled out for non-__GFX9__, but current callers don’t instantiate<64>there:topk_gating_kernels.cupassesWARP_SIZE(32 when not__GFX9__), andfused_qk_rmsnorm_group_quant.cuenforcesgroup_size <= WARP_SIZE * thread_data_size, so non-__GFX9__dispatch can’t produceReduceThreadSize==64. Still, tightening thestatic_asserttothread_num <= 32for!__GFX9__(or an explicit compile-time error forthread_num==64) would better align the template contract with actual support.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/include/hip_reduce.h` around lines 291 - 292, The static_assert in multithread_reduce_max_dpp currently allows thread_num up to 64 unconditionally; update the contract so non-__GFX9__ builds disallow 64: modify the check around the static_assert in multithread_reduce_max_dpp (symbol: multithread_reduce_max_dpp / the static_assert line) to be conditional on __GFX9__ — for __GFX9__ keep the existing <=64 power-of-two check, otherwise tighten to thread_num >=1 && thread_num <=32 && power-of-two (or explicitly static_assert(false) when thread_num==64 on non-__GFX9__), so callers cannot instantiate a 64-thread variant on targets that don't support it.op_tests/triton_tests/fusions/test_fused_kv_cache.py (1)
219-240: ⚡ Quick winAdd a partial-prefix decode-output case here.
This test always uses
num_decode_toks_for_zeros = T, so it never exercises the case where only the first few batches should writedecode_q_pe_out/q_nope_zeros_out. The new pid transpose in the fused kernels only breaks in that partial-prefix setup, so a small case likeT=8, num_decode_toks_for_zeros=3would catch the current OOB regression immediately.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@op_tests/triton_tests/fusions/test_fused_kv_cache.py` around lines 219 - 240, The test currently sets num_decode_toks_for_zeros = T so it never exercises the partial-prefix decode-output path; modify the test around the call to fused_qk_rope_cat_and_cache_mla to add at least one scenario where num_decode_toks_for_zeros < T (for example T=8 and num_decode_toks_for_zeros=3) so the function fused_qk_rope_cat_and_cache_mla (and outputs decode_q_pe_out / q_nope_zeros_out) is exercised for partial-prefix writes; ensure the same setup (slot_mapping, positions, shuffled_kv_cache, upcast_operand) is reused and only num_decode_toks_for_zeros is changed so the new case triggers the pid-transpose OOB regression.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In @.github/scripts/download_triton_wheel.sh:
- Around line 13-19: The pipeline assignment to ROCM_VERSION can fail under set
-euo pipefail, making the else fallback unreachable; modify the command that
sets ROCM_VERSION (the line assigning ROCM_VERSION using dpkg -l rocm-core |
awk) to neutralize failures (e.g., append || true) so the command never returns
non-zero and the subsequent if [[ -n "${ROCM_VERSION}" ]]; then ... else ...
block can run and correctly set ROCM_MAJOR_MINOR and TRITON_INDEX_URL or fall
back to the default ROCM version message.
In `@aiter/fused_moe.py`:
- Around line 1807-1811: The current branch that sets
extra_stage2_args["expert_mask"] and extra_stage2_args["topk_ids"] for the
_flydsl_stage2_wrapper incorrectly forwards localized ids (local_topk_ids) into
the reduction kernel; change it so the reduction path receives the original
global topk ids while kernels that need split-k localized ids still get
local_topk_ids. Concretely, in the if block guarded by stage2_func is
_flydsl_stage2_wrapper and expert_mask is not None, add/assign a distinct key
(e.g., "global_topk_ids") with the original topk_ids and leave "topk_ids"
pointing to local_topk_ids only for places that use localized routing; update
any call-sites inside fused_moe_2stages / fused_moe_ that expect global vs local
ids to read the appropriate key so the new reduce kernel uses global_topk_ids
while other kernels continue to use topk_ids (local_topk_ids).
- Around line 1420-1426: The code currently enables the FlyDSL force override by
default because _flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "1") ==
"1"; change this to be opt-in by default (e.g. use
os.environ.get("AITER_FLYDSL_FORCE", "0") == "1" or check presence with
os.environ.get("AITER_FLYDSL_FORCE") == "1") so _flydsl_force is False when the
variable is unset; update the surrounding use_mxfp4_flydsl condition (symbols:
_flydsl_force and use_mxfp4_flydsl) accordingly so FlyDSL is only forced when
the env var is explicitly set to "1".
In `@aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py`:
- Around line 87-110: Add device and output-shape validation before launching
the kernel: verify XQ, WQ, x_scale, w_scale, and Out are on the same device and
that Out.shape == (M, N); also ensure Out is contiguous (or call Out =
Out.contiguous()) when padded_m == M to avoid out-of-bounds writes. Implement
these checks alongside the existing dimensional/element-size checks in the
gfx1250 path (the block validating XQ/WQ, M, N, K, tile_n/tile_k and Out.dtype)
and duplicate the same guard logic in the other similar validation blocks
present later in this file (the other gfx1250/validation sections).
- Around line 119-123: Compute num_k_tiles as before, then do not coerce nb up
to 2; set nb = min(int(num_buffers), num_k_tiles) and if nb is not in
_SUPPORTED_NUM_BUFFERS pick the largest supported buffer count <= nb, but if
that yields a value < 1 or yields a supported nb that is greater than
num_k_tiles (e.g. forcing 2 when num_k_tiles == 1) raise an explicit exception
(ValueError) rejecting the unsupported pipeline depth instead of silently
forcing nb=2; update the logic around num_k_tiles, nb and _SUPPORTED_NUM_BUFFERS
and ensure callers that rely on kernel_fits_shape(...) will see the error when a
pipeline depth cannot be satisfied.
In `@aiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.py`:
- Around line 116-126: The CSV always records splitK as 0 because the info tuple
third slot is being hardcoded; fix by passing the real split_k when constructing
the info tuple in get_wmma_tune_task() (use ki.split_k instead of 0) and ensure
result_to_df() (and the similar block around lines 148-179) reads that third
element into splitK and writes it into the DataFrame (falling back to 0 only if
the info tuple lacks the slot). Update all places that build the info tuple so
the third element carries the actual ki.split_k, and keep result_to_df's
unpacking (keys, kernelId, splitK, kernelName, libtype) to persist the real
splitK value.
In `@aiter/ops/flydsl/kernels/moe_gemm_2stage.py`:
- Around line 3644-3647: The module_name currently ignores num_experts and the
new public ABI change; update construction of module_name (the variable
module_name) to also encode num_experts and an ABI tag so cached binaries are
unique to the expert count and signature change — include num_experts and a
short ABI identifier (e.g. "abi_expertmask_topkids" or similar) alongside
existing parts built from use_mask, dtype_str, topk and model_dim so the name
reflects the new expert_mask descriptor and the (expert_mask, topk_ids) ABI.
In `@aiter/ops/shuffle.py`:
- Around line 29-37: The 3D branch always treats w as (E, K, N) then transposes
to (E, N, K), which breaks callers that pass (E, N, K); update the branch in
shuffle.py to first detect the layout (e.g., inspect which axis lengths are
divisible by 32 vs 16 or use an explicit flag if available), then apply
divisibility checks against the correct axes (K must be divisible by 32 and N by
16) and only transpose when the input is in (E, K, N) form; finally perform the
view/permutations using the correct dimension ordering so the output layout is
consistent for both input layouts (refer to the existing variables and
operations: w, w.ndim == 3 branch, the transpose(), view(...16, K // 32, 2, 16),
permute(...), and final transpose()).
In `@aiter/ops/triton/_gluon_kernels/gfx1250/fusions/fused_kv_cache.py`:
- Around line 394-399: The transpose of pid (pid_hq = pid // B; pid_b = pid % B)
broke the original batch-based gate; update the downstream condition that
currently uses pid (e.g., the `if pid < num_decode_toks_for_zeros * QH` / OUTPUT
block that writes to decode_q_pe_out and q_nope_zeros_out) to also require the
batch predicate pid_b < num_decode_toks_for_zeros so stores are only issued for
batches inside the decode-output prefix; locate the check that previously used
pid and add `and pid_b < num_decode_toks_for_zeros` (or equivalent) to gate the
tail path that writes to decode_q_pe_out / q_nope_zeros_out.
In `@aiter/ops/triton/_gluon_kernels/gfx1250/gemm/basic/gemm_mxfp4.py`:
- Around line 279-376: The pipeline assumes k_tiles >= NUM_BUFFERS and will
overcompute when that's false; update the GEMM tile loop logic (variables
k_tiles, NUM_BUFFERS, main_iters, the prologue/epilogue sections and final WMMA)
to guard that case: if k_tiles < NUM_BUFFERS, set a local pipeline_depth =
min(NUM_BUFFERS, k_tiles) (or short-circuit into a simple non-pipelined path)
and use pipeline_depth in place of NUM_BUFFERS for prologue/epilogue waits,
buffer indexing (load_idx/compute_idx), and main_iters so you only issue/compute
exactly k_tiles tiles; ensure async_load/async_wait counts and the final WMMA
use the adjusted pipeline_depth to avoid draining nonexistent tiles.
In `@aiter/ops/triton/_triton_kernels/fusions/fused_kv_cache.py`:
- Around line 310-315: The batch vs head flattening was changed (pid_hq = pid //
B; pid_b = pid % B) which breaks the decode-output bounds check; restore correct
batch-major mapping or ensure bounds use pid_b. Either revert to the original
mapping (pid_b = pid // QH; pid_hq = pid % QH) in fused_kv_cache.py so pid_b
represents the batch index, or keep the current mapping but change the
decode/OUTPUT block checks that reference pid/flattened pid to use pid_b when
testing against num_decode_toks_for_zeros so writes to decode_q_pe_out and
q_nope_zeros_out are clamped by batch index.
In `@aiter/ops/triton/_triton_kernels/moe/moe_routing/topk.py`:
- Around line 538-551: The bias-mode top-2 computation currently suppresses the
first max by value-equality (scores_for_choice == gm1_per_e), which removes all
tied-max experts; instead compute the per-group argmax index and suppress by
index. Replace the equality-based suppression around gm1_per_e with an index
mask derived from tl.argmax over group_max1 (or the appropriate per-group argmax
of scores_for_choice using gid_eq) so only the chosen max position is masked
(use that argmax mask to create sup_3d/expanded2), preserving other tied maxima
and then compute group_max2 as before.
In `@aiter/ops/triton/attention/pa_mqa_logits.py`:
- Around line 459-460: The code currently mutates the caller-visible WavePerEU
when forcing gfx1250/non-preshuffle to compile with WavePerEU=1; instead, do not
overwrite the API parameter WavePerEU—instead introduce a new local variable
(e.g., effective_wave_per_eu) set to WavePerEU by default and set
effective_wave_per_eu = 1 only when get_gfx() == "gfx1250" and not Preshuffle;
then use effective_wave_per_eu for compilation decisions and any artifact naming
logic (the places that previously read WavePerEU for AOT artifact names and the
compilation path), leaving the original WavePerEU value untouched for callers
and tests like op_tests/op_benchmarks/triton/bench_deepgemm_attention.py.
In `@aiter/ops/triton/gluon/README.md`:
- Around line 87-89: The README for the bh16bn64 regime is ambiguous: document
the NUM_KV_SPLITS==1 fast path where the wrapper skips stage-2 reduction and
writes o directly (instead of running the two-stage reduce), and clarify how
return_lse behaves in that case; update the bh16bn64 paragraph to state that
when NUM_KV_SPLITS = 1 (e.g., batch_size >= 256) the implementation bypasses
stage-2 and mla_decode_gluon(...) returns (o, final_lse) only when
return_lse=True and the merged fp32 LSE is computed locally, otherwise the
wrapper writes o directly without a separate stage-2 reduction.
In `@aiter/ops/triton/moe/moe_routing/routing.py`:
- Around line 308-310: The small-token sort path calls
triton.next_power_of_2(num_tokens) unguarded and will crash when num_tokens ==
0; change the logic in the routing code that sets HIST_BLOCK_M and sort_fn (the
block using HIST_BLOCK_M = triton.next_power_of_2(num_tokens) and sort_fn =
sort_tokens_fused) to use a safe value such as safe_num_tokens = max(num_tokens,
1) (consistent with the fused score-mode/hash paths) and call
triton.next_power_of_2(safe_num_tokens) so empty local batches do not trigger an
error.
In `@aiter/ops/triton/moe/moe_routing/topk.py`:
- Around line 43-79: Validate expert_group mapping in grouped_topk before
launching _grouped_topk: ensure all expert_group values are in [0,
num_expert_group), compute per-group counts with torch.bincount (length
num_expert_group) and assert no empty groups, and additionally assert the
worst‑case selection of topk_group groups can provide k experts by checking that
the sum of the smallest topk_group counts >= k; raise a clear AssertionError
referencing expert_group/num_expert_group/topk_group/k when any check fails.
In `@csrc/kernels/binary_operator.cu`:
- Around line 24-41: The in-place native wrappers aiter_add_, aiter_sub_,
aiter_mul_, and aiter_div_ must reject non-contiguous destinations: before
calling binary_op_dispatch("...", input, other, input) check
input.is_contiguous() and return false when it is not; this prevents
binary_op_dispatch (and the transpose/broadcast kernels it may select) from
writing dense offsets into the wrong locations and lets the Python layer fall
back to the correct strided in-place implementation.
In `@op_tests/op_benchmarks/triton/bench_deepgemm_attention.py`:
- Around line 26-35: kv_cache_cast_to_fp8 currently hardcodes the FP8 max (240)
when computing sf; instead, derive the FP8 max from the selected fp8_dtype and
use that to compute sf. Update kv_cache_cast_to_fp8 (and where fp8_dtype is set
via get_fp8_e4m3_dtype) to compute fp8_max = torch.finfo(fp8_dtype).max (or
equivalent API for the chosen FP8 dtype) and set sf = x_amax / fp8_max before
scaling; keep the existing fallback when fp8_dtype is None and preserve
x.shape/asserts and subsequent casting to fp8_dtype.
In `@op_tests/test_mla.py`:
- Around line 582-584: The current early skip ("if return_lse: pass") prevents
the bh64 Gluon path from ever seeing return_lse, so ensure the Gluon validation
receives the flag: remove the no-op skip and forward the return_lse flag into
the Gluon test/run invocation (eg. pass return_lse into test_absorb_decode_gluon
or the Gluon runner call that handles bh64) so the Gluon branch evaluates the
merged LSE contract; update the call site(s) where test_absorb_decode_gluon or
the Gluon run is invoked to accept and propagate the return_lse parameter.
In `@op_tests/test_moe_ep.py`:
- Around line 482-509: The test for quant_label "a8w4_mxfp4" doesn't force the
fp8 runtime path, so set the fp8 runtime bound before calling run_perftest
(e.g., export or os.environ["AITER_BF16_FP8_MOE_BOUND"]="1" or a value <= token)
when quant_label == "a8w4_mxfp4" so the code path exercised by
run_perftest(fused_moe, ...) uses the fp8 EP; update the branch that sets act
and gate_mode in op_tests/test_moe_ep.py to set that environment variable (or
otherwise invoke the runtime-mode switch) prior to the run_perftest call.
---
Outside diff comments:
In `@aiter/ops/triton/moe/moe_op_gemm_a8w4.py`:
- Around line 364-420: Disable the out_mx_quant path when using the Gluon
kernels on gfx1250 (i.e. when use_gluon is true and the device/gfx id indicates
gfx1250) so we don't return an uninitialized y_scale; specifically, gate the
existing out_mx_quant logic (the out_dtype assignment, the allocate_output call
that depends on reduction_n_matmul, and the y_scale allocation) behind a check
that clears out_mx_quant for Gluon+gfx1250, and fix the n_out calculation to
derive the post-swiglu output width from the actual output tensor shape (use
y.shape[-1] or padded_N divided by reduction_n_matmul) instead of w.shape[-1] so
the scale buffer size is computed from N not K; refer to symbols out_mx_quant,
use_gluon, w, y, y_scale, allocate_output, and n_out when making the change.
In `@csrc/kernels/generate_binaryop.py`:
- Around line 18-34: The DATA_TYPE_MAP and AITER_DTYPE_MAP are out of sync
causing a KeyError during codegen; update AITER_DTYPE_MAP to include entries for
the dtypes referenced in DATA_TYPE_MAP (at minimum "float64" and "bool") with
the correct AITER_DTYPE_* constants, or alter cmdGenFunc() to
validate/short-circuit unsupported torch.dtype names before calling the
generator so codegen is never invoked for dtypes missing in AITER_DTYPE_MAP;
reference DATA_TYPE_MAP, AITER_DTYPE_MAP and cmdGenFunc() when making the
change.
In `@op_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.py`:
- Around line 55-62: The bench call is passing an unsupported keyword
use_aot=True to gemm_afp4wfp4_preshuffle which will raise a TypeError; remove
the use_aot=True argument from the lambda passed into triton.testing.do_bench so
the call matches the gemm_afp4wfp4_preshuffle signature (keep all other args: x,
w, x_scale, w_scale, c_dtype, y) and ensure the lambda still returns the
function invocation for benchmarking.
In `@op_tests/test_mla.py`:
- Around line 667-681: The gate incorrectly excludes batch_size > 256; update
the condition in the Gluon bh16bn64 decode check (the if using get_gfx(), dtype,
kvtype, nhead, decode_qlen, v_head_dim, qk_head_dim, page_size, batch_size,
ctx_lens) to mirror NUM_KV_SPLITS semantics by removing the hard cap and using
ctx_lens >= max(1, 256 // batch_size) with a simple batch_size >= 1 check (i.e.,
replace "1 <= batch_size <= 256 and ctx_lens >= (256 // batch_size)" with
"batch_size >= 1 and ctx_lens >= max(1, 256 // batch_size)").
---
Nitpick comments:
In `@csrc/include/hip_reduce.h`:
- Around line 291-292: The static_assert in multithread_reduce_max_dpp currently
allows thread_num up to 64 unconditionally; update the contract so non-__GFX9__
builds disallow 64: modify the check around the static_assert in
multithread_reduce_max_dpp (symbol: multithread_reduce_max_dpp / the
static_assert line) to be conditional on __GFX9__ — for __GFX9__ keep the
existing <=64 power-of-two check, otherwise tighten to thread_num >=1 &&
thread_num <=32 && power-of-two (or explicitly static_assert(false) when
thread_num==64 on non-__GFX9__), so callers cannot instantiate a 64-thread
variant on targets that don't support it.
In `@op_tests/triton_tests/fusions/test_fused_kv_cache.py`:
- Around line 219-240: The test currently sets num_decode_toks_for_zeros = T so
it never exercises the partial-prefix decode-output path; modify the test around
the call to fused_qk_rope_cat_and_cache_mla to add at least one scenario where
num_decode_toks_for_zeros < T (for example T=8 and num_decode_toks_for_zeros=3)
so the function fused_qk_rope_cat_and_cache_mla (and outputs decode_q_pe_out /
q_nope_zeros_out) is exercised for partial-prefix writes; ensure the same setup
(slot_mapping, positions, shuffled_kv_cache, upcast_operand) is reused and only
num_decode_toks_for_zeros is changed so the new case triggers the pid-transpose
OOB regression.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: fcbd3bff-42b8-49d0-bb7d-3db8e7a84668
⛔ Files ignored due to path filters (8)
aiter/configs/model_configs/dsv3_bf16_tuned_gemm.csvis excluded by!**/*.csvaiter/configs/model_configs/dsv4_bf16_tuned_gemm.csvis excluded by!**/*.csvaiter/configs/model_configs/glm5_bf16_tuned_gemm.csvis excluded by!**/*.csvaiter/configs/model_configs/gptoss_bf16_tuned_gemm.csvis excluded by!**/*.csvaiter/configs/model_configs/kimi_bf16_tuned_gemm.csvis excluded by!**/*.csvaiter/configs/model_configs/kimik2_bf16_tuned_gemm.csvis excluded by!**/*.csvaiter/configs/model_configs/kimik2_fp4_tuned_fmoe.csvis excluded by!**/*.csvhsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_silu.csvis excluded by!**/*.csv
📒 Files selected for processing (69)
.github/scripts/download_triton_wheel.sh.github/scripts/split_tests.sh.github/workflows/aiter-release.yaml.github/workflows/aiter-test.yaml.github/workflows/amd-ci-job-monitor.yml.github/workflows/docs.yml.github/workflows/prepare-triton-wheel.yaml.github/workflows/test-whl.yaml.github/workflows/triton-test.yaml.github/workflows/vllm_benchmark.yamlaiter/fused_moe.pyaiter/ops/aiter_operator.pyaiter/ops/flydsl/bpreshuffle_gemm_gfx1250.pyaiter/ops/flydsl/gemm_kernels.pyaiter/ops/flydsl/gemm_tune/flydsl_gemm_a8w8_bpreshuffle_wmma_common.pyaiter/ops/flydsl/gemm_tune/gemm_a8w8_bpreshuffle_wmma_tune.pyaiter/ops/flydsl/kernels/gemm_common_gfx1250.pyaiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.pyaiter/ops/flydsl/kernels/moe_gemm_2stage.pyaiter/ops/flydsl/kernels/pipeline_utils.pyaiter/ops/flydsl/kernels/splitk_hgemm.pyaiter/ops/flydsl/moe_kernels.pyaiter/ops/gemm_op_a8w8.pyaiter/ops/opus/__init__.pyaiter/ops/shuffle.pyaiter/ops/triton/__init__.pyaiter/ops/triton/_gluon_kernels/gfx1250/fusions/fused_kv_cache.pyaiter/ops/triton/_gluon_kernels/gfx1250/gemm/basic/gemm_mxfp4.pyaiter/ops/triton/_gluon_kernels/gfx1250/moe/moe_op_gemm_a8w4.pyaiter/ops/triton/_triton_kernels/fusions/fused_kv_cache.pyaiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.pyaiter/ops/triton/_triton_kernels/moe/moe_routing/topk.pyaiter/ops/triton/attention/pa_mqa_logits.pyaiter/ops/triton/configs/gemm/gfx1250-GEMM-AFP4WFP4_PRESHUFFLED.jsonaiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.jsonaiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.jsonaiter/ops/triton/fusions/fused_kv_cache.pyaiter/ops/triton/gemm/basic/gemm_afp4wfp4.pyaiter/ops/triton/gemm/batched/__init__.pyaiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.pyaiter/ops/triton/gluon/README.mdaiter/ops/triton/gluon/gemm_afp4wfp4.pyaiter/ops/triton/gluon/mla_decode_gluon.pyaiter/ops/triton/gluon/pa_mqa_logits.pyaiter/ops/triton/moe/moe_op_gemm_a8w4.pyaiter/ops/triton/moe/moe_routing/routing.pyaiter/ops/triton/moe/moe_routing/topk.pycsrc/include/aiter_operator.hcsrc/include/binary_operator.cuhcsrc/include/hip_reduce.hcsrc/kernels/binary_operator.cucsrc/kernels/generate_binaryop.pycsrc/kernels/topk_gating_kernels.cucsrc/pybind/aiter_operator_pybind.cuhsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_novs_silu_1tg_ps_32x384.cohsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_novs_silu_32x384.cohsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_vs_silu_1tg_ps_32x384.cohsa/gfx950/fmoe/silu/fmoe_bf16_blockscaleFp8_g1u1_vs_silu_32x384.coop_tests/op_benchmarks/triton/bench_deepgemm_attention.pyop_tests/op_benchmarks/triton/bench_gemm_afp4wfp4.pyop_tests/test_batched_gemm_smallB.pyop_tests/test_gemm_a8w8_bpreshuffle_gfx1250.pyop_tests/test_mla.pyop_tests/test_moe_2stage.pyop_tests/test_moe_ep.pyop_tests/triton_tests/fusions/test_fused_kv_cache.pyop_tests/triton_tests/fusions/test_fused_routing_from_topk.pyop_tests/triton_tests/gemm/basic/test_gemm_afp4wfp4.pyop_tests/triton_tests/moe/test_moe_routing.py
💤 Files with no reviewable changes (2)
- op_tests/triton_tests/fusions/test_fused_routing_from_topk.py
- aiter/ops/opus/init.py
✅ Files skipped from review due to trivial changes (3)
- csrc/pybind/aiter_operator_pybind.cu
- aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
- aiter/ops/triton/gemm/batched/init.py
🚧 Files skipped from review as they are similar to previous changes (5)
- aiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json
- aiter/ops/triton/init.py
- op_tests/test_batched_gemm_smallB.py
- aiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
- aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.py
| ROCM_VERSION=$(dpkg -l rocm-core 2>/dev/null | awk '/^ii/{print $3}') | ||
| if [[ -n "${ROCM_VERSION}" ]]; then | ||
| ROCM_MAJOR_MINOR=$(echo "${ROCM_VERSION}" | cut -d. -f1,2) | ||
| TRITON_INDEX_URL="https://pypi.amd.com/triton/release_/rocm-${ROCM_MAJOR_MINOR}.0/simple/" | ||
| else | ||
| echo "rocm-core not found; using default ROCm version ${TRITON_DEFAULT_ROCM_VERSION}" | ||
| fi |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/usr/bin/env bash
set -euo pipefail
echo "Inspect relevant lines:"
nl -ba .github/scripts/download_triton_wheel.sh | sed -n '11,20p'
echo
echo "Reproduce errexit+pipefail behavior in command substitution:"
set +e
bash -lc 'set -euo pipefail; v=$(false | cat); echo "SURVIVED"'
status=$?
set -e
echo "Observed exit status: ${status} (expected non-zero; fallback branch would be skipped)."Repository: buddywhitman/aiter
Length of output: 147
🏁 Script executed:
#!/usr/bin/env bash
set -euo pipefail
echo "1) Show relevant lines from .github/scripts/download_triton_wheel.sh (lines ~1-30):"
# use cat -n instead of nl (nl not available)
cat -n .github/scripts/download_triton_wheel.sh | sed -n '1,40p'
echo
echo "2) Locate references to TRITON_DEFAULT_ROCM_VERSION and roc m-core handling:"
rg -n "TRITON_DEFAULT_ROCM_VERSION|rocm-core|ROCM_VERSION|TRITON_INDEX_URL" .github/scripts/download_triton_wheel.sh .github/workflows/prepare-triton-wheel.yaml || true
echo
echo "3) Inspect .github/workflows/prepare-triton-wheel.yaml for how the script is invoked:"
sed -n '1,200p' .github/workflows/prepare-triton-wheel.yaml | rg -n "download_triton_wheel\.sh|prepare-triton-wheel|TRITON_DEFAULT_ROCM_VERSION|ROCM_VERSION|rocm-core" || true
echo
echo "4) Reproduce errexit+pipefail behavior in command substitution (minimal cases):"
set +e
# Case A: failing command in pipeline substitution
bash -lc 'set -euo pipefail; v=$(false | cat); echo "A_SURVIVED"' ; echo "A_EXIT=$?"
# Case B: failing left side (simulating dpkg not found) piped to awk; ensure pipeline failure
bash -lc 'set -euo pipefail; v=$(command_not_found_in_subst 2>/dev/null | awk "/x/"); echo "B_SURVIVED"' ; echo "B_EXIT=$?"
# Case C: force non-zero but allow pipeline to succeed via `|| true` (proposed fix baseline)
bash -lc 'set -euo pipefail; v=$(false | cat || true); echo "C_SURVIVED"' ; echo "C_EXIT=$?"
# Case D: proposed fix pattern: command substitution with `|| true` appended
bash -lc 'set -euo pipefail; v=$(false | cat || true); echo "D_SURVIVED"' ; echo "D_EXIT=$?"
set -eRepository: buddywhitman/aiter
Length of output: 4868
🏁 Script executed:
#!/usr/bin/env bash
set -euo pipefail
echo "1) Show relevant lines from .github/scripts/download_triton_wheel.sh (lines ~1-40):"
cat -n .github/scripts/download_triton_wheel.sh | sed -n '1,40p'
echo
echo "2) Locate references to TRITON_DEFAULT_ROCM_VERSION and rocm-core handling:"
rg -n "TRITON_DEFAULT_ROCM_VERSION|rocm-core|ROCM_VERSION|TRITON_INDEX_URL" .github/scripts/download_triton_wheel.sh .github/workflows/prepare-triton-wheel.yaml || true
echo
echo "3) Inspect .github/workflows/prepare-triton-wheel.yaml for how the script is invoked:"
sed -n '1,240p' .github/workflows/prepare-triton-wheel.yaml | rg -n "download_triton_wheel\.sh|TRITON_DEFAULT_ROCM_VERSION|ROCM_VERSION|rocm-core|prepare-triton-wheel" || true
echo
echo "4) Reproduce errexit+pipefail behavior in command substitution (minimal cases):"
set +e
# Case A: failing command in pipeline substitution
bash -lc 'set -euo pipefail; v=$(false | cat); echo "A_SURVIVED"' ; echo "A_EXIT=$?"
# Case B: missing command on left side piped to awk; pipeline fails under pipefail
bash -lc 'set -euo pipefail; v=$(command_not_found_in_subst 2>/dev/null | awk "/x/"); echo "B_SURVIVED"' ; echo "B_EXIT=$?"
# Case C: proposed style inside substitution using `|| true` on pipeline
bash -lc 'set -euo pipefail; v=$(false | cat || true); echo "C_SURVIVED"' ; echo "C_EXIT=$?"
set -eRepository: buddywhitman/aiter
Length of output: 4748
Fix unreachable fallback in download_triton_wheel.sh (lines 13-19).
With set -euo pipefail, a non-zero status from the pipeline in ROCM_VERSION=$(dpkg -l rocm-core ... | awk ...) aborts the script before the if [[ -n "${ROCM_VERSION}" ]]; then ... else ... block executes (matching the sandbox repro for v=$(false | cat)). So the “rocm-core not found; using default …” branch at lines 17-19 won’t run when dpkg -l rocm-core returns non-zero. Add || true (or otherwise neutralize the failure) to allow the fallback logic to execute; the current diff approach is appropriate.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In @.github/scripts/download_triton_wheel.sh around lines 13 - 19, The pipeline
assignment to ROCM_VERSION can fail under set -euo pipefail, making the else
fallback unreachable; modify the command that sets ROCM_VERSION (the line
assigning ROCM_VERSION using dpkg -l rocm-core | awk) to neutralize failures
(e.g., append || true) so the command never returns non-zero and the subsequent
if [[ -n "${ROCM_VERSION}" ]]; then ... else ... block can run and correctly set
ROCM_MAJOR_MINOR and TRITON_INDEX_URL or fall back to the default ROCM version
message.
| # Debug: AITER_FLYDSL_FORCE=1 is for debug use. | ||
| _flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "1") == "1" | ||
| use_mxfp4_flydsl = ( | ||
| dtype in [dtypes.bf16, dtypes.fp16] | ||
| and q_type == QuantType.per_1x32 | ||
| and activation == ActivationType.Swiglu | ||
| and q_dtype_a == dtypes.fp4x2 | ||
| and (activation == ActivationType.Swiglu or _flydsl_force) | ||
| and q_dtype_a in (dtypes.fp4x2, dtypes.fp8) |
There was a problem hiding this comment.
Make the FlyDSL force override opt-in.
AITER_FLYDSL_FORCE is documented here as a debug switch, but defaulting it to "1" makes the override active even when the env var is unset. That broadens the fallback to non-SwiGLU activations on the normal path and changes kernel selection by default instead of only under explicit opt-in.
Suggested fix
- _flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "1") == "1"
+ _flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "0") == "1"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Debug: AITER_FLYDSL_FORCE=1 is for debug use. | |
| _flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "1") == "1" | |
| use_mxfp4_flydsl = ( | |
| dtype in [dtypes.bf16, dtypes.fp16] | |
| and q_type == QuantType.per_1x32 | |
| and activation == ActivationType.Swiglu | |
| and q_dtype_a == dtypes.fp4x2 | |
| and (activation == ActivationType.Swiglu or _flydsl_force) | |
| and q_dtype_a in (dtypes.fp4x2, dtypes.fp8) | |
| # Debug: AITER_FLYDSL_FORCE=1 is for debug use. | |
| _flydsl_force = os.environ.get("AITER_FLYDSL_FORCE", "0") == "1" | |
| use_mxfp4_flydsl = ( | |
| dtype in [dtypes.bf16, dtypes.fp16] | |
| and q_type == QuantType.per_1x32 | |
| and (activation == ActivationType.Swiglu or _flydsl_force) | |
| and q_dtype_a in (dtypes.fp4x2, dtypes.fp8) |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@aiter/fused_moe.py` around lines 1420 - 1426, The code currently enables the
FlyDSL force override by default because _flydsl_force =
os.environ.get("AITER_FLYDSL_FORCE", "1") == "1"; change this to be opt-in by
default (e.g. use os.environ.get("AITER_FLYDSL_FORCE", "0") == "1" or check
presence with os.environ.get("AITER_FLYDSL_FORCE") == "1") so _flydsl_force is
False when the variable is unset; update the surrounding use_mxfp4_flydsl
condition (symbols: _flydsl_force and use_mxfp4_flydsl) accordingly so FlyDSL is
only forced when the env var is explicitly set to "1".
| # EP: forward expert_mask + topk_ids to the flydsl stage2 wrapper so it can | ||
| # switch to reduce mode and fuse the validity gather in compile_moe_reduction. | ||
| if stage2_func is _flydsl_stage2_wrapper and expert_mask is not None: | ||
| extra_stage2_args["expert_mask"] = expert_mask | ||
| extra_stage2_args["topk_ids"] = topk_ids |
There was a problem hiding this comment.
Pass global expert ids to FlyDSL reduce masking.
This forwards topk_ids from fused_moe_2stages, but that parameter can already be replaced with local_topk_ids in fused_moe_ for split-k bias handling. The new reduce kernel treats topk_ids as indices into expert_mask, so feeding localized ids here will mark the wrong routes valid/invalid under EP. Thread the original global topk_ids separately for the reduction path and keep the localized ids only for kernels that actually need them.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@aiter/fused_moe.py` around lines 1807 - 1811, The current branch that sets
extra_stage2_args["expert_mask"] and extra_stage2_args["topk_ids"] for the
_flydsl_stage2_wrapper incorrectly forwards localized ids (local_topk_ids) into
the reduction kernel; change it so the reduction path receives the original
global topk ids while kernels that need split-k localized ids still get
local_topk_ids. Concretely, in the if block guarded by stage2_func is
_flydsl_stage2_wrapper and expert_mask is not None, add/assign a distinct key
(e.g., "global_topk_ids") with the original topk_ids and leave "topk_ids"
pointing to local_topk_ids only for places that use localized routing; update
any call-sites inside fused_moe_2stages / fused_moe_ that expect global vs local
ids to read the appropriate key so the new reduce kernel uses global_topk_ids
while other kernels continue to use topk_ids (local_topk_ids).
| if XQ.dim() != 2 or WQ.dim() != 2: | ||
| raise RuntimeError( | ||
| f"[FlyDSL gfx1250] A/B must be 2-D, got {tuple(XQ.shape)}, {tuple(WQ.shape)}" | ||
| ) | ||
| if XQ.element_size() != 1 or WQ.element_size() != 1: | ||
| raise RuntimeError("[FlyDSL gfx1250] A/B must be 1-byte fp8 storage") | ||
|
|
||
| M, K = XQ.shape | ||
| N = WQ.shape[0] | ||
| if K != WQ.shape[1]: | ||
| raise RuntimeError(f"[FlyDSL gfx1250] K mismatch: A.K={K} vs B.K={WQ.shape[1]}") | ||
| if N % tile_n != 0: | ||
| raise RuntimeError(f"[FlyDSL gfx1250] N={N} not a multiple of tile_n={tile_n}") | ||
| if K % _WMMA_K != 0 or K % tile_k != 0: | ||
| raise RuntimeError( | ||
| f"[FlyDSL gfx1250] K={K} must be a multiple of WMMA_K={_WMMA_K} and " | ||
| f"tile_k={tile_k}" | ||
| ) | ||
|
|
||
| out_dtype = _OUT_DTYPE_NAME.get(Out.dtype) | ||
| if out_dtype is None: | ||
| raise RuntimeError( | ||
| f"[FlyDSL gfx1250] unsupported out dtype {Out.dtype}; expected bf16/fp16" | ||
| ) |
There was a problem hiding this comment.
Validate tensor device and output-shape contracts before launching.
This path never checks that XQ, WQ, x_scale, w_scale, and Out live on the same accelerator device, and it never checks that Out.shape == (M, N). With padded_m == M, Out.contiguous() is handed straight to the compiled kernel, so a mismatched Out can become an out-of-bounds device write instead of a clean Python error.
Proposed guard
M, K = XQ.shape
N = WQ.shape[0]
+ expected_device = XQ.device
+ tensors = {
+ "WQ": WQ,
+ "x_scale": x_scale,
+ "w_scale": w_scale,
+ "Out": Out,
+ }
+ if expected_device.type != "cuda":
+ raise RuntimeError(
+ f"[FlyDSL gfx1250] expected accelerator tensors, got {expected_device}"
+ )
+ for name, tensor in tensors.items():
+ if tensor.device != expected_device:
+ raise RuntimeError(
+ f"[FlyDSL gfx1250] {name} must be on {expected_device}, got {tensor.device}"
+ )
+ if Out.shape != (M, N):
+ raise RuntimeError(
+ f"[FlyDSL gfx1250] Out must have shape {(M, N)}, got {tuple(Out.shape)}"
+ )Also applies to: 125-126, 159-181
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py` around lines 87 - 110, Add
device and output-shape validation before launching the kernel: verify XQ, WQ,
x_scale, w_scale, and Out are on the same device and that Out.shape == (M, N);
also ensure Out is contiguous (or call Out = Out.contiguous()) when padded_m ==
M to avoid out-of-bounds writes. Implement these checks alongside the existing
dimensional/element-size checks in the gfx1250 path (the block validating XQ/WQ,
M, N, K, tile_n/tile_k and Out.dtype) and duplicate the same guard logic in the
other similar validation blocks present later in this file (the other
gfx1250/validation sections).
| # Pipeline depth needs >= 1 K tile per buffer (per split-k chunk). | ||
| num_k_tiles = (K // split_k) // tile_k | ||
| nb = max(2, min(int(num_buffers), num_k_tiles)) | ||
| if nb not in _SUPPORTED_NUM_BUFFERS: | ||
| nb = max(b for b in _SUPPORTED_NUM_BUFFERS if b <= nb) |
There was a problem hiding this comment.
Reject unsupported pipeline depths instead of silently forcing nb=2.
When a split-K chunk contains only one K tile, max(2, min(..., num_k_tiles)) still forces nb back to 2. That bypasses the same shape constraint enforced by kernel_fits_shape(...) and lets direct callers compile a two-stage pipeline for a one-tile chunk.
Proposed fix
# Pipeline depth needs >= 1 K tile per buffer (per split-k chunk).
num_k_tiles = (K // split_k) // tile_k
- nb = max(2, min(int(num_buffers), num_k_tiles))
+ if num_k_tiles < min(_SUPPORTED_NUM_BUFFERS):
+ raise RuntimeError(
+ f"[FlyDSL gfx1250] split_k={split_k} leaves only {num_k_tiles} K tiles; "
+ f"need at least {min(_SUPPORTED_NUM_BUFFERS)} for this pipeline"
+ )
+ nb = min(int(num_buffers), num_k_tiles)
if nb not in _SUPPORTED_NUM_BUFFERS:
nb = max(b for b in _SUPPORTED_NUM_BUFFERS if b <= nb)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@aiter/ops/flydsl/bpreshuffle_gemm_gfx1250.py` around lines 119 - 123, Compute
num_k_tiles as before, then do not coerce nb up to 2; set nb =
min(int(num_buffers), num_k_tiles) and if nb is not in _SUPPORTED_NUM_BUFFERS
pick the largest supported buffer count <= nb, but if that yields a value < 1 or
yields a supported nb that is greater than num_k_tiles (e.g. forcing 2 when
num_k_tiles == 1) raise an explicit exception (ValueError) rejecting the
unsupported pipeline depth instead of silently forcing nb=2; update the logic
around num_k_tiles, nb and _SUPPORTED_NUM_BUFFERS and ensure callers that rely
on kernel_fits_shape(...) will see the error when a pipeline depth cannot be
satisfied.
| assert num_expert_group > 1 | ||
| assert ( | ||
| num_expert_group <= 16 | ||
| ), f"NUM_EXPERT_GROUP ({num_expert_group}) > 16 not supported" | ||
| assert 0 < topk_group <= num_expert_group | ||
| assert 0 < k <= 16 | ||
| assert score_mode in ( | ||
| "softmax", | ||
| "sigmoid", | ||
| "sqrtsoftplus", | ||
| "none", | ||
| ), f"unknown score_mode {score_mode!r}" | ||
| has_bias = bias is not None | ||
| if has_bias: | ||
| assert bias.dim() == 1 and bias.shape[0] == n_cols | ||
| assert bias.dtype == torch.float32 | ||
| assert score_mode in ( | ||
| "sqrtsoftplus", | ||
| "sigmoid", | ||
| ), "bias only supported with sqrtsoftplus / sigmoid" | ||
|
|
||
| dev = x.device | ||
|
|
||
| # Default expert→group mapping = contiguous DeepSeek layout. | ||
| if expert_group is None: | ||
| assert n_cols % num_expert_group == 0, ( | ||
| f"n_expts_tot ({n_cols}) not divisible by num_expert_group " | ||
| f"({num_expert_group}); pass an explicit expert_group table." | ||
| ) | ||
| g_size = n_cols // num_expert_group | ||
| expert_group = ( | ||
| torch.arange(n_cols, device=dev, dtype=torch.int32) // g_size | ||
| ).to(torch.int32) | ||
| else: | ||
| assert expert_group.dim() == 1 and expert_group.shape[0] == n_cols | ||
| assert expert_group.dtype == torch.int32 | ||
|
|
There was a problem hiding this comment.
Validate grouped configs before launching the kernel.
grouped_topk() now accepts arbitrary expert_group, but Lines 67-79 only check shape and dtype. If the mapping contains out-of-range ids, empty groups, or a grouping where some valid topk_group selection can expose fewer than k experts, _grouped_topk will exhaust real candidates and keep emitting duplicate indices instead of padding. Please reject invalid mappings/counts up front, or add explicit sentinel handling in the kernel.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@aiter/ops/triton/moe/moe_routing/topk.py` around lines 43 - 79, Validate
expert_group mapping in grouped_topk before launching _grouped_topk: ensure all
expert_group values are in [0, num_expert_group), compute per-group counts with
torch.bincount (length num_expert_group) and assert no empty groups, and
additionally assert the worst‑case selection of topk_group groups can provide k
experts by checking that the sum of the smallest topk_group counts >= k; raise a
clear AssertionError referencing expert_group/num_expert_group/topk_group/k when
any check fails.
| bool aiter_add_(aiter_tensor_t &input, aiter_tensor_t &other) | ||
| { | ||
| torch::Tensor output = initOutput(input, other); | ||
| binary_op_dispatch("div", input, other, output); | ||
| return output; | ||
| return binary_op_dispatch("add", input, other, input); | ||
| } | ||
|
|
||
| // inp interface | ||
| torch::Tensor aiter_add_(torch::Tensor &input, torch::Tensor &other) | ||
| bool aiter_sub_(aiter_tensor_t &input, aiter_tensor_t &other) | ||
| { | ||
| binary_op_dispatch("add", input, other, input); | ||
| return input; | ||
| return binary_op_dispatch("sub", input, other, input); | ||
| } | ||
|
|
||
| torch::Tensor aiter_sub_(torch::Tensor &input, torch::Tensor &other) | ||
| bool aiter_mul_(aiter_tensor_t &input, aiter_tensor_t &other) | ||
| { | ||
| binary_op_dispatch("sub", input, other, input); | ||
| return input; | ||
| return binary_op_dispatch("mul", input, other, input); | ||
| } | ||
|
|
||
| torch::Tensor aiter_mul_(torch::Tensor &input, torch::Tensor &other) | ||
| bool aiter_div_(aiter_tensor_t &input, aiter_tensor_t &other) | ||
| { | ||
| binary_op_dispatch("mul", input, other, input); | ||
| return input; | ||
| } | ||
|
|
||
| torch::Tensor aiter_div_(torch::Tensor &input, torch::Tensor &other) | ||
| { | ||
| binary_op_dispatch("div", input, other, input); | ||
| return input; | ||
| return binary_op_dispatch("div", input, other, input); |
There was a problem hiding this comment.
Reject non-contiguous input in the in-place native wrappers.
These wrappers alias input as output, but the transpose/broadcast kernels assume the destination is laid out contiguously. A non-contiguous input can still match the transpose path, so binary_op_dispatch() will write back with dense offsets into the wrong locations. Returning false here when !input.is_contiguous() lets the Python layer fall back to input.<op>_() and preserves correct strided behavior.
Suggested guard
bool aiter_add_(aiter_tensor_t &input, aiter_tensor_t &other)
{
+ if (!input.is_contiguous()) {
+ return false;
+ }
return binary_op_dispatch("add", input, other, input);
}
bool aiter_sub_(aiter_tensor_t &input, aiter_tensor_t &other)
{
+ if (!input.is_contiguous()) {
+ return false;
+ }
return binary_op_dispatch("sub", input, other, input);
}
bool aiter_mul_(aiter_tensor_t &input, aiter_tensor_t &other)
{
+ if (!input.is_contiguous()) {
+ return false;
+ }
return binary_op_dispatch("mul", input, other, input);
}
bool aiter_div_(aiter_tensor_t &input, aiter_tensor_t &other)
{
+ if (!input.is_contiguous()) {
+ return false;
+ }
return binary_op_dispatch("div", input, other, input);
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| bool aiter_add_(aiter_tensor_t &input, aiter_tensor_t &other) | |
| { | |
| torch::Tensor output = initOutput(input, other); | |
| binary_op_dispatch("div", input, other, output); | |
| return output; | |
| return binary_op_dispatch("add", input, other, input); | |
| } | |
| // inp interface | |
| torch::Tensor aiter_add_(torch::Tensor &input, torch::Tensor &other) | |
| bool aiter_sub_(aiter_tensor_t &input, aiter_tensor_t &other) | |
| { | |
| binary_op_dispatch("add", input, other, input); | |
| return input; | |
| return binary_op_dispatch("sub", input, other, input); | |
| } | |
| torch::Tensor aiter_sub_(torch::Tensor &input, torch::Tensor &other) | |
| bool aiter_mul_(aiter_tensor_t &input, aiter_tensor_t &other) | |
| { | |
| binary_op_dispatch("sub", input, other, input); | |
| return input; | |
| return binary_op_dispatch("mul", input, other, input); | |
| } | |
| torch::Tensor aiter_mul_(torch::Tensor &input, torch::Tensor &other) | |
| bool aiter_div_(aiter_tensor_t &input, aiter_tensor_t &other) | |
| { | |
| binary_op_dispatch("mul", input, other, input); | |
| return input; | |
| } | |
| torch::Tensor aiter_div_(torch::Tensor &input, torch::Tensor &other) | |
| { | |
| binary_op_dispatch("div", input, other, input); | |
| return input; | |
| return binary_op_dispatch("div", input, other, input); | |
| bool aiter_add_(aiter_tensor_t &input, aiter_tensor_t &other) | |
| { | |
| if (!input.is_contiguous()) { | |
| return false; | |
| } | |
| return binary_op_dispatch("add", input, other, input); | |
| } | |
| bool aiter_sub_(aiter_tensor_t &input, aiter_tensor_t &other) | |
| { | |
| if (!input.is_contiguous()) { | |
| return false; | |
| } | |
| return binary_op_dispatch("sub", input, other, input); | |
| } | |
| bool aiter_mul_(aiter_tensor_t &input, aiter_tensor_t &other) | |
| { | |
| if (!input.is_contiguous()) { | |
| return false; | |
| } | |
| return binary_op_dispatch("mul", input, other, input); | |
| } | |
| bool aiter_div_(aiter_tensor_t &input, aiter_tensor_t &other) | |
| { | |
| if (!input.is_contiguous()) { | |
| return false; | |
| } | |
| return binary_op_dispatch("div", input, other, input); | |
| } |
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@csrc/kernels/binary_operator.cu` around lines 24 - 41, The in-place native
wrappers aiter_add_, aiter_sub_, aiter_mul_, and aiter_div_ must reject
non-contiguous destinations: before calling binary_op_dispatch("...", input,
other, input) check input.is_contiguous() and return false when it is not; this
prevents binary_op_dispatch (and the transpose/broadcast kernels it may select)
from writing dense offsets into the wrong locations and lets the Python layer
fall back to the correct strided in-place implementation.
| def kv_cache_cast_to_fp8( | ||
| x: torch.Tensor, padding=False, fp8_dtype=None | ||
| ) -> torch.Tensor: | ||
| if fp8_dtype is None: | ||
| fp8_dtype = get_fp8_e4m3_dtype() | ||
| num_blocks, block_size, num_heads, head_dim = x.shape | ||
| assert num_heads == 1 | ||
| x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) | ||
| sf = x_amax / 240.0 | ||
| x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fnuz) | ||
| x_scaled = (x * (1.0 / sf)).to(fp8_dtype) |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
What are the max finite values (or recommended quantization ranges) for torch.float8_e4m3fnandtorch.float8_e4m3fnuz in PyTorch, and should they use different scale denominators when quantizing BF16 tensors to FP8?
💡 Result:
The max finite value for both torch.float8_e4m3fn and torch.float8_e4m3fnuz is 448.0 [1][2]. torch.float8_e4m3fn is based on the OCP (Open Compute Project) FP8 specification, where the suffix "fn" denotes "Finite, no infinity" [3][2][4]. torch.float8_e4m3fnuz (often referred to as the E4M3 FNUZ format) is an alternative variant supported by specific hardware (e.g., AMD MI300 series) where "fnuz" stands for "Finite, NaN, Unsigned, Zero" [5][6][7][8]. Both types lack representations for infinity and are generally restricted to the same nominal maximum value of 448.0 in practice when quantizing [1][2]. Regarding scale denominators, they should generally be treated identically for quantization purposes when using the same max value (448.0) [1]. Because both types share the same dynamic range, the scaling factor used to map higher-precision tensors (like bfloat16 or float32) to the FP8 range is calculated using the same logic: scale = 448.0 / amax [2] While the bit-level encodings differ between fn and fnuz variants (specifically regarding how NaNs and zeros are represented) [5][7][8], the quantization strategy—clamping values to the [–448.0, 448.0] range before casting—remains the same for both [1][9]. In standard PyTorch implementations, including those in torchao and NVIDIA Model-Optimizer, saturation is highly recommended because native casting does not always clamp by default, which can lead to unexpected overflows or NaN generation [10][11][9]. Top results: [1][5][9][2][7]
Citations:
- 1: https://github.com/pytorch/ao/blob/main/torchao/csrc/cuda/mx_kernels/mxfp8_quantize.cuh
- 2: https://docs.nvidia.com/cuda/nvmath-python/0.5.0/tutorials/notebooks/matmul/04_fp8.html
- 3: https://dev-discuss.pytorch.org/t/float8-in-pytorch-1-x/1815
- 4: https://github.com/pytorch/rfcs/blob/master/RFC-0030-native-fp8-dtype.md
- 5: https://github.com/pytorch/pytorch/blob/main/torch/headeronly/util/Float8_e4m3fnuz.h
- 6: https://github.com/pytorch/ao/blob/main/torchao/float8/config.py
- 7: pytorch/pytorch@1c3c1f2
- 8: Add support for float8_e4m3fnuz and _e5m2fnuz pytorch/pytorch#107586
- 9: https://github.com/pytorch/ao/blob/main/torchao/float8/float8_utils.py
- 10: Add clamped FP8 (E4M3) cast for overflow-safe inference pytorch/pytorch#154028
- 11: https://github.com/NVIDIA/Model-Optimizer/blob/e5ce0ae8/modelopt/torch/quantization/tensor_quant.py
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 1) Locate the helper and constant usage in this repo
rg -n "get_fp8_e4m3_dtype|240\.0|kv_cache_cast_to_fp8" -S . || true
# 2) Show the surrounding code in the target file around the snippet
target="op_tests/op_benchmarks/triton/bench_deepgemm_attention.py"
if [ -f "$target" ]; then
sed -n '1,120p' "$target" | nl -ba | sed -n '1,120p'
fiRepository: buddywhitman/aiter
Length of output: 9821
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 3) Find implementations of get_fp8_e4m3_dtype and any fp8 scaling utilities/constants
rg -n "def get_fp8_e4m3_dtype|get_fp8_e4m3_dtype\s*\(" -S . || true
# If present, inspect likely nearby files (by searching filenames containing fp8)
fd -a "fp8" . || trueRepository: buddywhitman/aiter
Length of output: 50375
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Show the helper + dtype selection/printing area
target="op_tests/op_benchmarks/triton/bench_deepgemm_attention.py"
echo "=== kv_cache_cast_to_fp8 + dtype selection context ==="
# Print around lines 1-120
awk 'NR>=1 && NR<=120 {printf "%6d %s\n", NR, $0}' "$target" | sed -n '1,120p'
echo "=== later FP8 dtype usage around line ~220-320 ==="
awk 'NR>=220 && NR<=320 {printf "%6d %s\n", NR, $0}' "$target" | sed -n '1,120p'
echo "=== definitions of get_fp8_e4m3_dtype ==="
types="aiter/ops/triton/utils/types.py"
awk 'NR>=1 && NR<=200 {printf "%6d %s\n", NR, $0}' "$types" | sed -n '1,200p'
echo "=== locate fp8 min/max bounds utilities ==="
rg -n "get_fp8_min_max_bounds|get_fp8_.*max|fp8_max|max_pos|inv_max_pos|240\.0|448\.0" aiter/ops/triton -S || true
echo "=== inspect fused_fp8_quant fp8 range logic ==="
fused="aiter/ops/triton/quant/fused_fp8_quant.py"
# show around where it mentions get_fp8_min_max_bounds and dtype
awk 'NR>=350 && NR<=520 {printf "%6d %s\n", NR, $0}' "$fused" | sed -n '1,200p'Repository: buddywhitman/aiter
Length of output: 42658
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== ck_tile/vec_convert.h around constants ==="
file="csrc/include/ck_tile/vec_convert.h"
if [ -f "$file" ]; then
sed -n '1,120p' "$file" | nl -ba | sed -n '35,80p'
else
echo "missing $file"
fi
echo "=== mx_quant_utils.h around max_pos/max_normal comments ==="
file="csrc/include/mx_quant_utils.h"
if [ -f "$file" ]; then
sed -n '70,130p' "$file" | nl -ba | sed -n '60,120p'
else
echo "missing $file"
fi
echo "=== aiter/utility/fp4_utils.py fp8 dtype mapping ==="
file="aiter/utility/fp4_utils.py"
if [ -f "$file" ]; then
sed -n '60,110p' "$file" | nl -ba | sed -n '1,80p'
else
echo "missing $file"
fiRepository: buddywhitman/aiter
Length of output: 170
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "=== ck_tile/vec_convert.h (show lines 1-120 or whole if smaller) ==="
file="csrc/include/ck_tile/vec_convert.h"
if [ -f "$file" ]; then
wc -l "$file"
sed -n '35,80p' "$file"
fi
echo "=== mx_quant_utils.h (show lines 70-130) ==="
file="csrc/include/mx_quant_utils.h"
if [ -f "$file" ]; then
wc -l "$file"
sed -n '80,130p' "$file"
fi
echo "=== aiter/utility/fp4_utils.py (show lines 60-110) ==="
file="aiter/utility/fp4_utils.py"
if [ -f "$file" ]; then
wc -l "$file"
sed -n '60,110p' "$file"
fiRepository: buddywhitman/aiter
Length of output: 7426
Use the FP8 max for the selected dtype when computing sf in kv_cache_cast_to_fp8().
kv_cache_cast_to_fp8() always scales with sf = x_amax / 240.0, but torch.float8_e4m3fn and torch.float8_e4m3fnuz have different max finite ranges in this codebase (448 vs 240). Derive the denominator from fp8_dtype (e.g., fp8_max = torch.finfo(fp8_dtype).max; sf = x_amax / fp8_max) so the KV cache is scaled consistently with the chosen FP8 encoding.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@op_tests/op_benchmarks/triton/bench_deepgemm_attention.py` around lines 26 -
35, kv_cache_cast_to_fp8 currently hardcodes the FP8 max (240) when computing
sf; instead, derive the FP8 max from the selected fp8_dtype and use that to
compute sf. Update kv_cache_cast_to_fp8 (and where fp8_dtype is set via
get_fp8_e4m3_dtype) to compute fp8_max = torch.finfo(fp8_dtype).max (or
equivalent API for the chosen FP8 dtype) and set sf = x_amax / fp8_max before
scaling; keep the existing fallback when fp8_dtype is None and preserve
x.shape/asserts and subsequent casting to fp8_dtype.
| # The ASM decode baseline aborts for these MLA configs when lse is requested | ||
| if return_lse: | ||
| pass |
There was a problem hiding this comment.
--return_lse still skips the bh64 Gluon return_lse path.
Once the ASM baseline is skipped here, the Gluon run becomes the only validation path — but test_absorb_decode_gluon() never forwards return_lse, so -lse on bh64 still times the output-only kernel and never checks the merged LSE contract.
Suggested fix
- (attn_logits, attn_lse), us_gluon_decode = run_perftest(
+ (_, attn_lse), us_gluon_decode = run_perftest(
mla_decode_gluon,
q_nope,
q_pe,
kv_c,
out_gluon.view(batch_size, nhead, v_head_dim),
page_table,
seq_info,
sm_scale,
use_2d_view=use_2d_view,
min_kv_seq_len=ctx_lens,
+ return_lse=return_lse,
)
@@
err = checkAllclose(
out_ref,
out_gluon,
msg=f"mla_decode-absorb [golden vs gluon_mla]: {us_gluon_decode:>8.2f} us......",
)
+ if return_lse and attn_lse is not None:
+ checkAllclose(
+ lse_ref,
+ attn_lse.reshape(total_q, nhead),
+ msg=f"mla_decode-absorb [lse_ref vs gluon_mla_lse]: {us_gluon_decode:>8.2f} us......",
+ )
return err, us_gluon_decode🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@op_tests/test_mla.py` around lines 582 - 584, The current early skip ("if
return_lse: pass") prevents the bh64 Gluon path from ever seeing return_lse, so
ensure the Gluon validation receives the flag: remove the no-op skip and forward
the return_lse flag into the Gluon test/run invocation (eg. pass return_lse into
test_absorb_decode_gluon or the Gluon runner call that handles bh64) so the
Gluon branch evaluates the merged LSE contract; update the call site(s) where
test_absorb_decode_gluon or the Gluon run is invoked to accept and propagate the
return_lse parameter.
| # a4w4: Silu + SEPARATED -> FlyDSL fp4/fp4 (AITER_FLYDSL_FORCE=1 drops the | ||
| # Swiglu gate). a8w4: Silu + INTERLEAVE -> q_dtype_a auto-picker selects | ||
| # fp8 on gfx950 (fused_moe.py:357-361), and since the L1261 CK-Tile | ||
| # pre-emption requires Swiglu, Silu falls through to the | ||
| # swiglu_mxfp4_flydsl branch (with FLYDSL_FORCE=1) and lands on | ||
| # flydsl_moe1_afp8_wfp4_... Needs AITER_BF16_FP8_MOE_BOUND<=token. | ||
| if quant_label == "a8w4_mxfp4": | ||
| act = ActivationType.Silu | ||
| gate_mode = GateMode.INTERLEAVE.value | ||
| else: | ||
| act = ActivationType.Silu | ||
| gate_mode = GateMode.SEPARATED.value | ||
| out, us = run_perftest( | ||
| fused_moe, | ||
| input_, | ||
| w1_a, | ||
| w2_a, | ||
| topk_weights, | ||
| topk_ids, | ||
| expert_mask=expert_mask, | ||
| activation=act, | ||
| gate_mode=gate_mode, | ||
| quant_type=QuantType.per_1x32, | ||
| w1_scale=w1_s, | ||
| w2_scale=w2_s, | ||
| num_warmup=3, | ||
| num_iters=16, | ||
| ) |
There was a problem hiding this comment.
Force the fp8 runtime mode in the a8w4_mxfp4 EP test.
This test never sets AITER_BF16_FP8_MOE_BOUND, so the default token=128 case can stay on the bf16 runtime path instead of the fp8 path the label/comment claim to cover. That means g1u1_a8w4_mxfp4 can pass without ever validating the new fp8 EP route.
Suggested fix
+import os
import argparse
@@
if quant_label == "a8w4_mxfp4":
act = ActivationType.Silu
gate_mode = GateMode.INTERLEAVE.value
+ _old_bound = os.environ.get("AITER_BF16_FP8_MOE_BOUND")
+ os.environ["AITER_BF16_FP8_MOE_BOUND"] = "0"
else:
act = ActivationType.Silu
gate_mode = GateMode.SEPARATED.value
- out, us = run_perftest(
- fused_moe,
- input_,
- w1_a,
- w2_a,
- topk_weights,
- topk_ids,
- expert_mask=expert_mask,
- activation=act,
- gate_mode=gate_mode,
- quant_type=QuantType.per_1x32,
- w1_scale=w1_s,
- w2_scale=w2_s,
- num_warmup=3,
- num_iters=16,
- )
+ try:
+ out, us = run_perftest(
+ fused_moe,
+ input_,
+ w1_a,
+ w2_a,
+ topk_weights,
+ topk_ids,
+ expert_mask=expert_mask,
+ activation=act,
+ gate_mode=gate_mode,
+ quant_type=QuantType.per_1x32,
+ w1_scale=w1_s,
+ w2_scale=w2_s,
+ num_warmup=3,
+ num_iters=16,
+ )
+ finally:
+ if quant_label == "a8w4_mxfp4":
+ if _old_bound is None:
+ os.environ.pop("AITER_BF16_FP8_MOE_BOUND", None)
+ else:
+ os.environ["AITER_BF16_FP8_MOE_BOUND"] = _old_bound🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@op_tests/test_moe_ep.py` around lines 482 - 509, The test for quant_label
"a8w4_mxfp4" doesn't force the fp8 runtime path, so set the fp8 runtime bound
before calling run_perftest (e.g., export or
os.environ["AITER_BF16_FP8_MOE_BOUND"]="1" or a value <= token) when quant_label
== "a8w4_mxfp4" so the code path exercised by run_perftest(fused_moe, ...) uses
the fp8 EP; update the branch that sets act and gate_mode in
op_tests/test_moe_ep.py to set that environment variable (or otherwise invoke
the runtime-mode switch) prior to the run_perftest call.
Summary
Implements
batched_gemm_a8w8_smallB_blockscale— a Triton FP8 batched GEMMkernel optimized for small-B (B=2, DeepSeek V4 grouped output LoRA wo_a).
This is a contribution targeting ROCm#3000. Unblocks ROCm/ATOM#676.
Problem
Existing aiter FP8 batched GEMM kernels lose to BF16 einsum by 1.5×–7× at B=2
due to CU starvation: grid
(B=2, M_tiles=1, N_tiles=8)= 16 work-groups on a304-CU GPU → 5% CU utilization.
Solution
1. Grid collapse + split-K=8
Flatten B into M:
(B×M_tiles, N_tiles, split_k)grid.At B=2, M=1, split_k=8: 128 work-groups → 42% CU utilization (was 5%).
2. Fused bf16 output for split_k=1
When split_k=1, the main kernel writes bf16 directly — no partial-sum buffer,
no separate reduction kernel. Eliminates ~160μs overhead at M=256.
3. Per-128-block W-scales inline
Weight scales loaded per 128-element block within the kernel.
Single-row vector load (no 128× register expansion from broadcast).
New files
aiter/ops/triton/_triton_kernels/gemm/batched/batched_gemm_a8w8_smallB_blockscale.pyaiter/ops/triton/gemm/batched/batched_gemm_a8w8_smallB_blockscale.pyaiter/ops/triton/configs/gemm/gfx942-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json(MI300X tuned configs)aiter/ops/triton/configs/gemm/gfx950-BATCHED_GEMM-A8W8_SMALLB_BLOCKSCALE.json(MI355X tuned configs)op_tests/test_batched_gemm_smallB.pyExpected results on AMD MI300X (gfx942)
Correctness: tested on NVIDIA RTX 3070Ti (sm_86) using float8_e5m2.
AMD production uses float8_e4m3fnuz (set
AITER_AMD_FP8=1).NVIDIA sm_86 has no native FP8 hardware — AMD MI300X/MI355X numbers
expected to show FP8 ≥ BF16 at all M values.
Work-group count improvement at B=2, N=1024, BLOCK_N=128:
AI Assistance Disclosure
Implementation assisted by Claude Sonnet 4.6. All changes reviewed and
understood by the submitter. Tests run and passing on NVIDIA RTX 3070Ti.
Summary by CodeRabbit
New Features
Backward Compatibility
Tests
Chores