Blocked transform#5
Closed
ahurta92 wants to merge 23 commits into
Closed
Conversation
- Introduced L2 and L3 transform kernels with shared memory staging. - Updated transform_bench to support level selection for benchmarking. - Added validate_levels.hip for correctness testing across optimization levels. - Modified .clangd for improved indexing and added HIP file parsing. - Updated CMakeLists.txt to include new executables for level validation.
Replaces the incorrect cpu_transform3d_slabs with cpu_transform3d_blocked: distributes A along the fastest axis, does two local GEMMs, a block-level corner turn, a third local GEMM, and a per-block transpose on store. Matches cpu_transform3d bit-exactly. Includes a K=2 trace (-debug 1) that dumps every stage for a counting tensor with B=I, so only the shuffle pattern is visible. Useful as a smoke test when porting to GPU. BLOCKED_TRANSFORM.md documents the algorithm frame-by-frame for an upcoming presentation; frames/ holds the diagram images. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pass 3 now right-multiplies by B instead of left-multiplying by B^T. The corner turn stores arriving data as (b, k') instead of (k', b), so pass 3's output lands directly in canonical (b, c) order — the in-block transpose is absorbed into the corner turn at no extra cost (same loads/stores, different destination address). Side effect: passes 2 and 3 now share the same GEMM orientation (B on the right); only pass 1 uses the other form. One fewer MFMA microkernel variant to maintain on the GPU port. README frame 5 collapses from two stages to one; the bookkeeping table loses its last row. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Implements the corner-turn algorithm on MI250X at K=16:
1 wavefront = 1 block (K² elements);
K=16 waves = 1 thread block = 1 tensor;
passes 1/2/3 local; one all-to-all corner turn through LDS.
LDS layout at K=16:
buf[K³] 32 KB in-place ping-pong via register stash
B_lds[K²] 2 KB cached B matrix
Each pass computes its output into a K²/64-element per-lane register
stash, __syncthreads, then writes back to buf -- letting the single
buffer serve as both input and output. Corner turn uses the same
stash pattern across wave boundaries. Pass 3 fuses with the HBM
store, skipping one LDS round-trip.
Registered as level 7 in validate_levels and transformbench.
Correctness: max_rel_err < 2e-15 at N ∈ {1, 8, 64, 512, 2048}.
Perf at K=16: ~2270 GF/s steady-state, up from a 64 KB double-buffer
first cut at ~720 GF/s -- bigger jump than expected, partly because
dropping from the 64 KB LDS ceiling gave the compiler scheduling
headroom back.
Known remaining inefficiencies (to address in later passes):
- uncoalesced HBM reads in distribute (stride K per lane)
- LDS bank conflicts in corner-turn writes (stride 16 = 32 banks)
- no MFMA yet; scalar FP64 ops throughout
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The scalar passes 1/2/3 are replaced with v_mfma_f64_16x16x4f64 (4
MFMAs per pass for K=16). Correctness passes at N ∈ {1,8,64,512,2048}
with max_rel_err < 2e-15.
Perf is unchanged at ~2270 GF/s -- the kernel is not compute-bound at
this occupancy; LDS bandwidth, the corner-turn bank conflicts, and
the serialized MFMA dependency chain are the limiters. Remaining
optimization work will target those.
Includes test_mfma_layout.hip, a small diagnostic that writes the
MFMA accumulator to host-readable memory so the per-lane output
layout can be verified empirically. Using it, the confirmed layout
on GFX90A is:
A_frag (M=16, K=4): lane t -> A[t%16][t/16] (col-major)
B_frag (K=4, N=16): lane t -> B[t/16][t%16] (row-major)
D (M=16, N=16): lane t acc[e] -> D[(t/16)+4*e][t%16]
The comment block in the old mxm_level4.h claimed A was row-major
(A[t/4][t%4]) and D was 4-consecutive-rows per lane, which is wrong
and caused an initial ~3.7 rel_err before the probe pinned down the
actual hardware layout.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Change the within-wave row stride in buf from K to K+1 so that
stride-K-across-lanes access patterns (pass 2/3 A_frag loads, and
the corner-turn cross-wave write) no longer land on the same 128-byte
bank row on GFX90A. LDS grows 34 KB -> 36 KB; still one block per CU.
Perf at K=16 jumps from ~2270 GF/s to ~3320 GF/s (1.46x) -- bigger
than expected because three LDS phases were conflicting, not just
the corner turn. Correctness still PASS at N in {1,8,64,512,2048}.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each pass only touches its owning wave's buf region, so the barriers between pass-1 compute/store, pass-1/pass-2, and pass-2/corner-turn were enforcing ordering the wave lockstep + compiler waitcnt already guarantees. Only the two corner-turn barriers (stash-done-before- cross-wave-writes, writes-done-before-pass-3) and the initial B_lds barrier are cross-wave. 8 __syncthreads() per cube -> 3. SQ_WAIT_INST_LDS drops 61%, GRBM active cycles drop 9%, wall-clock perf improves ~4% (3320 -> 3460 GF/s at K=16). The gap between counter delta and wall-clock delta suggests the hardware was already hiding most of the barrier latency through its 4-waves-per-SIMD scheduler. Also add counters.txt (rocprofv2 counter set used to diagnose this). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After pass 2, each lane t's acc[e] is at position (row = t/16 + 4e, col = t%16) in pass-2 output. The corner-turn stash was reading buf at (a = (t+64e)/K, b = (t+64e)%K) in wave s's region -- which for K=16 simplifies to (a = t/16 + 4e, b = t%16). Exactly the same positions. So the stash read was just reading back what the pass-2 store had just put there in the same lane. Skip both the pass-2 store and the stash read; write acc[e] directly into the cross-wave destination. LDS op savings: 4 writes + 4 reads per lane per tensor. Wall-clock change: ~1% (the eliminated ops were hidden in the shadow of MFMA latency, so we improved the theoretical critical path but not the observed one). Still worth having -- fewer LDS unit ops leave more room for whatever IS on the critical path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pass 1 now computes blk^T · B instead of B^T · blk. Swapping the operands gives MFMA output with (row=j, col=a) indexing instead of (row=a, col=j), which coincidentally is exactly the layout pass 2's A_frag wants at iter p. Net effect: lane t's pass-1 acc[p] IS pass 2's a_val at iter p -- zero LDS, zero cross-lane movement between passes 1 and 2. LDS ops drop another 21% (SQ_INSTS_LDS 231k -> 182k). Wall-clock is flat, which is the clearest signal yet that LDS is not the critical path of this kernel. Cumulative win since the bank-conflict pad: -40% LDS instructions, ~0% wall-clock -- the LDS pipeline was fully hidden in the shadow of MFMA / HBM latency the whole time. Further LDS reductions won't help perf; the limiter lives elsewhere (likely HBM read latency or a scheduling artifact not captured by the counters we're collecting). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
All 1024 threads cooperatively load the K^3 tensor from HBM with stride-1 reads (coalesced, one cache line per 16 consecutive lanes), placing elements into buf in canonical (i, j, k) layout with k fastest. Pass 1 read indexing updated to match. The old per-wave distribute read HBM with stride K = 128 bytes -- each lane pulled a whole cache line for 1 useful double. L2 absorbed the amplification via sharing across the block's 16 waves (same lines, different k), so L2 hit rate read high (88%) and HBM bandwidth looked underused (26% of peak) -- but the underlying cache-line BW was saturated. With coalesced reads: GRBM active cycles -44% (124k -> 70k) TCC_HIT -92% (1.92M -> 154k) (fewer redundant lookups) HBM BW utilization 26% -> 47% of peak GF/s @ K=16 N=512 ~3475 -> ~9500 (2.7x) GF/s @ K=16 N=2048 ~3470 -> ~7400 (2.1x) Costs: +1 __syncthreads (cross-wave LDS writes in the load) No added LDS footprint No added bank conflicts (canonical layout's j-stride is K+1 = 2 banks shift) Post-corner-turn accesses (corner-turn write, pass-3 read) are numerically identical to before because the canonical and per-wave layouts use the same strides (K*(K+1), K+1, 1); only the MEANING of each index changes (i,j,k vs s,i,j). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Upstream convention is that each kernel invocation launches one block per tensor and the GPU scheduler handles distribution across CUs. Bringing L7 in line: grid = nfuncs, no internal cube loop, no grid clamping via nblocks. This gives the scheduler finer-grain parallelism. With only 512 concurrent blocks (previous limit), all blocks moved through the same pipeline phase together, which synchronized everyone on whichever phase was slow. More blocks = more overlap potential between HBM-heavy and compute-heavy phases. Perf: N=512 9500 -> 9130 GF/s (slight regression, already saturated) N=2048 7400 -> 8660 GF/s (+17%) nblocks arg is retained in the launcher signature but ignored, for call-site compatibility with other levels. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each thread loads 4 CONTIGUOUS doubles as a double4, letting the compiler emit global_load_dwordx4 (128-bit = 2 doubles/lane on gfx90a). With 4 doubles per thread that becomes 2 dwordx4 instructions, down from 4 dwordx2 in the prior (stride-1024) layout. The per-thread HBM stride changes from "stride 1024 across 4 iterations" to "stride 1 within a single 32-byte vector load"; per-wave cache-line footprint is identical (4 cache lines per 16-lane group), so HBM bandwidth is unchanged -- the win is in instruction count. gfx90a has no dwordx8 for global loads, so this is as wide as the ISA allows for a single instruction. No bank conflicts introduced; BK=17 pad blocks ds_write_b128 (16-byte alignment required at odd-k_start) so LDS stores stay at ds_write_b64. Perf @ K=16: N=512 ~9130 -> ~9310 GF/s N=2048 ~8660 -> ~8880 GF/s Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
B is K²=256 doubles; with 1024 threads and stride-1 loads, the old code had 256 threads issuing one global_load_dwordx2 each (8 B/lane) inside a single-iteration loop. Replace with 64 threads issuing one global_load_dwordx4 each (16 B/lane), and one ds_write_b128 instead of ds_write_b64 for the LDS staging. 4x fewer load instructions, same HBM traffic (still 16 cache lines of B = 2 KB, coalesced as before). B_lds is 16-byte aligned (offset K*WB*8 = 34816), so double4 LDS writes are safe. Perf @ K=16 N=2048: ~8880 -> ~9340 GF/s (~5%). Smaller win than the distribute double4 change because B cache is a one-shot kernel-entry load; the gain here is reducing issue-slot pressure during warmup. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Document the full optimization arc at K=16 on MI250X (720 -> 9340 GF/s, 86% of HBM roofline), call out which optimizations paid off vs which didn't move wall-clock, and record the confirmed MFMA fragment layouts for GFX90A. Ends with a pointer to K=32 as the open follow-up. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
K=32's working set (K³ = 256 KB) blows LDS (64 KB) by 4x; the K=16 algorithm does not survive as-is. Document the budget analysis, three candidate approaches (half-slab / streamed / atomic-add), and recommend starting from the simplest (atomic-add 4-blocks-per-tensor) before iterating toward streaming. Captures the constraints and open questions so the next session on K=32 starts from concrete numbers, not a blank page. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same algorithm as L7 but with rocwmma::mma_sync replacing the
manual v_mfma_f64_16x16x4f64 calls. rocWMMA hides the MFMA
fragment layouts, and load_matrix_sync with col_major gives us
the B^T view for pass 1 without any hand-written swizzle.
Differences from L7:
* buf uses L7's old per-wave layout (buf[s*WB + i*BK + j])
because rocWMMA wants stride-1 along one matrix axis; the
canonical (i, j, k) layout doesn't fit.
* No register fusion between pass 1 and pass 2 (fragments are
opaque); pass 1 output round-trips through LDS.
* Extra __syncthreads before pass 2 and before corner-turn
reads to cover the fragment stores.
Correctness: PASS at N ∈ {1, 8, 64, 512, 2048} with the same
max_rel_err ~1e-15 as L7.
Perf @ K=16 N=2048: ~6250 GF/s (vs L7's ~9100). The gap is the
LDS round-trips we lose by not having hand control of fragments.
Kept for two reasons: (1) fair comparison point "vendor MFMA
wrapper vs. hand-tuned"; (2) cleaner starting point for K values
that rocWMMA can tile automatically.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same register-fusion trick as L7 but accessed through rocWMMA's
fragment internals. Swap pass 1's operands to compute blk^T · B;
its output is D[j][a] = temp1[a][j], so thread t's acc1.x[e] =
temp1[t%16][(t/16)+4e] = temp1[t%16][p*4+t/16] -- the exact value
pass 2's matrix_a fragment wants at iter p=e.
Pass 2's inner loop now populates FragA directly:
a_frag_p.x[0] = acc1.x[p];
and passes it to mma_sync. No LDS round-trip, no __syncthreads.
Also collapsed the pass-2 store and corner-turn stash read: acc2.x[e]
is already the value the cross-wave write wants (same argument as
L7's earlier commit).
Perf @ K=16 N=2048: ~6250 -> ~8500 GF/s (+36%). Gap to hand-tuned
L7 narrows from ~31% to ~9%. Correctness preserved.
Fragment internals (.x[i] indexing) work because rocWMMA's
Native_vec_ resolves to a clang ext_vector_type.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
K=20 and K=24 are the scientifically relevant odd-K cases; K=32 is not (it was a suggestion for MFMA-divisibility, not a science need). Neither K=20 nor K=24 divides by 16 for 16×16×4 MFMA, and padding to 32 blows the LDS budget (see K=32 section). Document the 4×4×4 MFMA approach, the tile count blowup (125-216 MFMA calls per pass vs 4 for K=16), the open question about the 4-block lane layout (first probe showed 4 identical outputs -- need ISA clarity on CBSZ/ABID/BLGP control bits), and an ordered set of next steps. Recommendation: measure L3 at K=20/24 first as a baseline before committing to a custom kernel. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three diagnostic kernels that pin down the fragment layout of v_mfma_f64_4x4x4f64 on gfx90a. Result: the instruction computes 4 INDEPENDENT 4x4x4 GEMMs per issue (one per group g = (S/4)%4), NOT the "4-broadcast" variant the original design note hypothesized. Per group g, with S = 16*alpha + 4*g + beta: A_g[m=beta][k=alpha] at lane S B_g[k=alpha][n=beta] at lane S D_g[m=alpha][n=beta] at lane S The probes build in isolation (hipcc ... -o build/test_mfma_4x4x4_layout*). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
New file transform_blocked_k20.h contains four variants, listed in
order of performance on MI250X single GCD at N=2048:
transform_kernel_blocked_k20_scalar (1083 GF/s, correctness baseline)
transform_kernel_blocked_k20_mfma (2787 GF/s, pure v_mfma_f64_4x4x4f64)
transform_kernel_blocked_k20_split (1548 GF/s, 2 blocks/tensor,
padded LDS + atomicAdd -- kept as
reference; regressed vs. hybrid
because atomic + memset traffic
dominated)
transform_kernel_blocked_k20_hybrid (6900 GF/s, DEFAULT):
- 16x16x4 MFMA on the main 16x16 sub-tile
- 4x4x4 MFMA on 16x4 right strip, 4x16 bottom strip, 4x4 corner
- pass 1 uses swapped operands so its accumulator layout matches
pass 2's A-frag directly -- main + right strip are fused in
registers across the sync (K=16 kernel uses the same trick)
- double4 distribute for the HBM read phase
Geometry: 10 waves x 2 slabs per wave (20 k' slabs, 640 threads/block).
LDS: 8000 doubles (unpadded K^3) = 64 KB, no B_lds.
Dispatch: submit_transform_bench_blocked routes K=20 to the hybrid
kernel; L3 also gets a K=20 case for baseline measurement.
Validated bit-exact (to FP noise) against the CPU reference.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the earlier design-note hypothesis ("4-broadcast") with the
confirmed layout (4 independent 4x4 GEMMs) and add a measured perf
table for the K=20 variants, including what didn't work:
L3 (register block) 1967 GF/s baseline
L7 scalar (K=20) 1083 GF/s correctness only
L7 pure 4x4x4 2787 GF/s 1.42x L3
L7 hybrid (no fusion) 3480 GF/s 1.77x L3
L7 hybrid + fusion 6900 GF/s 3.51x L3 (DEFAULT)
Lesson captured: the 2-block-per-tensor split kernel dropped LDS
bank conflicts from 23.9-way to 0.2-way but ran 1.9x SLOWER --
atomicAdd on C + hipMemset pre-zero added 500+ MB of HBM traffic
that erased the LDS win. The much cheaper pass-1/2 register fusion
(same layout, same bank conflicts) nearly doubled throughput by
unblocking the real critical path. Counters showed what was
happening, not what was limiting.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
Author
|
See #7 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.