Skip to content

Blocked transform#5

Closed
ahurta92 wants to merge 23 commits into
devreal:mainfrom
ahurta92:blocked-transform
Closed

Blocked transform#5
ahurta92 wants to merge 23 commits into
devreal:mainfrom
ahurta92:blocked-transform

Conversation

@ahurta92

Copy link
Copy Markdown
Contributor

No description provided.

ahurta92 and others added 23 commits April 15, 2026 19:29
- Introduced L2 and L3 transform kernels with shared memory staging.
- Updated transform_bench to support level selection for benchmarking.
- Added validate_levels.hip for correctness testing across optimization levels.
- Modified .clangd for improved indexing and added HIP file parsing.
- Updated CMakeLists.txt to include new executables for level validation.
Replaces the incorrect cpu_transform3d_slabs with cpu_transform3d_blocked:
distributes A along the fastest axis, does two local GEMMs, a block-level
corner turn, a third local GEMM, and a per-block transpose on store.
Matches cpu_transform3d bit-exactly.

Includes a K=2 trace (-debug 1) that dumps every stage for a counting
tensor with B=I, so only the shuffle pattern is visible. Useful as a
smoke test when porting to GPU.

BLOCKED_TRANSFORM.md documents the algorithm frame-by-frame for an
upcoming presentation; frames/ holds the diagram images.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pass 3 now right-multiplies by B instead of left-multiplying by B^T.
The corner turn stores arriving data as (b, k') instead of (k', b), so
pass 3's output lands directly in canonical (b, c) order — the in-block
transpose is absorbed into the corner turn at no extra cost (same
loads/stores, different destination address).

Side effect: passes 2 and 3 now share the same GEMM orientation
(B on the right); only pass 1 uses the other form. One fewer MFMA
microkernel variant to maintain on the GPU port.

README frame 5 collapses from two stages to one; the bookkeeping table
loses its last row.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Implements the corner-turn algorithm on MI250X at K=16:
  1 wavefront = 1 block (K² elements);
  K=16 waves = 1 thread block = 1 tensor;
  passes 1/2/3 local; one all-to-all corner turn through LDS.

LDS layout at K=16:
  buf[K³]     32 KB  in-place ping-pong via register stash
  B_lds[K²]    2 KB  cached B matrix

Each pass computes its output into a K²/64-element per-lane register
stash, __syncthreads, then writes back to buf -- letting the single
buffer serve as both input and output.  Corner turn uses the same
stash pattern across wave boundaries.  Pass 3 fuses with the HBM
store, skipping one LDS round-trip.

Registered as level 7 in validate_levels and transformbench.
Correctness: max_rel_err < 2e-15 at N ∈ {1, 8, 64, 512, 2048}.
Perf at K=16: ~2270 GF/s steady-state, up from a 64 KB double-buffer
first cut at ~720 GF/s -- bigger jump than expected, partly because
dropping from the 64 KB LDS ceiling gave the compiler scheduling
headroom back.

Known remaining inefficiencies (to address in later passes):
  - uncoalesced HBM reads in distribute (stride K per lane)
  - LDS bank conflicts in corner-turn writes (stride 16 = 32 banks)
  - no MFMA yet; scalar FP64 ops throughout

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The scalar passes 1/2/3 are replaced with v_mfma_f64_16x16x4f64 (4
MFMAs per pass for K=16).  Correctness passes at N ∈ {1,8,64,512,2048}
with max_rel_err < 2e-15.

Perf is unchanged at ~2270 GF/s -- the kernel is not compute-bound at
this occupancy; LDS bandwidth, the corner-turn bank conflicts, and
the serialized MFMA dependency chain are the limiters.  Remaining
optimization work will target those.

Includes test_mfma_layout.hip, a small diagnostic that writes the
MFMA accumulator to host-readable memory so the per-lane output
layout can be verified empirically.  Using it, the confirmed layout
on GFX90A is:

    A_frag (M=16, K=4):  lane t -> A[t%16][t/16]         (col-major)
    B_frag (K=4, N=16):  lane t -> B[t/16][t%16]         (row-major)
    D      (M=16, N=16): lane t acc[e] -> D[(t/16)+4*e][t%16]

The comment block in the old mxm_level4.h claimed A was row-major
(A[t/4][t%4]) and D was 4-consecutive-rows per lane, which is wrong
and caused an initial ~3.7 rel_err before the probe pinned down the
actual hardware layout.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Change the within-wave row stride in buf from K to K+1 so that
stride-K-across-lanes access patterns (pass 2/3 A_frag loads, and
the corner-turn cross-wave write) no longer land on the same 128-byte
bank row on GFX90A.  LDS grows 34 KB -> 36 KB; still one block per CU.

Perf at K=16 jumps from ~2270 GF/s to ~3320 GF/s (1.46x) -- bigger
than expected because three LDS phases were conflicting, not just
the corner turn.  Correctness still PASS at N in {1,8,64,512,2048}.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each pass only touches its owning wave's buf region, so the barriers
between pass-1 compute/store, pass-1/pass-2, and pass-2/corner-turn
were enforcing ordering the wave lockstep + compiler waitcnt already
guarantees.  Only the two corner-turn barriers (stash-done-before-
cross-wave-writes, writes-done-before-pass-3) and the initial B_lds
barrier are cross-wave.

8 __syncthreads() per cube -> 3.  SQ_WAIT_INST_LDS drops 61%, GRBM
active cycles drop 9%, wall-clock perf improves ~4% (3320 -> 3460
GF/s at K=16).  The gap between counter delta and wall-clock delta
suggests the hardware was already hiding most of the barrier latency
through its 4-waves-per-SIMD scheduler.

Also add counters.txt (rocprofv2 counter set used to diagnose this).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
After pass 2, each lane t's acc[e] is at position
(row = t/16 + 4e, col = t%16) in pass-2 output.  The corner-turn
stash was reading buf at (a = (t+64e)/K, b = (t+64e)%K) in wave s's
region -- which for K=16 simplifies to (a = t/16 + 4e, b = t%16).
Exactly the same positions.  So the stash read was just reading
back what the pass-2 store had just put there in the same lane.

Skip both the pass-2 store and the stash read; write acc[e]
directly into the cross-wave destination.

LDS op savings: 4 writes + 4 reads per lane per tensor.  Wall-clock
change: ~1% (the eliminated ops were hidden in the shadow of MFMA
latency, so we improved the theoretical critical path but not the
observed one).  Still worth having -- fewer LDS unit ops leave more
room for whatever IS on the critical path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Pass 1 now computes blk^T · B instead of B^T · blk.  Swapping the
operands gives MFMA output with (row=j, col=a) indexing instead of
(row=a, col=j), which coincidentally is exactly the layout pass 2's
A_frag wants at iter p.  Net effect: lane t's pass-1 acc[p] IS
pass 2's a_val at iter p -- zero LDS, zero cross-lane movement
between passes 1 and 2.

LDS ops drop another 21% (SQ_INSTS_LDS 231k -> 182k).  Wall-clock
is flat, which is the clearest signal yet that LDS is not the
critical path of this kernel.

Cumulative win since the bank-conflict pad: -40% LDS instructions,
~0% wall-clock -- the LDS pipeline was fully hidden in the shadow of
MFMA / HBM latency the whole time.  Further LDS reductions won't
help perf; the limiter lives elsewhere (likely HBM read latency or
a scheduling artifact not captured by the counters we're collecting).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
All 1024 threads cooperatively load the K^3 tensor from HBM with
stride-1 reads (coalesced, one cache line per 16 consecutive lanes),
placing elements into buf in canonical (i, j, k) layout with k
fastest.  Pass 1 read indexing updated to match.

The old per-wave distribute read HBM with stride K = 128 bytes --
each lane pulled a whole cache line for 1 useful double.  L2
absorbed the amplification via sharing across the block's 16 waves
(same lines, different k), so L2 hit rate read high (88%) and HBM
bandwidth looked underused (26% of peak) -- but the underlying
cache-line BW was saturated.

With coalesced reads:
  GRBM active cycles -44% (124k -> 70k)
  TCC_HIT -92% (1.92M -> 154k)  (fewer redundant lookups)
  HBM BW utilization 26% -> 47% of peak
  GF/s @ K=16 N=512  ~3475 -> ~9500  (2.7x)
  GF/s @ K=16 N=2048 ~3470 -> ~7400  (2.1x)

Costs:
  +1 __syncthreads (cross-wave LDS writes in the load)
  No added LDS footprint
  No added bank conflicts (canonical layout's j-stride is K+1 = 2 banks shift)

Post-corner-turn accesses (corner-turn write, pass-3 read) are
numerically identical to before because the canonical and per-wave
layouts use the same strides (K*(K+1), K+1, 1); only the MEANING of
each index changes (i,j,k vs s,i,j).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Upstream convention is that each kernel invocation launches one
block per tensor and the GPU scheduler handles distribution across
CUs.  Bringing L7 in line: grid = nfuncs, no internal cube loop,
no grid clamping via nblocks.

This gives the scheduler finer-grain parallelism.  With only 512
concurrent blocks (previous limit), all blocks moved through the
same pipeline phase together, which synchronized everyone on
whichever phase was slow.  More blocks = more overlap potential
between HBM-heavy and compute-heavy phases.

Perf:
  N=512   9500 -> 9130 GF/s   (slight regression, already saturated)
  N=2048  7400 -> 8660 GF/s   (+17%)

nblocks arg is retained in the launcher signature but ignored, for
call-site compatibility with other levels.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Each thread loads 4 CONTIGUOUS doubles as a double4, letting the
compiler emit global_load_dwordx4 (128-bit = 2 doubles/lane on
gfx90a).  With 4 doubles per thread that becomes 2 dwordx4
instructions, down from 4 dwordx2 in the prior (stride-1024)
layout.

The per-thread HBM stride changes from "stride 1024 across 4
iterations" to "stride 1 within a single 32-byte vector load";
per-wave cache-line footprint is identical (4 cache lines per
16-lane group), so HBM bandwidth is unchanged -- the win is in
instruction count.

gfx90a has no dwordx8 for global loads, so this is as wide as the
ISA allows for a single instruction.  No bank conflicts introduced;
BK=17 pad blocks ds_write_b128 (16-byte alignment required at
odd-k_start) so LDS stores stay at ds_write_b64.

Perf @ K=16:
  N=512   ~9130 -> ~9310 GF/s
  N=2048  ~8660 -> ~8880 GF/s

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
B is K²=256 doubles; with 1024 threads and stride-1 loads, the old
code had 256 threads issuing one global_load_dwordx2 each (8 B/lane)
inside a single-iteration loop.  Replace with 64 threads issuing
one global_load_dwordx4 each (16 B/lane), and one ds_write_b128
instead of ds_write_b64 for the LDS staging.

4x fewer load instructions, same HBM traffic (still 16 cache lines
of B = 2 KB, coalesced as before).  B_lds is 16-byte aligned (offset
K*WB*8 = 34816), so double4 LDS writes are safe.

Perf @ K=16 N=2048: ~8880 -> ~9340 GF/s (~5%).  Smaller win than the
distribute double4 change because B cache is a one-shot kernel-entry
load; the gain here is reducing issue-slot pressure during warmup.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Document the full optimization arc at K=16 on MI250X (720 -> 9340
GF/s, 86% of HBM roofline), call out which optimizations paid off
vs which didn't move wall-clock, and record the confirmed MFMA
fragment layouts for GFX90A.  Ends with a pointer to K=32 as the
open follow-up.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
K=32's working set (K³ = 256 KB) blows LDS (64 KB) by 4x; the K=16
algorithm does not survive as-is.  Document the budget analysis,
three candidate approaches (half-slab / streamed / atomic-add), and
recommend starting from the simplest (atomic-add 4-blocks-per-tensor)
before iterating toward streaming.

Captures the constraints and open questions so the next session on
K=32 starts from concrete numbers, not a blank page.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same algorithm as L7 but with rocwmma::mma_sync replacing the
manual v_mfma_f64_16x16x4f64 calls.  rocWMMA hides the MFMA
fragment layouts, and load_matrix_sync with col_major gives us
the B^T view for pass 1 without any hand-written swizzle.

Differences from L7:
  * buf uses L7's old per-wave layout (buf[s*WB + i*BK + j])
    because rocWMMA wants stride-1 along one matrix axis; the
    canonical (i, j, k) layout doesn't fit.
  * No register fusion between pass 1 and pass 2 (fragments are
    opaque); pass 1 output round-trips through LDS.
  * Extra __syncthreads before pass 2 and before corner-turn
    reads to cover the fragment stores.

Correctness: PASS at N ∈ {1, 8, 64, 512, 2048} with the same
max_rel_err ~1e-15 as L7.
Perf @ K=16 N=2048: ~6250 GF/s (vs L7's ~9100).  The gap is the
LDS round-trips we lose by not having hand control of fragments.

Kept for two reasons: (1) fair comparison point "vendor MFMA
wrapper vs. hand-tuned"; (2) cleaner starting point for K values
that rocWMMA can tile automatically.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same register-fusion trick as L7 but accessed through rocWMMA's
fragment internals.  Swap pass 1's operands to compute blk^T · B;
its output is D[j][a] = temp1[a][j], so thread t's acc1.x[e] =
temp1[t%16][(t/16)+4e] = temp1[t%16][p*4+t/16] -- the exact value
pass 2's matrix_a fragment wants at iter p=e.

Pass 2's inner loop now populates FragA directly:
    a_frag_p.x[0] = acc1.x[p];
and passes it to mma_sync.  No LDS round-trip, no __syncthreads.

Also collapsed the pass-2 store and corner-turn stash read: acc2.x[e]
is already the value the cross-wave write wants (same argument as
L7's earlier commit).

Perf @ K=16 N=2048: ~6250 -> ~8500 GF/s (+36%).  Gap to hand-tuned
L7 narrows from ~31% to ~9%.  Correctness preserved.

Fragment internals (.x[i] indexing) work because rocWMMA's
Native_vec_ resolves to a clang ext_vector_type.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
K=20 and K=24 are the scientifically relevant odd-K cases; K=32 is
not (it was a suggestion for MFMA-divisibility, not a science need).
Neither K=20 nor K=24 divides by 16 for 16×16×4 MFMA, and padding to
32 blows the LDS budget (see K=32 section).

Document the 4×4×4 MFMA approach, the tile count blowup (125-216 MFMA
calls per pass vs 4 for K=16), the open question about the 4-block
lane layout (first probe showed 4 identical outputs -- need ISA
clarity on CBSZ/ABID/BLGP control bits), and an ordered set of next
steps.  Recommendation: measure L3 at K=20/24 first as a baseline
before committing to a custom kernel.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three diagnostic kernels that pin down the fragment layout of
v_mfma_f64_4x4x4f64 on gfx90a.  Result: the instruction computes
4 INDEPENDENT 4x4x4 GEMMs per issue (one per group g = (S/4)%4),
NOT the "4-broadcast" variant the original design note hypothesized.

Per group g, with S = 16*alpha + 4*g + beta:
  A_g[m=beta][k=alpha]  at lane S
  B_g[k=alpha][n=beta]  at lane S
  D_g[m=alpha][n=beta]  at lane S

The probes build in isolation (hipcc ... -o build/test_mfma_4x4x4_layout*).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
New file transform_blocked_k20.h contains four variants, listed in
order of performance on MI250X single GCD at N=2048:

  transform_kernel_blocked_k20_scalar  (1083 GF/s, correctness baseline)
  transform_kernel_blocked_k20_mfma    (2787 GF/s, pure v_mfma_f64_4x4x4f64)
  transform_kernel_blocked_k20_split   (1548 GF/s, 2 blocks/tensor,
                                        padded LDS + atomicAdd -- kept as
                                        reference; regressed vs. hybrid
                                        because atomic + memset traffic
                                        dominated)
  transform_kernel_blocked_k20_hybrid  (6900 GF/s, DEFAULT):
      - 16x16x4 MFMA on the main 16x16 sub-tile
      - 4x4x4 MFMA on 16x4 right strip, 4x16 bottom strip, 4x4 corner
      - pass 1 uses swapped operands so its accumulator layout matches
        pass 2's A-frag directly -- main + right strip are fused in
        registers across the sync (K=16 kernel uses the same trick)
      - double4 distribute for the HBM read phase

Geometry: 10 waves x 2 slabs per wave (20 k' slabs, 640 threads/block).
LDS: 8000 doubles (unpadded K^3) = 64 KB, no B_lds.

Dispatch: submit_transform_bench_blocked routes K=20 to the hybrid
kernel; L3 also gets a K=20 case for baseline measurement.

Validated bit-exact (to FP noise) against the CPU reference.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace the earlier design-note hypothesis ("4-broadcast") with the
confirmed layout (4 independent 4x4 GEMMs) and add a measured perf
table for the K=20 variants, including what didn't work:

  L3 (register block)     1967 GF/s  baseline
  L7 scalar (K=20)        1083 GF/s  correctness only
  L7 pure 4x4x4           2787 GF/s  1.42x L3
  L7 hybrid (no fusion)   3480 GF/s  1.77x L3
  L7 hybrid + fusion      6900 GF/s  3.51x L3  (DEFAULT)

Lesson captured: the 2-block-per-tensor split kernel dropped LDS
bank conflicts from 23.9-way to 0.2-way but ran 1.9x SLOWER --
atomicAdd on C + hipMemset pre-zero added 500+ MB of HBM traffic
that erased the LDS win.  The much cheaper pass-1/2 register fusion
(same layout, same bank conflicts) nearly doubled throughput by
unblocking the real critical path.  Counters showed what was
happening, not what was limiting.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@ahurta92

ahurta92 commented May 4, 2026

Copy link
Copy Markdown
Contributor Author

See #7

@ahurta92 ahurta92 closed this May 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant