diff --git a/.gitignore b/.gitignore index 7f3e00b..8ad8267 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,10 @@ clangd* compile_commands.json rocroof/ +# Python virtualenvs (rocprofiler tooling) +.rocprof-venv/ +rocprof-compute-local/ + +# Profiling data +prof/ +prof_*/ diff --git a/BLOCKED_TRANSFORM.md b/BLOCKED_TRANSFORM.md new file mode 100644 index 0000000..c4dbf21 --- /dev/null +++ b/BLOCKED_TRANSFORM.md @@ -0,0 +1,508 @@ +# Blocked 3D Transform — Corner-Turn Algorithm + +A data-movement-conscious restructuring of the MRA 3D transform, designed so +each of K thread-block-equivalents ("wavefronts" on AMD) owns one K×K slab of +the tensor throughout the computation. Passes 1 and 2 are fully local per +block; pass 3 requires a single all-to-all "corner turn" that transposes the +block-level ownership from `k'` to `a`. Pass 3 right-multiplies by B so the +output lands in canonical (b, c) order with no post-transpose. + +Reference implementations +- CPU (C++): `validate.hip` → `cpu_transform3d_blocked` +- NumPy: `mra_python/algorithms.py` → `transform_nd_blocked` + +Both match `cpu_transform3d` / `transform_nd` bit-exactly (CPU) or to +floating-point noise (NumPy). + +--- + +## The canonical operation + +We compute the 3D transform + +``` +result[a, b, c] = Σ_{i, j, k} A[i, j, k] · B[i, a] · B[j, b] · B[k, c] +``` + +via three sequential contractions, one per tensor axis. Standard formulations +apply this as three K²×K GEMMs over the whole tensor, reloading the input +through HBM on every pass. The blocked algorithm keeps each K×K slab +resident in registers/LDS of a single wavefront for all three passes, paying +one inter-wave exchange instead of three HBM round-trips. + +--- + +## Visual language + +A consistent set of conventions used across every frame: + +- **Axis colors.** `i'/a` red, `j'/b` green, `k'/c` blue. +- **Blocks are K×K squares.** Row-bar and column-bar outside each square + name the axis currently playing that role. +- **Solid box = data, dashed box = operation.** +- **Arrows carry data.** + +--- + +## Frame 1 — canonical operation + +> ![Frame 1](frames/frame1_canonical.png) + +A K×K×K input cube, a K×K matrix `B`, and a K×K×K output cube, with three +curved arrows labeled "contract `i'`", "contract `j'`", "contract `k'`". +Provides grounding; this is what all later frames compute. + +--- + +## Frame 2 — distribute (slice along `k'`) + +> ![Frame 2](frames/frame2_distribute.png) + +The input cube is sliced along the fastest axis into K independent K×K +squares. Block `s` owns the slice `A[:, :, k'=s]`, with row = `i'` and +col = `j'`. + +This establishes the row-of-K-squares motif used in every subsequent frame. + +--- + +## Frame 3 — passes 1 and 2 (local GEMMs) + +> ![Frame 3](frames/frame3_local_gemms.png) + +Two stages stacked vertically on one frame: + +``` +row of K squares ──(× B^T on left)──► row of K squares [Pass 1] + rows: i', cols: j' rows: a, cols: j' + +row of K squares ──(× B on right)──► row of K squares [Pass 2] + rows: a, cols: j' rows: a, cols: b +``` + +Both passes are **local** — no wavefront touches another wavefront's data. +A small "✓ local" badge on each block emphasizes this. + +--- + +## Frame 4 — the corner turn (the money slide) + +> ![Frame 4](frames/frame4_corner_turn.png) + +The critical data-movement step. Shown as two rows of K blocks with colored +rows flowing between them. + +**Before** (top): K blocks, each with rows colored red/green/blue/... +(= `a = 0, 1, 2, ...`). Every block has the same row-color pattern because +`a` is the within-block row index. + +**After** (bottom): K blocks, each entirely one color. Block 0 all red, +block 1 all green, block 2 all blue. `a` is now the **block** index. +**Each arriving row is stored as a column**, so within each block the row +index is `b` and the column index is `k'`. This orientation sets pass 3 +up to right-multiply by B and land in canonical order directly. + +The operation is a **block-level transpose of a K×K super-matrix** whose +cells are K-vectors indexed by `b`, with an in-block transpose rolled into +the receive address (free — same data movement, different destination cell). + +Callout on the frame: *"block-level transpose: block index ↔ within-block +row."* This is the single sentence the audience should leave with. + +The main visual shows only the flow of one color (e.g. red rows, all +destined for block `a=0`) to keep the arrow count manageable; an inset +shows the full K=2 picture with all four arrows. + +--- + +## Frame 5 — pass 3 (no un-shuffle needed) + +> ![Frame 5](frames/frame5_pass3.png) + +``` +row of K squares ──(× B on right)──► canonical result + rows: b, cols: k' rows: b, cols: c +``` + +Points to emphasize: + +- Block index throughout this frame is `a` (inherited from the corner turn). +- **Right-multiply by B contracts `k'` → `c` in the column slot**, so the + output is already in canonical (b, c) order. No per-block transpose on + store — it was absorbed into the corner turn. +- Passes 2 and 3 now share the same GEMM orientation (B on the right); + only pass 1 is the odd one out. One fewer MFMA microkernel variant to + maintain on GPU. + +--- + +## Frame 6 — the bookkeeping table (takeaway slide) + +> ![Frame 6](frames/frame6_table.png) + +| stage | block idx | within-block row | within-block col | +|--------------------|-----------|------------------|------------------| +| distribute | k' | i' | j' | +| after pass 1 | k' | **a** | j' | +| after pass 2 | k' | a | **b** | +| after corner turn | **a** | **b** | **k'** | +| after pass 3 | a | b | **c** | + +Bold cells mark which slot changed on each row. Two annotations: + +- Rows 1–2 (pass 1) and row 5 (pass 3): *"local GEMM: one slot updates."* +- Row 4: *"corner turn: block slot ↔ row slot, and row ↔ col inside the block."* + +--- + +## Cost summary + +| item | cost per transform | +|--------------------------------------|---------------------------------------| +| compute (mathematical minimum) | 3 · 2 · K⁴ FLOPs | +| HBM reads of A | K³ doubles (once, at distribute) | +| HBM writes of result | K³ doubles (once, at pass-3 store) | +| HBM reads of B | K² doubles (once, cached in LDS) | +| inter-wave exchange (corner turn) | K³ doubles (via LDS, one pass only) | + +Comparison against the naive 3-GEMM implementation is the whole point: HBM +traffic for A drops from 3·K³ to 1·K³, and the compute path is exactly the +same three K²×K contractions — just re-organized. + +--- + +## K=2 walkthrough (self-checking) + +Run `./validate -debug 1` in the build directory. Uses the counting tensor +`A[i, j, k] = 100i + 10j + k` and `B = I`, so every GEMM is a no-op and only +the corner turn produces visible motion. Each stage dumps all K blocks; the +final line is self-checking (`MATCH` / `DIFFER`). + +Particularly useful as a "does this still work" smoke test when porting to +GPU. + +--- + +## GPU implementation — measured perf on MI250X (K=16, one GCD) + +Wired up as level 7. One block per tensor; grid = `nfuncs`. +Thread block = 64 × K threads = 1024 at K=16 (one wave per K×K slab). + +Step-by-step results through the optimization arc (N=2048, FP64): + +| version | GF/s | vs. prev | +|-----------------------------------------------|-------:|---------:| +| First cut (double-buffer LDS, B from HBM) | 720 | — | +| Single-buffer LDS + B cached in LDS | 2270 | 3.15× | +| + LDS row pad (kill 16-way bank conflict) | 3320 | 1.46× | +| + drop same-wave-only barriers (8 → 3) | 3460 | 1.04× | +| + fuse pass-2 store with corner-turn write | 3475 | 1.00× | +| + swap pass-1 operands → pass-2 reads acc | 3475 | 1.00× | +| + coalesced cooperative distribute | 7400 | 2.13× | +| + one block per tensor (grid = nfuncs) | 8880 | 1.20× | +| + `double4` loads for distribute + B cache | **9340** | 1.05× | + +Final result: **9340 GF/s ≈ 20 % of scalar FP64 peak ≈ 10 % of MFMA peak**. +More importantly, **86 % of the HBM-bandwidth roofline** at the kernel's +arithmetic intensity (AI = 6.14 FLOPs/byte). See `roofline.py` / `roofline.png`. + +### What we learned (not all optimizations paid off) + +A few commits are worth flagging because their lesson matters more than +their number: + +- **LDS bank-conflict pad — big win (1.46×)**. Classic 16-way conflict pattern + in the corner turn. Simple row-stride pad from K to K+1. +- **LDS instruction cuts gave ~nothing**. Dropping barriers (-60 % LDS wait), + fusing pass-2 into corner turn (-8 LDS ops/lane), register-fused pass-1→ + pass-2 (-8 more LDS ops/lane) all looked big in counters (40 % fewer LDS + insts total) but moved wall clock by <5 %. Interpretation: LDS was never + on the critical path in this kernel; those ops were fully hidden in the + shadow of MFMA/HBM latency. +- **Coalesced distribute — biggest single win (2.1× at N=2048)**. The prior + per-wave stride-K reads were being served as individual cache lines by the + hardware coalescer — 16× amplification on L2 request count. `rocprof` showed + HBM BW "only 26 % of peak" but actually the cache-line layer was saturated. + Counter-intuitive: the HBM BW counter was misleading. +- **Widening the pass-3 store via LDS staging — regression (-17 %)**. The 4× + narrow stores were already coalescing into 4 cache-line writes at the + hardware level; adding a barrier + LDS round-trip lost more than the + saved instructions gained. +- **MFMA swap trick (pass 1) — enables register fusion but no perf gain**. + Swapping operands in pass 1 makes its output layout match pass 2's input + exactly, so pass 2 reads acc directly. Theoretically saves LDS traffic; + wall-clock impact hidden by memory bubbles (see point 2 above). + +### Compiler / ISA notes (GFX90A) + +- MFMA `v_mfma_f64_16x16x4f64` layouts confirmed empirically via + `test_mfma_layout.hip` (the upstream L4 comment had them wrong): + - A_frag (16×4): lane t → A[t%16][t/16] + - B_frag (4×16): lane t → B[t/16][t%16] + - D output (16×16): lane t acc[e] → D[(t/16) + 4e][t%16] +- The widest global load/store on gfx90a is `global_{load,store}_dwordx4` + (128-bit, 2 doubles per lane). `double4` in code compiles to a pair + of these. +- LDS bank period: 32 banks × 4 bytes = 128 bytes. Stride-K doubles + (K = 16) hits the same bank row → 16-way conflict; stride-(K+1) shifts + by 2 banks → conflict-free. + +--- + +## K=20 / K=24 — use v_mfma_f64_4x4x4f64 (design note) + +K=20 and K=24 are the scientifically relevant odd-size cases (K=32 turns +out not to be; it was a suggestion just because the 16-divisibility +worked nicely with MFMA 16×16×4). Neither K=20 nor K=24 divides by 16, +so they can't use `v_mfma_f64_16x16x4f64` directly without padding, and +padding to K=32 blows LDS. + +The path forward is to use the smaller **`v_mfma_f64_4x4x4f64`** MFMA +variant — same hardware matrix core, 4×4 output tile instead of 16×16. +This avoids padding entirely because 4 divides both 20 and 24 cleanly. + +### Tile-count cost + +Per pass, per wave: + +| | 16×16×4 path (K=16) | 4×4×4 path (K=20) | 4×4×4 path (K=24) | +|-------|---------------------|--------------------|--------------------| +| output tiles per K slice | 1 | 25 = 5×5 | 36 = 6×6 | +| K slices needed | 4 | 5 | 6 | +| MFMA calls per pass | 4 | 125 | 216 | + +The 4×4×4 approach uses ~30× more MFMA calls per pass than the K=16 +kernel. On gfx90a, matrix-core throughput for all f64 MFMA variants +is the same per-cycle-of-matrix-core, so more-but-smaller calls is a +wash on raw compute — but the **issue rate** may bottleneck at low +occupancy. + +### 4×4×4 fragment layout — CONFIRMED via probe + +`v_mfma_f64_4x4x4f64` on gfx90a returns **one double per lane**. +64 lanes × 1 = 64 output values = **four INDEPENDENT 4×4×4 GEMMs** +(NOT broadcast — earlier hypothesis disproved by `test_mfma_4x4x4_layout{,2,3}.hip`). + +With the lane decomposition `S = 16·α + 4·g + β` where `β = S%4`, +`g = (S/4)%4` and `α = S/16`, the **group index** is `g` (∈ {0,1,2,3}). +Each of the four groups computes its own 4×4×4 GEMM `D_g = A_g · B_g`: + +| fragment | lane S holds | +|-----------|------------------------------| +| A (4×4) | `A_g[m = β ][k = α]` | +| B (4×4) | `B_g[k = α ][n = β]` | +| D (4×4) | `D_g[m = α ][n = β]` | + +So A is column-major across 16 lanes, B row-major, D row-major, and +`t/16`/`(t/4)%4` each play different roles for the input vs. output +fragments — A's column index is where D's row index comes from, and +vice versa. The four `g`-groups share nothing: `g=1` reads none of +`g=0`'s lanes. + +This is strictly better for our K=20 case than 4×-broadcast would have +been: we can pack 4 independent 4×4 output tiles into a single +instruction, if we can lay out A/B so the four groups see different +tiles. + +### Algorithm shape (once the 4×4×4 layout is understood) + +Assuming we find a CBSZ/BLGP combination that gives 4 independent tiles: + +- Replace the single 16×16 `mma_sync` in each of the three passes + with a nested pair of tile loops: + ``` + for (row_tile = 0; row_tile < K/4; ++row_tile) + for (col_tile = 0; col_tile < K/4; ++col_tile) + for (k_slice = 0; k_slice < K/4; ++k_slice) + MFMA_4x4x4(...) + ``` + At K=20: 5×5×5 = 125 calls per pass per wave. +- The per-wave output is K² = 400 (K=20) or 576 (K=24) doubles. + 400/64 = 6.25 per lane — **not integer**. For K=20 this means + some lanes will have 6 elements, others 7. Awkward. May want + to pad to K=21 internally (6.56/lane)? Or group 2 MFMA calls + per pair of lanes (?). Ugly. + 576/64 = 9 per lane — clean for K=24. +- LDS: K³ × 8 = 20³×8 = 64 KB (K=20), 24³×8 = 110 KB (K=24). + K=20 just fits; K=24 still over. So K=24 needs the same + streaming story as K=32 (see section below). + +### Recommended next steps (in order) + +1. **Measure L3 at K=20 and K=24 first.** If L3 hits 50%+ of the HBM + roofline at these K values, a custom MFMA kernel may not be worth + the complexity. +2. **Pin down the 4×4×4 fragment layout** via a diagnostic kernel. +3. **Start with K=20 using 4×4×4** — clean LDS budget, even if awkward + lane distribution. K=24 can follow after streaming is figured out. +4. Consider a hybrid **16×16×4 main + 4×4×4 remainder** tiling (e.g. + K=20 = 16+4 in each dim gives one 16×16 main tile + strips + + corner). Complex but avoids most of the 4×4×4 instruction overhead. + +### K=20 status: implemented (MFMA via 4 independent tiles per call) + +`transform_blocked_k20.h` ships both a scalar correctness baseline and +a 4×4×4 MFMA path. Design choices and constraints: + +| knob | K=20 choice | +|-------------------------|-------------------------------------------------------| +| threads/block | 64 × 10 = 640 (20 waves would exceed the 1024 cap) | +| slabs per wave | 2 (wave w owns k' ∈ {w, w+10}) | +| LDS tensor buf | K³ = 8000 doubles = 64 KB, **unpadded** | +| B caching | read from HBM (no room for B_lds at 64 KB cap) | +| tile pattern | 25 output tiles of 4×4 per slab, 7 MFMA rounds × 5 k-slices | +| wasted groups | last round has 3 unused groups (fed zero) | + +Measured @ N=2048, FP64, MI250X single GCD: + +| variant | GF/s | vs. L3 | kernel µs | MFMA busy | LDS wait/inst | +|---------------------------------------|-----:|--------:|----------:|----------:|--------------:| +| L3 (register-block) | 1967 | — | — | — | — | +| L7 scalar (K=20) | 1083 | 0.55× | — | — | — | +| L7 pure 4×4×4 | 2787 | 1.42× | 753 | 12.2 % | 5.5 cyc | +| L7 hybrid 16×16×4 + 4×4×4 | 3480 | 1.77× | 591 | 11.1 % | 19.4 cyc | +| **L7 hybrid + fusion + wide dist** | **6900** | **3.51×** | **370** | **17.8 %** | **4.1 cyc** | + +The final configuration pairs three changes: + +1. **Hybrid tiling** — 16×16×4 MFMA for the main 16×16 sub-tile, 4×4×4 MFMA + for the 16×4 / 4×16 strips and the 4×4 corner. 43 % fewer MFMA issues + than the pure 4×4×4 path. +2. **Pass 1 → Pass 2 register fusion** (operand-swap trick from the K=16 + kernel). Pass 1 computes `D = temp1^T` so the main accumulator `p1_main[e]` + at lane t directly equals `temp1[t%16][(t/16)+4e]` — exactly what pass 2's + A-frag wants. Pass 1 right strip gets the same treatment and feeds pass 2's + 5th k-slice from register `p1_right`. Bottom strip and corner can't be + fused cleanly (α/β transpose mismatch) and still round-trip through LDS. + LDS instructions dropped 32 %, LDS wait cycles dropped 79 %. +3. **double4 distribute** — K³ = 8000 doubles = 2000 `double4` loads. + +Remaining headroom (all less impactful than the above): +- **MFMA busy still 18 %** — there's still ~2× to gain before matching K=16's + ~50 % MFMA-busy ceiling; further reductions in VALU overhead (mostly in + the bottom/corner 4×4×4 loops) and in bank conflicts (currently 16-way on + the stride-K=20 accesses for bottom/corner) would help. +- **No B_lds** — single-block-per-tensor has no room for a 3.2 KB B cache + after the K³ = 64 KB `buf`. Attempted 2-blocks-per-tensor with padded + layout + B_lds but atomicAdd + pre-zero overhead (+500 MB HBM) erased the + bank-conflict win — see notes below. +- **Pass 2 → Pass 3 fusion** isn't feasible: pass 3's cross-wave corner-turn + REQUIRES data in LDS. + +### Experiments that didn't pay off + +- **2 blocks per tensor + K+1 padded LDS + atomicAdd** (`transform_kernel_blocked_k20_split` + kept in the tree for reference). Bank conflicts did drop to 0.2-way, but + atomic read-modify-write on C and pre-zero memset inflated HBM traffic + 3× (258 → 752 MB), VALU went up 2.6×, and MFMA busy *fell* to 6 %. Net: + ~1.9× slower. Take-away: bank conflict counters were real but not the + critical path on the hybrid; the extra coordination costs of splitting + the k'-axis across blocks easily dominated the LDS wins. + +### 4×4×4 layout — CONFIRMED (overrides the "broadcast" hypothesis) + +From `test_mfma_4x4x4_layout{,2,3}.hip`: + +- The instruction is **4 INDEPENDENT 4×4×4 GEMMs**, not a 4-way + broadcast. Lane decomposition `S = 16α + 4g + β` with + `β = S%4, g = (S/4)%4, α = S/16`. `g` is the group id. +- A (4×4): lane S holds `A_g[m=β][k=α]` (col-major) +- B (4×4): lane S holds `B_g[k=α][n=β]` (row-major) +- D (4×4): lane S holds `D_g[m=α][n=β]` (row-major) + +Note α and β swap roles between input (A/B) and output (D) fragments — +same oddity as the 16×16×4 layout. + +### Obvious next optimisations (not pursued here) + +- **Fit B in LDS** via multi-block-per-tensor (e.g. 2 blocks handling + 10 k'-slabs each, atomic-add in pass 3). LDS per block drops to + ~35 KB and every MFMA can use LDS B. +- **Hybrid 16×16×4 main + 4×4×4 strip/corner** tile pattern. Uses the + faster instruction on the bulk of the work. +- **Pass-1/pass-2 register fusion** (as in the K=16 path). + +--- + +## K=32 — design note (not yet implemented) + +At K=32 the K=16 algorithm's working set doesn't fit in LDS, even split +across multiple blocks. This is a real redesign, not a tweak. + +### The budget problem + +| thing | K=16 | K=32 | +|------------------------------------------------|--------------|--------------| +| full K³ tensor | 4 K dbl = 32 KB | 32 K dbl = **256 KB** | +| per-wave K×K slab | 256 dbl = 2 KB | 1024 dbl = 8 KB | +| 16 waves × slab (one block per tensor) | 32 KB | **128 KB** | +| 16 waves × slab with K+1 pad | 34 KB | **132 KB** | +| LDS cap (MI250X) | 64 KB | 64 KB | + +Even splitting into 2 blocks (each handling K/2 of k') gives 128 KB/block — +2× over cap. There's no simple split that makes the K=16 algorithm fit. + +### What must change + +The working set per wave has to shrink from K×K to something smaller. +Two candidate approaches: + +**(A) Each wave holds K × K/2 (half-slab).** + - 2 waves cooperate on a full K×K slab. + - Per-wave LDS: 512 dbl = 4 KB. 32 waves × 4 KB = 128 KB. Still over. + - 16 waves × 4 KB = 64 KB, halved coverage. Would need 2 blocks to cover. + +**(B) Stream the corner turn in k'-chunks.** + - Hold only a chunk of the tensor at a time: K × K × (K/4) = 8 KB · 4 = 32 KB. + - Pass 3 accumulates across 4 chunked corner turns into registers. + - Extra HBM passes for re-reading A, unless we also chunk passes 1/2. + - Complex but fits the budget. + +**(C) Atomic-add reduction across 4 blocks.** + - 4 blocks per tensor; each handles K/4 = 8 values of k'. + - Each block's pass 3 contributes a partial; atomicAdd to HBM C. + - 4 atomics per output element, 4K³ = 128K atomics per tensor. + - Simpler code, slower arithmetic path. + +### Thread-block sizing + +At K=32, blockDim must be ≤ 1024 threads. With 64-thread waves, that's +≤ 16 waves per block. So we can't have 32 waves (one per k') in a single +block — each wave must handle ≥ 2 slabs, or blocks must handle fewer k' +values. + +### Recommended approach for the next iteration + +Start with **(C) — 4 blocks per tensor with atomic add**, because: +- Reuses the existing K=16 algorithm shape almost verbatim (per-block + pipeline is identical). +- Simplest to get correct; good baseline before optimizing. +- Perf will be bad (atomics, 4× HBM writes for reduction) but it + validates the design. + +Then iterate toward **(B)** — streaming — which preserves single-block- +per-tensor semantics and should be faster once correct. + +### Open questions worth probing first + +- What does the actual MADNESS workload look like at K=32? If `nfuncs` + is small, block-count explosion from approach (C) isn't a problem + but atomic contention might be. +- Does hipBLAS's batched DGEMM do better than anything we can write by + hand at K=32? Worth measuring L6 (kron) on real K=32 workloads as + a reference point. +- L4 / L5 (upstream MFMA + rocWMMA paths) likely have K=32 code; are + their designs worth stealing? + +--- + +## References + +- `validate.hip` — CPU reference and GPU L1 correctness check +- `validate_levels.hip` — multi-level correctness test (select L7 with `-l 7`) +- `transformbench.hip` — throughput benchmark (`-l 7`) +- `mra_python/algorithms.py::transform_nd_blocked` — NumPy version +- `test_mfma_layout.hip` — diagnostic that verifies MFMA fragment layouts +- `counters.txt`, `counters_deep.txt` — rocprof counter sets +- `roofline.py`, `roofline.png` — roofline analysis +- `transform.h`, `transform_level{2,3}.h` — earlier GPU levels for comparison diff --git a/CMakeLists.txt b/CMakeLists.txt index 1935e9b..dba72f4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -103,6 +103,10 @@ if (CMAKE_HIP_COMPILER) add_executable(validate_levels validate_levels.hip) target_link_libraries(validate_levels PUBLIC libmra roc::hipblas) + # Correctness test: GPU L1 vs CPU mTxm reference (mirrors madness/transform3d.cc) + add_executable(validate validate.hip) + target_link_libraries(validate PUBLIC libmra) + if (USE_SUGGEST_LAYOUT) # Enable using suggested layout instead of get_layout target_compile_definitions(transformbench_hip PUBLIC USE_SUGGEST_LAYOUT) diff --git a/transform_blocked.h b/transform_blocked.h new file mode 100644 index 0000000..5d73530 --- /dev/null +++ b/transform_blocked.h @@ -0,0 +1,239 @@ +#pragma once +#include +#include "util.h" + +// Entire blocked-transform implementation is HIP/AMD-only (uses MFMA +// intrinsics). Skip under nvcc so transformbench.cu still compiles. +#if defined(__HIP__) + +typedef double v4f64 __attribute__((ext_vector_type(4))); + +__device__ inline v4f64 mfma_16x16x4_f64_shared(double a, double b, v4f64 c) { + return __builtin_amdgcn_mfma_f64_16x16x4f64(a, b, c, 0, 0, 0); +} + +#include "transform_blocked_k20.h" + +// transform_blocked.h — block-distributed 3D transform with AMD MFMA. +// +// One wavefront owns one K×K block of the tensor throughout the computation: +// wave s initially holds A[:, :, k'=s]; +// after the corner turn, wave s holds result[a=s, :, :]. +// +// Flow: +// distribute wave s loads A[:, :, k'=s] (strided read from HBM) +// Pass 1 (local) blk_s <- B^T · blk_s (a, j') +// Pass 2 (local) blk_s <- blk_s · B (a, b) +// corner turn LDS all-to-all: wave t ends up holding +// temp2[a=t, b, k'] stored as (b, k') +// Pass 3 (local) blk_t <- blk_t · B (b, c) — canonical +// store wave s writes result[s, :, :] to HBM +// +// Local GEMMs use v_mfma_f64_16x16x4f64 (GFX90A). One MFMA does a 16×16 +// output with K=4 contraction; for K=16 we chain 4 MFMAs per pass. +// +// LDS layout (34 KB at K=16): +// buf[K³] single K³ scratch, in-place across passes via register stash +// B_lds[K²] cached B matrix, shared across all waves +// +// Thread-block size: 64 × K = 1024 threads at K=16. + +#if defined(__HIP__) +__device__ inline v4f64 mfma_16x16x4_f64(double a, double b, v4f64 c) { + return __builtin_amdgcn_mfma_f64_16x16x4f64(a, b, c, 0, 0, 0); +} +#endif + +template +__global__ +__launch_bounds__(K * 64, 1) +void transform_kernel_blocked(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*workspace unused*/) +{ + static_assert(std::is_same::value, "MFMA path is FP64 only"); + static_assert(K == 16, "MFMA path: K=16 only for now"); + static_assert(K * K % 64 == 0, "K^2 must be a multiple of the wavefront size (64)"); + + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int ELEMS_PER_LANE = K2 / 64; // 4 at K=16 + constexpr int NMFMA = K / 4; // 4 MFMAs per K=16 pass + + // LDS bank-conflict pad: give each "row" inside a wave's K×K region an + // extra column of dead storage so within-wave stride-K accesses (pass 2/3 + // A_frag loads, corner-turn cross-wave writes) don't collide on the same + // 128-byte bank row. Per-wave region becomes K × (K+1) instead of K × K. + constexpr int BK = K + 1; // padded inner (row) stride + constexpr int WB = K * BK; // per-wave region size (padded) + + extern __shared__ unsigned char _smem_blk[]; + T* buf = reinterpret_cast(_smem_blk); + T* B_lds = buf + K * WB; + + const int s = threadIdx.y; // wave index (0..K-1) + const int t = threadIdx.x; // lane within wave (0..63) + const int tid = s * 64 + t; + const int nthr = blockDim.x * blockDim.y; + + // Cache B into LDS once, using wide double4 loads. + // K²=256 doubles; with double4 (4-wide) we need K²/4 = 64 threads -- the + // first wave does it, the rest wait at the barrier. Each active thread + // issues a single global_load_dwordx4 (vs global_load_dwordx2 before), + // halving the load instruction count. B pointer is cudaMalloc'd and + // B_lds is 16-byte aligned (offset K*WB*8 = 34816, multiple of 16). + static_assert(K2 % 4 == 0, "B cache double4 path needs K² divisible by 4"); + if (tid < K2 / 4) { + const double4* B_vec = reinterpret_cast(B); + double4* B_lds_vec = reinterpret_cast(B_lds); + B_lds_vec[tid] = B_vec[tid]; + } + __syncthreads(); + + // One block = one tensor. Grid is sized to nfuncs at launch. + { + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // --- Distribute: cooperative coalesced load of the K^3 tensor --- + // Each thread loads 4 CONTIGUOUS doubles from HBM as a single double4 + // (global_load_dwordx8 = 32 bytes per lane on GFX90A). Per-thread + // contiguity is what lets the compiler emit a wide load; across 16 + // consecutive lanes this still touches just 4 cache lines (same HBM + // traffic as before, 4× fewer load instructions). + // + // Per-thread HBM indices: a_ptr[4*tid .. 4*tid+3]. + // For K=16, all 4 values share the same (i, j) with consecutive k: + // k_start = (4*tid) % 16 ∈ {0, 4, 8, 12} + // + // buf layout is canonical (i, j, k), k fastest: + // buf[i*WB + j*BK + k] = A[i, j, k] + { + static_assert(K == 16, "wide-load distribute tuned for K=16"); + const double4* a_ptr_vec = reinterpret_cast(a_ptr); + const int base_idx = 4 * tid; // 0..K^3-4 + const int i = base_idx >> 8; // / K^2 + const int j = (base_idx >> 4) & 15; + const int k_start = base_idx & 15; // 0, 4, 8, or 12 + + double4 v = a_ptr_vec[tid]; // one global_load_dwordx8 + T* row = &buf[i*WB + j*BK + k_start]; + row[0] = v.x; + row[1] = v.y; + row[2] = v.z; + row[3] = v.w; + } + __syncthreads(); // cross-wave writes visible before pass-1 reads + + // --- Pass 1 MFMA (SWAPPED OPERANDS): compute blk^T · B instead of B^T · blk --- + // GFX90A v_mfma_f64_16x16x4f64 confirmed layouts: + // A_frag (M=16, K=4): lane t -> A[t%16][t/16] (col-major) + // B_frag (K=4, N=16): lane t -> B[t/16][t%16] (row-major) + // D (M=16, N=16): lane t acc[e] -> D[(t/16)+4*e][t%16] + // + // By feeding A_frag = blk^T (from buf, treating (i,j) as if transposed) + // and B_frag = B (from B_lds), MFMA produces D where + // D[j][a] = sum_i blk[i][j] * B[i][a] = temp1[a][j] + // i.e. temp1 stored with its axes swapped in the register file: + // lane t, acc[e] == temp1[t%16][(t/16) + 4*e] + // + // This is exactly the layout pass 2's A_frag wants at iter p = e, + // so pass 2 can consume pass 1's acc directly -- no LDS round trip. + // With canonical buf layout, thread t's A_frag contribution at iter p + // lives at buf[i*WB + j*BK + k] with i=p*4+t/16, j=t%16, k=s. + v4f64 acc1 = v4f64{0.0, 0.0, 0.0, 0.0}; + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + double a_val = buf[(p*4 + (t >> 4)) * WB + (t & 15) * BK + s]; + double b_val = B_lds[(p*4 + (t >> 4)) * K + (t & 15)]; + acc1 = mfma_16x16x4_f64(a_val, b_val, acc1); + } + // No LDS writeback: acc1[p] is pass 2's a_val at iter p directly. + + // --- Pass 2 MFMA: out = temp1 · B (cols: j' -> b) --- + // A_frag = temp1 (from acc1), B_frag = B. + // thread t at iter p uses acc1[p] = temp1[t%16][p*4 + t/16] + v4f64 acc = v4f64{0.0, 0.0, 0.0, 0.0}; + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + double a_val = acc1[p]; + double b_val = B_lds[(p*4 + (t >> 4)) * K + (t & 15)]; + acc = mfma_16x16x4_f64(a_val, b_val, acc); + } + // Pass 2 output is in acc. The corner-turn stash step used to read + // pass 2's buf at positions (a = t/16 + 4e, b = t%16) -- exactly the + // positions MFMA already deposited in acc. So acc[e] IS the stash; + // we can skip both the pass-2 store and the stash read, and write + // acc directly into the cross-wave destination below. + + __syncthreads(); // all waves done with pass-2 reads from own region + + // --- Corner turn (cross-wave write, fused with pass-2 store) --- + // buf[a*WB + b*BK + s] <- acc[e] where (a, b) = (t/16 + 4e, t%16) + #pragma unroll + for (int e = 0; e < ELEMS_PER_LANE; ++e) { + int idx = t + e * 64; + int a = idx / K; + int b = idx % K; + buf[a*WB + b*BK + s] = acc[e]; + } + __syncthreads(); + + // --- Pass 3 MFMA + store: out = blk · B, write directly to HBM --- + // Same orientation as pass 2 (blk has (b, k') layout). + acc = v4f64{0.0, 0.0, 0.0, 0.0}; + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + double a_val = buf[s*WB + (t & 15) * BK + (p*4 + (t >> 4))]; + double b_val = B_lds[(p*4 + (t >> 4)) * K + (t & 15)]; + acc = mfma_16x16x4_f64(a_val, b_val, acc); + } + // No sync: acc is in registers, about to write to HBM. + + #pragma unroll + for (int e = 0; e < 4; ++e) { + int row = (t >> 4) + 4 * e; // b + int col = t & 15; // c + c_ptr[s*K2 + row*K + col] = acc[e]; + } + } +} + +template +inline size_type blocked_shmem_size(int K) { + // Padded buf (K × K × (K+1)) + B cache (K × K) + return (size_type)((K * K * (K + 1) + K * K) * sizeof(T)); +} + +template +inline Dim3 blocked_blockdim(int K) { + return Dim3(64, K, 1); +} + +template +inline void submit_transform_bench_blocked(int nfuncs, int /*nblocks*/, int K, + const T* A, const T* B, T* C, T* workspace, + Stream stream) +{ + if (K == 16) { + constexpr int Kv = 16; + Dim3 td = blocked_blockdim(Kv); + size_type smem = blocked_shmem_size(Kv); + CONFIGURE_KERNEL((transform_kernel_blocked), smem); + // Grid = nfuncs: one block per tensor. nblocks arg is ignored here + // (matches upstream convention for single-function-per-kernel kernels). + CALL_KERNEL((transform_kernel_blocked), nfuncs, td, smem, stream, + (nfuncs, A, B, C, workspace)); + } else if (K == 20) { + submit_transform_bench_blocked_k20(nfuncs, A, B, C, workspace, stream); + } else { + fprintf(stderr, "blocked transform: K=%d not supported\n", K); + assert(false); + } +} + +#endif // __HIP__ diff --git a/transform_blocked_k20.h b/transform_blocked_k20.h new file mode 100644 index 0000000..f4f8462 --- /dev/null +++ b/transform_blocked_k20.h @@ -0,0 +1,1122 @@ +#pragma once +#include +#include "util.h" + +// HIP/AMD-only — uses v_mfma_f64_4x4x4f64 and v_mfma_f64_16x16x4f64 intrinsics. +#if defined(__HIP__) + +// transform_blocked_k20.h — corner-turn 3D transform at K=20. +// +// Two variants live here: +// * scalar — `transform_kernel_blocked_k20_scalar` (correctness baseline) +// * MFMA — `transform_kernel_blocked_k20_mfma` (v_mfma_f64_4x4x4f64) +// +// Structural notes (shared): +// +// * K=20 needs 20 slabs but one wave (64 lanes) per slab would require +// 20 × 64 = 1280 threads, past the 1024/block limit. We use 10 waves +// and give each wave two k' slabs (w, w+10). +// +// * LDS budget is tight. K³ = 8000 doubles = 64 KB. No room for a +// B_lds, and bank-conflict padding (K → K+1) would push us to 8400 +// doubles. B lives in HBM and goes through the L1 cache; at +// K²=400 it's 3.2 KB and fits L1 comfortably. +// +// Logical flow (matches CPU reference cpu_transform3d_blocked): +// +// distribute : buf[i, j, k] ← A[i, j, k] canonical layout +// Pass 1 : buf[a, j, k] ← Σ_i B[i, a] · buf[i, j, k] +// Pass 2 : buf[a, b, k] ← Σ_j buf[a, j, k] · B[j, b] +// "corner turn": implicit -- buf is already (a, b, k); wave w +// simply reads rows a ∈ {w, w+10} in pass 3 instead of +// the k' it wrote. No LDS shuffle needed. +// Pass 3 : C[a, b, c] ← Σ_k buf[a, b, k] · B[k, c] +// (write directly to HBM) + +// ============================================================================ +// Scalar variant (correctness baseline) +// ============================================================================ + +template +__global__ +__launch_bounds__(640, 1) +void transform_kernel_blocked_k20_scalar(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*ws unused*/) +{ + static_assert(std::is_same::value, "K=20 blocked: FP64 only"); + + constexpr int K = 20; + constexpr int K2 = K * K; // 400 + constexpr int K3 = K * K * K; // 8000 + constexpr int NWAVES = 10; + constexpr int SPW = 2; // slabs per wave + constexpr int LANES = 64; + constexpr int MAX_SLOTS = 7; // ceil(K²/64) + + extern __shared__ unsigned char _smem_k20[]; + T* buf = reinterpret_cast(_smem_k20); + + const int w = threadIdx.y; + const int t = threadIdx.x; + const int tid = w * LANES + t; + const int nthr = blockDim.x * blockDim.y; + + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + for (int idx = tid; idx < K3; idx += nthr) buf[idx] = a_ptr[idx]; + __syncthreads(); + + // Pass 1 + { + double acc[SPW][MAX_SLOTS]; + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) { acc[ss][slot] = 0.0; continue; } + const int a = cell_idx / K, j = cell_idx % K; + double s = 0.0; + #pragma unroll + for (int i = 0; i < K; ++i) s += B[i * K + a] * buf[i * K2 + j * K + kp]; + acc[ss][slot] = s; + } + } + __syncthreads(); + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) continue; + const int a = cell_idx / K, j = cell_idx % K; + buf[a * K2 + j * K + kp] = acc[ss][slot]; + } + } + } + __syncthreads(); + + // Pass 2 + { + double acc[SPW][MAX_SLOTS]; + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) { acc[ss][slot] = 0.0; continue; } + const int a = cell_idx / K, b = cell_idx % K; + double s = 0.0; + #pragma unroll + for (int j = 0; j < K; ++j) s += buf[a * K2 + j * K + kp] * B[j * K + b]; + acc[ss][slot] = s; + } + } + __syncthreads(); + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) continue; + const int a = cell_idx / K, b = cell_idx % K; + buf[a * K2 + b * K + kp] = acc[ss][slot]; + } + } + } + __syncthreads(); + + // Pass 3 (direct to HBM) + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int a = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) continue; + const int b = cell_idx / K, c = cell_idx % K; + double s = 0.0; + #pragma unroll + for (int kp = 0; kp < K; ++kp) s += buf[a * K2 + b * K + kp] * B[kp * K + c]; + c_ptr[a * K2 + b * K + c] = s; + } + } +} + +// ============================================================================ +// MFMA variant — v_mfma_f64_4x4x4f64 +// +// Layout (empirically confirmed in test_mfma_4x4x4_layout*.hip): +// The instruction computes FOUR INDEPENDENT 4×4×4 GEMMs per call. +// With lane S = 16·α + 4·g + β (g = group, α = S/16, β = (S/4)%4), +// each of the four groups g ∈ {0,1,2,3} operates on its own 4×4 A, +// 4×4 B and 4×4 D: +// A_g[m = β ][k = α] at lane S +// B_g[k = α ][n = β] at lane S +// D_g[m = α ][n = β] at lane S +// +// GEMM plan per pass, per slab: +// Output is 20×20 = 5×5 tiles of 4×4. Inner contraction K_inner=20=5×4. +// Per output tile: 5 MFMA calls to accumulate the K-slices. +// Per MFMA call: 4 tiles computed simultaneously (one per g). +// 25 tiles ⇒ 7 MFMA rounds (25/4 rounds = 6 full + 1 partial, last +// round leaves groups 1,2,3 idle -- fed zeros so their output is +// discarded cleanly). +// ============================================================================ + +__device__ inline double mfma_4x4x4_f64(double a, double b, double c) { + return __builtin_amdgcn_mfma_f64_4x4x4f64(a, b, c, 0, 0, 0); +} + +template +__global__ +__launch_bounds__(640, 1) +void transform_kernel_blocked_k20_mfma(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*ws unused*/) +{ + static_assert(std::is_same::value, "K=20 blocked: FP64 only"); + + constexpr int K = 20; + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int NWAVES = 10; + constexpr int SPW = 2; + constexpr int LANES = 64; + constexpr int TILES_DIM = 5; // K / 4 + constexpr int K_SLICES = 5; // K / 4 + constexpr int TOTAL_TILES = TILES_DIM * TILES_DIM; // 25 + constexpr int MFMA_ROUNDS = 7; // ceil(25/4) + + extern __shared__ unsigned char _smem_k20m[]; + T* buf = reinterpret_cast(_smem_k20m); + + const int w = threadIdx.y; + const int t = threadIdx.x; + const int tid = w * LANES + t; + const int nthr = blockDim.x * blockDim.y; + + // Lane decomposition: S = 16α + 4g + β with β = S%4, g = (S/4)%4, α = S/16. + const int beta = t & 3; // β: low 2 bits + const int g = (t >> 2) & 3; // g: middle 2 bits -- MFMA group id + const int alpha = (t >> 4) & 3; // α: high 2 bits + + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // ----- Distribute ----- + for (int idx = tid; idx < K3; idx += nthr) buf[idx] = a_ptr[idx]; + __syncthreads(); + + // ----- Pass 1: buf[a, j, kp] = Σ_i B[i, a] · buf[i, j, kp] ----- + // MFMA A_g[m=β][k=α] ← B[i=4·k_tile+α, a=4·a_tile_g+β] + // MFMA B_g[k=α][n=β] ← buf[i=4·k_tile+α, j=4·j_tile_g+β, kp] + // D_g[m=α][n=β] ↦ buf[a=4·a_tile_g+α, j=4·j_tile_g+β, kp] + { + double tile_acc[SPW][MFMA_ROUNDS]; + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + const bool valid = tile_idx < TOTAL_TILES; + const int a_tile = valid ? tile_idx / TILES_DIM : 0; + const int j_tile = valid ? tile_idx % TILES_DIM : 0; + + double acc = 0.0; + #pragma unroll + for (int k_tile = 0; k_tile < K_SLICES; ++k_tile) { + const int gi = 4 * k_tile + alpha; + const int ga = 4 * a_tile + beta; + const int gj = 4 * j_tile + beta; + const double aval = valid ? B[gi * K + ga] : 0.0; + const double bval = valid ? buf[gi * K2 + gj * K + kp] : 0.0; + acc = mfma_4x4x4_f64(aval, bval, acc); + } + tile_acc[ss][round] = acc; + } + } + + __syncthreads(); + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + if (tile_idx >= TOTAL_TILES) continue; + const int a_tile = tile_idx / TILES_DIM; + const int j_tile = tile_idx % TILES_DIM; + const int ga = 4 * a_tile + alpha; + const int gj = 4 * j_tile + beta; + buf[ga * K2 + gj * K + kp] = tile_acc[ss][round]; + } + } + } + __syncthreads(); + + // ----- Pass 2: buf[a, b, kp] = Σ_j buf[a, j, kp] · B[j, b] ----- + // MFMA A_g[m=β][k=α] ← buf[a=4·a_tile_g+β, j=4·j_tile+α, kp] + // (here "m" of the mfma = "a" of our GEMM; "k" of the mfma = "j") + // MFMA B_g[k=α][n=β] ← B[j=4·j_tile+α, b=4·b_tile_g+β] + // D_g[m=α][n=β] ↦ buf[a=4·a_tile_g+α, b=4·b_tile_g+β, kp] + { + double tile_acc[SPW][MFMA_ROUNDS]; + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + const bool valid = tile_idx < TOTAL_TILES; + const int a_tile = valid ? tile_idx / TILES_DIM : 0; + const int b_tile = valid ? tile_idx % TILES_DIM : 0; + + double acc = 0.0; + #pragma unroll + for (int j_tile = 0; j_tile < K_SLICES; ++j_tile) { + const int ga = 4 * a_tile + beta; + const int gj = 4 * j_tile + alpha; + const int gb = 4 * b_tile + beta; + const double aval = valid ? buf[ga * K2 + gj * K + kp] : 0.0; + const double bval = valid ? B[gj * K + gb] : 0.0; + acc = mfma_4x4x4_f64(aval, bval, acc); + } + tile_acc[ss][round] = acc; + } + } + + __syncthreads(); + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + if (tile_idx >= TOTAL_TILES) continue; + const int a_tile = tile_idx / TILES_DIM; + const int b_tile = tile_idx % TILES_DIM; + const int ga = 4 * a_tile + alpha; + const int gb = 4 * b_tile + beta; + buf[ga * K2 + gb * K + kp] = tile_acc[ss][round]; + } + } + } + __syncthreads(); + + // ----- Pass 3: C[a, b, c] = Σ_k buf[a, b, k] · B[k, c] ----- + // After pass 2 the buf layout is (a, b, k'). Wave w now "owns" + // rows a ∈ {w, w+10} -- it reads from cross-wave contributions. + // + // For pass 3, fix a = the wave's owned row. The per-slab GEMM is + // over (b, c) with inner k'. + // + // MFMA A_g[m=β][k=α] ← buf[a=fixed, b=4·b_tile_g+β, k'=4·k_tile+α] + // ("m" = "b" of the GEMM, "k" = "k'") + // MFMA B_g[k=α][n=β] ← B[k'=4·k_tile+α, c=4·c_tile_g+β] + // D_g[m=α][n=β] ↦ C[a=fixed, b=4·b_tile_g+α, c=4·c_tile_g+β] + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int a = w + ss * NWAVES; + + double tile_acc[MFMA_ROUNDS]; + + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + const bool valid = tile_idx < TOTAL_TILES; + const int b_tile = valid ? tile_idx / TILES_DIM : 0; + const int c_tile = valid ? tile_idx % TILES_DIM : 0; + + double acc = 0.0; + #pragma unroll + for (int k_tile = 0; k_tile < K_SLICES; ++k_tile) { + const int gb = 4 * b_tile + beta; + const int gk = 4 * k_tile + alpha; + const int gc = 4 * c_tile + beta; + const double aval = valid ? buf[a * K2 + gb * K + gk] : 0.0; + const double bval = valid ? B[gk * K + gc] : 0.0; + acc = mfma_4x4x4_f64(aval, bval, acc); + } + tile_acc[round] = acc; + } + + // No __syncthreads needed: writes go to HBM, not LDS. + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + if (tile_idx >= TOTAL_TILES) continue; + const int b_tile = tile_idx / TILES_DIM; + const int c_tile = tile_idx % TILES_DIM; + const int gb = 4 * b_tile + alpha; + const int gc = 4 * c_tile + beta; + c_ptr[a * K2 + gb * K + gc] = tile_acc[round]; + } + } +} + +// ============================================================================ +// Hybrid MFMA variant — 16×16×4 main + 4×4×4 strips + 4×4 corner +// +// Profile of the pure-4×4×4 path showed MFMA units at 12 % busy (starved by +// per-instruction VALU/LDS overhead), not HBM-bound. The fix is to push +// more FMAs per instruction by using v_mfma_f64_16x16x4f64 for the bulk +// 16×16 sub-tile; the remaining 16×4, 4×16 and 4×4 pieces are handled by +// v_mfma_f64_4x4x4f64 with 4, 4 and 1 groups respectively. +// +// Per slab per pass: 5 k-slices × 4 MFMA kinds = 20 MFMAs (vs. 35 for the +// pure-4×4×4 path -- a 43 % reduction in instruction issues). +// +// Tile partition of a 20×20 output slab: +// ┌────────────────┬──────┐ +// │ 16×16 main │16×4 │ right strip (4 tiles, 4 MFMA groups) +// │ │right │ +// ├────────────────┼──────┤ +// │ 4×16 bottom │ 4×4 │ corner (1 tile, 1 MFMA group, 3 wasted) +// │ │corner│ +// └────────────────┴──────┘ +// +// Lane roles (inherit both layouts): +// 16×16×4 A (16×4): lane t → A[t%16][t/16] col-major +// B ( 4×16): lane t → B[t/16][t%16] row-major +// D (16×16): lane t acc[e] → D[(t/16)+4e][t%16] +// 4×4×4 (per group g): +// A_g[m=β][k=α] at S = 16α + 4g + β +// B_g[k=α][n=β] +// D_g[m=α][n=β] +// +// Optimisations on top of the vanilla hybrid: +// (a) Wide double4 distribute from HBM. +// (b) Pass 1 → Pass 2 register fusion for the main + right strip tiles. +// The K=16 kernel already uses this trick; we swap pass 1's main/right +// operands so the accumulator layout matches pass 2's A-frag directly: +// feed A_frag = buf, B_frag = B ⇒ D[m][n] = temp1[n][m] +// at lane t acc_main[e] = temp1[t%16][(t/16) + 4e] (main) +// at lane S acc_right = temp1[4g+β][16+α] (right strip) +// Pass 2 main slices 0..3 (j∈0..15) read acc_main[0..3] directly; slice 4 +// (j∈16..19) reads acc_right. Pass 2 right strip does the same — its +// A-frag layout at lane S is temp1[4g+β][4p+α] which matches acc_main +// for p=0..3 and acc_right for p=4. Bottom strip and corner can't be +// fused cleanly (α/β transpose mismatch) and still use LDS. +// ============================================================================ + +template +__global__ +__launch_bounds__(640, 1) +void transform_kernel_blocked_k20_hybrid(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*ws unused*/) +{ + static_assert(std::is_same::value, "K=20 blocked: FP64 only"); + + constexpr int K = 20; + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int NWAVES = 10; + constexpr int SPW = 2; + constexpr int K_SLICES = K / 4; // 5 + + extern __shared__ unsigned char _smem_hyb[]; + T* buf = reinterpret_cast(_smem_hyb); + + const int w = threadIdx.y; + const int t = threadIdx.x; + const int tid = w * 64 + t; + const int nthr = blockDim.x * blockDim.y; + + // 4×4×4 lane decomposition: S = 16α + 4g + β + const int beta = t & 3; + const int g = (t >> 2) & 3; + const int alpha = (t >> 4) & 3; + // 16×16×4 lane decomposition + const int tmod16 = t & 15; + const int tdiv16 = t >> 4; + + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // --- Wide (double4) distribute. One load per 4 doubles = 2000 loads total. + // a_ptr[] is hipMalloc'd (16-byte aligned) and K³ = 8000 is divisible by 4. + // Each thread issues at most ceil(2000/640) = 4 double4 loads. + { + const double4* a_ptr_vec = reinterpret_cast(a_ptr); + constexpr int NVEC = K3 / 4; // 2000 + for (int idx = tid; idx < NVEC; idx += nthr) { + double4 v = a_ptr_vec[idx]; + const int base = idx * 4; + buf[base + 0] = v.x; + buf[base + 1] = v.y; + buf[base + 2] = v.z; + buf[base + 3] = v.w; + } + } + __syncthreads(); + + // -------------------------------------------------------------------- + // PASS 1 (swapped operands) : keep acc_main and acc_right in registers + // across the sync so pass 2 can read them directly instead of going + // through LDS. See header comment for the layout derivation. + // + // With A_frag = buf and B_frag = B, the MFMA produces D = temp1^T. + // Concretely: at lane t acc_main[e] = temp1[t%16][(t/16) + 4e]. + // + // Bottom strip and corner are NOT swapped (their α/β transpose doesn't + // align with pass 2's bottom/corner A-frag layout) and still go to LDS. + // -------------------------------------------------------------------- + v4f64 p1_main[SPW]; // pass-1 main accumulator, survives to pass 2 + double p1_right[SPW]; // pass-1 right accumulator, survives to pass 2 + + { + double acc_bottom[SPW], acc_corner[SPW]; + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + p1_main[ss] = v4f64{0.0, 0.0, 0.0, 0.0}; + p1_right[ss] = 0.0; + acc_bottom[ss] = 0.0; + acc_corner[ss] = 0.0; + } + + #pragma unroll + for (int ks = 0; ks < K_SLICES; ++ks) { + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + + // main 16×16 (SWAPPED: A_frag=buf, B_frag=B) + // feeds D[m=(t/16)+4e][n=t%16] = temp1[t%16][(t/16)+4e] + { + const int gi = 4*ks + tdiv16; // global i (0..19) + const int gj = tmod16; // pass-1 "m" from buf = j + const double av = buf[gi * K2 + gj * K + kp]; + const double bv = B[gi * K + gj]; // pass-1 "n" from B = a + p1_main[ss] = mfma_16x16x4_f64_shared(av, bv, p1_main[ss]); + } + // right strip 16×4 (SWAPPED) + // at lane S acc_right = temp1[4g+β][16+α] + { + const int gi = 4*ks + alpha; + const int gj = 16 + beta; // j from buf = 16..19 + const int ga = 4*g + beta; // a from B = 0..15 + const double av = buf[gi * K2 + gj * K + kp]; + const double bv = B[gi * K + ga]; + p1_right[ss] = mfma_4x4x4_f64(av, bv, p1_right[ss]); + } + // bottom strip 4×16 (NOT swapped — goes to LDS) + { + const int gi = 4*ks + alpha; + const int ga = 16 + beta; + const int gj = 4*g + beta; + const double av = B[gi * K + ga]; + const double bv = buf[gi * K2 + gj * K + kp]; + acc_bottom[ss] = mfma_4x4x4_f64(av, bv, acc_bottom[ss]); + } + // corner 4×4 (NOT swapped — goes to LDS, group 0 only) + { + const int gi = 4*ks + alpha; + const int ga = 16 + beta; + const int gj = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? B[gi * K + ga] : 0.0; + const double bv = vg ? buf[gi * K2 + gj * K + kp] : 0.0; + acc_corner[ss] = mfma_4x4x4_f64(av, bv, acc_corner[ss]); + } + } + } + + __syncthreads(); // all pass-1 LDS reads done; about to overwrite buf + + // Write only bottom + corner to LDS. Main and right stay in + // (p1_main, p1_right) registers for the pass-2 fused reads. + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + { + const int ga = 16 + alpha; + const int gj = 4*g + beta; + buf[ga * K2 + gj * K + kp] = acc_bottom[ss]; + } + if (g == 0) { + const int ga = 16 + alpha; + const int gj = 16 + beta; + buf[ga * K2 + gj * K + kp] = acc_corner[ss]; + } + } + } + __syncthreads(); + + // -------------------------------------------------------------------- + // PASS 2 : D[a, b] = Σ_j temp1[a, j, kp] · B_filter[j, b] + // Main & right strip read A-frag from (p1_main, p1_right) registers: + // ks=0..3 → A-frag = p1_main[ss][ks] (j ∈ 0..15) + // ks=4 → A-frag = p1_right[ss] (j ∈ 16..19) + // Bottom strip and corner still read from LDS. + // -------------------------------------------------------------------- + { + v4f64 acc_main[SPW]; + double acc_right[SPW], acc_bottom[SPW], acc_corner[SPW]; + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + acc_main[ss] = v4f64{0.0, 0.0, 0.0, 0.0}; + acc_right[ss] = 0.0; + acc_bottom[ss] = 0.0; + acc_corner[ss] = 0.0; + } + + #pragma unroll + for (int ks = 0; ks < K_SLICES; ++ks) { + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + + // A-frag for main & right: from registers (fused). + // ks=0..3 → p1_main[ss][ks] at this lane (4g+β, α encoding). + // ks=4 → p1_right[ss] at this lane. + const double a_main_av = (ks < 4) ? p1_main[ss][ks] : p1_right[ss]; + // For the 4x4x4 right strip, the A-frag value at lane S is the + // same register value (the encoding matches — see header). + const double a_right_av = (ks < 4) ? p1_main[ss][ks] : p1_right[ss]; + + // main 16×16 (a∈0..15, b∈0..15) + { + const int gj = 4*ks + tdiv16; + const int gb = tmod16; + const double bv = B[gj * K + gb]; + acc_main[ss] = mfma_16x16x4_f64_shared(a_main_av, bv, acc_main[ss]); + } + // right strip 16×4 (a∈0..15, b∈16..19) + { + const int gj = 4*ks + alpha; + const int gb = 16 + beta; + const double bv = B[gj * K + gb]; + acc_right[ss] = mfma_4x4x4_f64(a_right_av, bv, acc_right[ss]); + } + // bottom strip 4×16 (a∈16..19, b∈0..15) — LDS read + { + const int ga = 16 + beta; + const int gj = 4*ks + alpha; + const int gb = 4*g + beta; + const double av = buf[ga * K2 + gj * K + kp]; + const double bv = B[gj * K + gb]; + acc_bottom[ss] = mfma_4x4x4_f64(av, bv, acc_bottom[ss]); + } + // corner 4×4 (a∈16..19, b∈16..19) — LDS read, group 0 only + { + const int ga = 16 + beta; + const int gj = 4*ks + alpha; + const int gb = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? buf[ga * K2 + gj * K + kp] : 0.0; + const double bv = vg ? B[gj * K + gb] : 0.0; + acc_corner[ss] = mfma_4x4x4_f64(av, bv, acc_corner[ss]); + } + } + } + + __syncthreads(); + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int ga = tdiv16 + 4*e; + const int gb = tmod16; + buf[ga * K2 + gb * K + kp] = acc_main[ss][e]; + } + { + const int ga = 4*g + alpha; + const int gb = 16 + beta; + buf[ga * K2 + gb * K + kp] = acc_right[ss]; + } + { + const int ga = 16 + alpha; + const int gb = 4*g + beta; + buf[ga * K2 + gb * K + kp] = acc_bottom[ss]; + } + if (g == 0) { + const int ga = 16 + alpha; + const int gb = 16 + beta; + buf[ga * K2 + gb * K + kp] = acc_corner[ss]; + } + } + } + __syncthreads(); + + // -------------------------------------------------------------------- + // PASS 3 : C[a, b, c] = Σ_k temp2[a, b, k] · B_filter[k, c] + // For each wave-owned a, do the hybrid GEMM over (b, c); inner k=k'. + // A_mfma[b, k'] = buf[a, b, k']; B_mfma[k', c] = B_filter[k', c] + // Output written directly to HBM. + // -------------------------------------------------------------------- + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int a = w + ss * NWAVES; // wave's owned row + + v4f64 acc_main = v4f64{0.0, 0.0, 0.0, 0.0}; + double acc_right = 0.0; + double acc_bottom = 0.0; + double acc_corner = 0.0; + + #pragma unroll + for (int ks = 0; ks < K_SLICES; ++ks) { + // main 16×16 (b∈0..15, c∈0..15) + { + const int gb = tmod16; + const int gkp = 4*ks + tdiv16; + const int gc = tmod16; + const double av = buf[a * K2 + gb * K + gkp]; + const double bv = B[gkp * K + gc]; + acc_main = mfma_16x16x4_f64_shared(av, bv, acc_main); + } + // right strip 16×4 (b∈0..15, c∈16..19) + { + const int gb = 4*g + beta; + const int gkp = 4*ks + alpha; + const int gc = 16 + beta; + const double av = buf[a * K2 + gb * K + gkp]; + const double bv = B[gkp * K + gc]; + acc_right = mfma_4x4x4_f64(av, bv, acc_right); + } + // bottom strip 4×16 (b∈16..19, c∈0..15) + { + const int gb = 16 + beta; + const int gkp = 4*ks + alpha; + const int gc = 4*g + beta; + const double av = buf[a * K2 + gb * K + gkp]; + const double bv = B[gkp * K + gc]; + acc_bottom = mfma_4x4x4_f64(av, bv, acc_bottom); + } + // corner 4×4 (b∈16..19, c∈16..19). Only group 0. + { + const int gb = 16 + beta; + const int gkp = 4*ks + alpha; + const int gc = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? buf[a * K2 + gb * K + gkp] : 0.0; + const double bv = vg ? B[gkp * K + gc] : 0.0; + acc_corner = mfma_4x4x4_f64(av, bv, acc_corner); + } + } + + // Writeback directly to HBM C. + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int gb = tdiv16 + 4*e; + const int gc = tmod16; + c_ptr[a * K2 + gb * K + gc] = acc_main[e]; + } + { + const int gb = 4*g + alpha; + const int gc = 16 + beta; + c_ptr[a * K2 + gb * K + gc] = acc_right; + } + { + const int gb = 16 + alpha; + const int gc = 4*g + beta; + c_ptr[a * K2 + gb * K + gc] = acc_bottom; + } + if (g == 0) { + const int gb = 16 + alpha; + const int gc = 16 + beta; + c_ptr[a * K2 + gb * K + gc] = acc_corner; + } + } +} + +// ============================================================================ +// 2-block-per-tensor SPLIT variant — padded LDS + B_lds +// +// Profile of the 1-block hybrid revealed 23-way LDS bank conflicts from the +// stride-K=20 access pattern (GCD(20, 32)=4). Adding a K+1 pad fixes the +// conflict but overflows 64 KB in a single block. Splitting the k'-axis +// across 2 blocks (each handles K/2=10 slabs) buys the padded layout *and* +// a B_lds cache: +// +// buf_padded : 10 · K · BK doubles = 10 · 20 · 21 = 4200 (33.6 KB) +// B_lds : K · K doubles = 400 ( 3.2 KB) +// total ( 36.8 KB) +// +// LDS layout: buf[kp_local · K · BK + i · BK + j], BK = 21. +// stride over i = BK (GCD(21,32)=1) -> conflict-free +// stride over j = 1 -> conflict-free +// stride over kp_local = K·BK=420 -> 8-way conflict (pass 3 contraction) +// +// Grid: (nfuncs, 2). Block 0 handles kp ∈ [0, 10); block 1 handles [10, 20). +// Each block computes a partial-sum over its k' range and atomic-adds to C. +// Caller must zero C before launching (submit helper does hipMemsetAsync). +// +// Pass 3 inner contraction has K_inner = 10 = 2 full 4-wide MFMA k-slices + 1 +// partial (α∈{0,1} valid, α∈{2,3} fed zero). Handled by a per-lane guard. +// ============================================================================ + +template +__global__ +__launch_bounds__(640, 2) +void transform_kernel_blocked_k20_split(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*ws unused*/) +{ + static_assert(std::is_same::value, "K=20 split: FP64 only"); + + constexpr int K = 20; + constexpr int K_HALF = K / 2; // 10 slabs per block + constexpr int BK = K + 1; // 21 - bank-conflict pad on j + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int K_SLICES_FULL = K / 4; // 5 (pass 1 & 2: full K) + // Pass 3 inner is only K_HALF=10 per block → 3 slices, last one partial. + constexpr int K_SLICES_P3 = (K_HALF + 3) / 4; // 3 + constexpr int WAVE_STRIDE = K * BK; // 20 * 21 = 420 + + extern __shared__ unsigned char _smem_split[]; + T* buf = reinterpret_cast(_smem_split); + T* B_lds = buf + K_HALF * WAVE_STRIDE; // after padded buf + + const int w = threadIdx.y; // 0..9 kp_local this wave owns + const int t = threadIdx.x; // 0..63 + const int tid = w * 64 + t; + const int nthr = blockDim.x * blockDim.y; // 640 + + // 4×4×4 lane decomposition + const int beta = t & 3; + const int g = (t >> 2) & 3; + const int alpha = (t >> 4) & 3; + // 16×16×4 lane decomposition + const int tmod16 = t & 15; + const int tdiv16 = t >> 4; + + const int cube = blockIdx.x; + const int block_half = blockIdx.y; // 0 or 1 + if (cube >= nfuncs) return; + const int kp_start = block_half * K_HALF; // 0 or 10 + + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // --- Cache B into B_lds --- + for (int idx = tid; idx < K2; idx += nthr) B_lds[idx] = B[idx]; + + // --- Distribute: load A[:, :, kp ∈ block's range] into buf --- + // HBM layout: A[i*K² + j*K + kp]. LDS: buf[kp_local * WAVE_STRIDE + i * BK + j]. + for (int idx = tid; idx < K_HALF * K2; idx += nthr) { + int kp_local = idx / K2; + int i = (idx / K) % K; + int j = idx % K; + int kp_g = kp_start + kp_local; + buf[kp_local * WAVE_STRIDE + i * BK + j] = a_ptr[i * K2 + j * K + kp_g]; + } + __syncthreads(); + + // ============ PASS 1 : D[a, j] = Σ_i B[i, a] · buf[i, j, kp] ============ + { + v4f64 acc_main[1]; + double acc_right = 0.0, acc_bottom = 0.0, acc_corner = 0.0; + acc_main[0] = v4f64{0.0, 0.0, 0.0, 0.0}; + + #pragma unroll + for (int ks = 0; ks < K_SLICES_FULL; ++ks) { + // main 16×16 (a∈0..15, j∈0..15) + { + const int gi = 4*ks + tdiv16; + const int ga = tmod16; + const double av = B_lds[gi * K + ga]; + const double bv = buf[w * WAVE_STRIDE + gi * BK + ga]; + acc_main[0] = mfma_16x16x4_f64_shared(av, bv, acc_main[0]); + } + // right strip 16×4 (a∈0..15, j∈16..19) + { + const int gi = 4*ks + alpha; + const int ga = 4*g + beta; + const int gj = 16 + beta; + const double av = B_lds[gi * K + ga]; + const double bv = buf[w * WAVE_STRIDE + gi * BK + gj]; + acc_right = mfma_4x4x4_f64(av, bv, acc_right); + } + // bottom strip 4×16 (a∈16..19, j∈0..15) + { + const int gi = 4*ks + alpha; + const int ga = 16 + beta; + const int gj = 4*g + beta; + const double av = B_lds[gi * K + ga]; + const double bv = buf[w * WAVE_STRIDE + gi * BK + gj]; + acc_bottom = mfma_4x4x4_f64(av, bv, acc_bottom); + } + // corner 4×4 (a∈16..19, j∈16..19) group 0 only + { + const int gi = 4*ks + alpha; + const int ga = 16 + beta; + const int gj = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? B_lds[gi * K + ga] : 0.0; + const double bv = vg ? buf[w * WAVE_STRIDE + gi * BK + gj] : 0.0; + acc_corner = mfma_4x4x4_f64(av, bv, acc_corner); + } + } + + __syncthreads(); + + // Writeback pass 1 → buf (same slab, reinterpret as (a, j, kp_local)) + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int ga = tdiv16 + 4*e; + const int gj = tmod16; + buf[w * WAVE_STRIDE + ga * BK + gj] = acc_main[0][e]; + } + { + const int ga = 4*g + alpha; + const int gj = 16 + beta; + buf[w * WAVE_STRIDE + ga * BK + gj] = acc_right; + } + { + const int ga = 16 + alpha; + const int gj = 4*g + beta; + buf[w * WAVE_STRIDE + ga * BK + gj] = acc_bottom; + } + if (g == 0) { + const int ga = 16 + alpha; + const int gj = 16 + beta; + buf[w * WAVE_STRIDE + ga * BK + gj] = acc_corner; + } + } + __syncthreads(); + + // ============ PASS 2 : D[a, b] = Σ_j temp1[a, j, kp] · B[j, b] ============ + { + v4f64 acc_main[1]; + double acc_right = 0.0, acc_bottom = 0.0, acc_corner = 0.0; + acc_main[0] = v4f64{0.0, 0.0, 0.0, 0.0}; + + #pragma unroll + for (int ks = 0; ks < K_SLICES_FULL; ++ks) { + // main 16×16 (a∈0..15, b∈0..15) + { + const int ga = tmod16; + const int gj = 4*ks + tdiv16; + const int gb = tmod16; + const double av = buf[w * WAVE_STRIDE + ga * BK + gj]; + const double bv = B_lds[gj * K + gb]; + acc_main[0] = mfma_16x16x4_f64_shared(av, bv, acc_main[0]); + } + // right strip 16×4 (a∈0..15, b∈16..19) + { + const int ga = 4*g + beta; + const int gj = 4*ks + alpha; + const int gb = 16 + beta; + const double av = buf[w * WAVE_STRIDE + ga * BK + gj]; + const double bv = B_lds[gj * K + gb]; + acc_right = mfma_4x4x4_f64(av, bv, acc_right); + } + // bottom strip 4×16 (a∈16..19, b∈0..15) + { + const int ga = 16 + beta; + const int gj = 4*ks + alpha; + const int gb = 4*g + beta; + const double av = buf[w * WAVE_STRIDE + ga * BK + gj]; + const double bv = B_lds[gj * K + gb]; + acc_bottom = mfma_4x4x4_f64(av, bv, acc_bottom); + } + // corner 4×4 (a∈16..19, b∈16..19) group 0 only + { + const int ga = 16 + beta; + const int gj = 4*ks + alpha; + const int gb = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? buf[w * WAVE_STRIDE + ga * BK + gj] : 0.0; + const double bv = vg ? B_lds[gj * K + gb] : 0.0; + acc_corner = mfma_4x4x4_f64(av, bv, acc_corner); + } + } + + __syncthreads(); + + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int ga = tdiv16 + 4*e; + const int gb = tmod16; + buf[w * WAVE_STRIDE + ga * BK + gb] = acc_main[0][e]; + } + { + const int ga = 4*g + alpha; + const int gb = 16 + beta; + buf[w * WAVE_STRIDE + ga * BK + gb] = acc_right; + } + { + const int ga = 16 + alpha; + const int gb = 4*g + beta; + buf[w * WAVE_STRIDE + ga * BK + gb] = acc_bottom; + } + if (g == 0) { + const int ga = 16 + alpha; + const int gb = 16 + beta; + buf[w * WAVE_STRIDE + ga * BK + gb] = acc_corner; + } + } + __syncthreads(); + + // ============ PASS 3 : partial[a, b, c] = Σ_{kp_local ∈ [0, K_HALF)} ============ + // buf[kp_local, a, b] · B[kp_start+kp_local, c] + // Output atomic-added to HBM C (caller pre-zeros C). + // + // Wave w owns 2 a-values: a ∈ {w, w + K_HALF} = {w, w+10}. + // K_inner = K_HALF = 10 → 3 MFMA k-slices (last partial, 2 of 4 valid). + #pragma unroll + for (int ss = 0; ss < 2; ++ss) { + const int a = w + ss * K_HALF; // wave's owned a-row + + v4f64 acc_main = v4f64{0.0, 0.0, 0.0, 0.0}; + double acc_right = 0.0, acc_bottom = 0.0, acc_corner = 0.0; + + #pragma unroll + for (int ks = 0; ks < K_SLICES_P3; ++ks) { + // For 16×16×4 MFMA, the K=4 inner lanes are the ones with tdiv16 ∈ [0, 4). + // gkp = 4*ks + tdiv16. Valid iff gkp < K_HALF. + const int gkp_main = 4*ks + tdiv16; + const bool v_main = (gkp_main < K_HALF); + + const int gkp_strip = 4*ks + alpha; + const bool v_strip = (gkp_strip < K_HALF); + + // main 16×16 (b∈0..15, c∈0..15) + // A_mfma[m=b][k=kp_local] = temp2[a_fixed, b, kp_local] + // = buf[kp_local, a_fixed, b] + // Lane t's (m, k) = (t%16, t/16) → b = t%16, kp_local offset = t/16 + { + const int gb = tmod16; + const int gc = tmod16; + const int kpg = kp_start + gkp_main; + const double av = v_main ? buf[gkp_main * WAVE_STRIDE + a * BK + gb] : 0.0; + const double bv = v_main ? B_lds[kpg * K + gc] : 0.0; + acc_main = mfma_16x16x4_f64_shared(av, bv, acc_main); + } + // right strip 16×4 (b∈0..15, c∈16..19) + { + const int gb = 4*g + beta; + const int gc = 16 + beta; + const int kpg = kp_start + gkp_strip; + const double av = v_strip ? buf[gkp_strip * WAVE_STRIDE + a * BK + gb] : 0.0; + const double bv = v_strip ? B_lds[kpg * K + gc] : 0.0; + acc_right = mfma_4x4x4_f64(av, bv, acc_right); + } + // bottom strip 4×16 (b∈16..19, c∈0..15) + { + const int gb = 16 + beta; + const int gc = 4*g + beta; + const int kpg = kp_start + gkp_strip; + const double av = v_strip ? buf[gkp_strip * WAVE_STRIDE + a * BK + gb] : 0.0; + const double bv = v_strip ? B_lds[kpg * K + gc] : 0.0; + acc_bottom = mfma_4x4x4_f64(av, bv, acc_bottom); + } + // corner 4×4 (b∈16..19, c∈16..19) + { + const int gb = 16 + beta; + const int gc = 16 + beta; + const int kpg = kp_start + gkp_strip; + const bool vg = (g == 0) && v_strip; + const double av = vg ? buf[gkp_strip * WAVE_STRIDE + a * BK + gb] : 0.0; + const double bv = vg ? B_lds[kpg * K + gc] : 0.0; + acc_corner = mfma_4x4x4_f64(av, bv, acc_corner); + } + } + + // --- atomic add partial to HBM C --- + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int gb = tdiv16 + 4*e; + const int gc = tmod16; + atomicAdd(&c_ptr[a * K2 + gb * K + gc], (double)acc_main[e]); + } + { + const int gb = 4*g + alpha; + const int gc = 16 + beta; + atomicAdd(&c_ptr[a * K2 + gb * K + gc], acc_right); + } + { + const int gb = 16 + alpha; + const int gc = 4*g + beta; + atomicAdd(&c_ptr[a * K2 + gb * K + gc], acc_bottom); + } + if (g == 0) { + const int gb = 16 + alpha; + const int gc = 16 + beta; + atomicAdd(&c_ptr[a * K2 + gb * K + gc], acc_corner); + } + } +} + +// ============================================================================ +// Dispatch helpers +// ============================================================================ + +template +inline size_type blocked_k20_shmem_size() { + return (size_type)(20 * 20 * 20 * sizeof(T)); // 64 KB at FP64 (1-block variants) +} + +template +inline size_type blocked_k20_split_shmem_size() { + // buf[K_HALF · K · BK] + B_lds[K · K] = 10·20·21 + 20·20 = 4600 doubles + return (size_type)((10 * 20 * 21 + 20 * 20) * sizeof(T)); // 36.8 KB +} + +template +inline Dim3 blocked_k20_blockdim() { + return Dim3(64, 10, 1); +} + +// Default path for K=20: hybrid 16×16×4 + 4×4×4 MFMA (1 block per tensor). +// +// The 2-block _split variant (with padded LDS + B_lds) was implemented and +// measured to be ~1.9× SLOWER at N=2048 despite eliminating bank conflicts +// (23.9 → 0.2-way). The overheads that killed it: +// * atomicAdd read-modify-write on C (HBM traffic 258 → 752 MB) +// * pre-zero hipMemsetAsync on C (+128 MB) +// * extra VALU for 420-stride indexing + partial-k-slice guards (+2.6×) +// Net: MFMA busy fell 11 → 6 % because the extra VALU starved them further. +// Conclusion: bank conflicts were a visible counter but not the critical path. +// +// Alternative kernels kept for reference / future experimentation: +// transform_kernel_blocked_k20_split (2 blocks + padded LDS + atomics) +// transform_kernel_blocked_k20_mfma (pure 4×4×4) +// transform_kernel_blocked_k20_scalar (correctness baseline) +template +inline void submit_transform_bench_blocked_k20(int nfuncs, + const T* A, const T* B, T* C, T* workspace, + Stream stream) +{ + Dim3 td = blocked_k20_blockdim(); + size_type smem = blocked_k20_shmem_size(); + CONFIGURE_KERNEL((transform_kernel_blocked_k20_hybrid), smem); + CALL_KERNEL((transform_kernel_blocked_k20_hybrid), nfuncs, td, smem, stream, + (nfuncs, A, B, C, workspace)); +} + +#endif // __HIP__ diff --git a/transform_blocked_rocwmma.h b/transform_blocked_rocwmma.h new file mode 100644 index 0000000..f65ee26 --- /dev/null +++ b/transform_blocked_rocwmma.h @@ -0,0 +1,196 @@ +#pragma once +#include +#include "util.h" + +// HIP/AMD-only — uses rocWMMA fragments / mma_sync. +#if defined(__HIP__) + +#if defined(__HIP_DEVICE_COMPILE__) +#include +#endif + +// transform_blocked_rocwmma.h — block-distributed 3D transform, rocWMMA variant. +// +// Mirrors the algorithm in transform_blocked.h (L7) but replaces the manual +// MFMA calls with rocwmma::mma_sync. rocWMMA hides the MFMA fragment layouts +// and emits the same underlying v_mfma_f64_16x16x4f64 on GFX90A. +// +// Layout differences from L7: +// * buf uses the per-wave layout buf[s*WB + i*BK + j] (not L7's canonical +// (i, j, k) form) so rocWMMA's standard row-major loads work directly +// with ldm = BK. +// * No register fusion between pass 1 and pass 2 (fragment contents are +// opaque to us), so pass 1's output goes through LDS. +// +// Thread block: K waves × 64 threads = 1024 at K=16. One block per tensor. +// LDS budget: buf (K × K × (K+1) padded) + B cache (K × K) = 36 KB at K=16. + +#if defined(__HIP__) && defined(__HIP_DEVICE_COMPILE__) +namespace rw = rocwmma; + +// Fragment types for f64 16×16×4 MFMA on GFX90A. +using FragA = rw::fragment; +using FragAT = rw::fragment; // used to load B as B^T +using FragB = rw::fragment; +using FragAcc = rw::fragment; +#endif + +template +__global__ +__launch_bounds__(K * 64, 1) +void transform_kernel_blocked_rocwmma(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*workspace unused*/) +{ +#if defined(__HIP__) && defined(__HIP_DEVICE_COMPILE__) + static_assert(std::is_same::value, "rocWMMA path is FP64 only"); + static_assert(K == 16, "rocWMMA path: K=16 only for now"); + + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int ELEMS_PER_LANE = K2 / 64; + constexpr int NMFMA = K / 4; + + // Bank-conflict pad identical to L7. + constexpr int BK = K + 1; + constexpr int WB = K * BK; + + extern __shared__ unsigned char _smem_rwmma[]; + T* buf = reinterpret_cast(_smem_rwmma); + T* B_lds = buf + K * WB; + + const int s = threadIdx.y; // wave index (0..K-1) + const int t = threadIdx.x; // lane 0..63 + const int tid = s * 64 + t; + + // --- Cache B into LDS (wide double4 loads, like L7) --- + static_assert(K2 % 4 == 0); + if (tid < K2 / 4) { + const double4* B_vec = reinterpret_cast(B); + double4* B_lds_vec = reinterpret_cast(B_lds); + B_lds_vec[tid] = B_vec[tid]; + } + __syncthreads(); + + // One block = one tensor. + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // --- Distribute: coalesced HBM load into per-wave layout --- + // Each thread loads 4 contiguous HBM doubles as a double4, placing them + // across 4 WAVES (since consecutive HBM elements have consecutive k, and + // wave s owns k=s). Per-wave layout: buf[s*WB + i*BK + j] = A[i, j, s]. + { + static_assert(K == 16); + const double4* a_ptr_vec = reinterpret_cast(a_ptr); + const int base_idx = 4 * tid; + const int i = base_idx >> 8; + const int j = (base_idx >> 4) & 15; + const int k_start = base_idx & 15; // 0, 4, 8, or 12 + + double4 v = a_ptr_vec[tid]; + // 4 cross-wave stores (destination wave = k_start + e for e=0..3). + buf[(k_start + 0) * WB + i * BK + j] = v.x; + buf[(k_start + 1) * WB + i * BK + j] = v.y; + buf[(k_start + 2) * WB + i * BK + j] = v.z; + buf[(k_start + 3) * WB + i * BK + j] = v.w; + } + __syncthreads(); + + // --- Pass 1 (SWAPPED operands): compute blk^T · B --- + // A_frag = blk^T: col_major load of the per-wave blk region. + // B_frag = B: row_major load. + // Output D[j][a] = temp1[a][j]. In the fragment, thread t's acc1.x[e] = + // temp1[t%16][(t/16) + 4e] = temp1[t%16][p*4 + t/16] (commutative). + // That is EXACTLY the value pass 2's matrix_a fragment wants at iter p=e, + // so we can feed acc1 straight into pass 2 without any LDS round-trip. + FragAcc acc1; + rw::fill_fragment(acc1, 0.0); + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + FragAT a_frag; // col_major → blk^T + FragB b_frag; // B + rw::load_matrix_sync(a_frag, &buf[s * WB + p * 4 * BK], BK); + rw::load_matrix_sync(b_frag, &B_lds[p * 4 * K], K); + rw::mma_sync(acc1, a_frag, b_frag, acc1); + } + // No LDS writeback. acc1.x[p] is pass 2's a_frag value at iter p. + + // --- Pass 2: temp2 = temp1 · B (consumes acc1 directly) --- + FragAcc acc2; + rw::fill_fragment(acc2, 0.0); + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + FragA a_frag_p; + // Populate matrix_a's per-lane element from acc1's p-th acc value. + a_frag_p.x[0] = acc1.x[p]; + FragB b_frag; + rw::load_matrix_sync(b_frag, &B_lds[p * 4 * K], K); + rw::mma_sync(acc2, a_frag_p, b_frag, acc2); + } + // acc2 holds pass 2 output (temp2) in the same per-lane layout as L7's + // acc2 — so acc2.x[e] IS the corner-turn stash for iter e. + + // --- Corner turn: direct cross-wave write from acc2 (no stash read) --- + __syncthreads(); // all waves done with pass-1 reads from own region + #pragma unroll + for (int e = 0; e < ELEMS_PER_LANE; ++e) { + int idx = t + e * 64; + int a_ix = idx / K; + int b_ix = idx % K; + buf[a_ix * WB + b_ix * BK + s] = acc2.x[e]; + } + __syncthreads(); + + // --- Pass 3: result = temp2_reshuffled · B --- + // Wave s now represents a=s. buf[s*WB + b*BK + k'] = temp2[a=s, b, k']. + FragAcc acc3; + rw::fill_fragment(acc3, 0.0); + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + FragA a_frag; + FragB b_frag; + rw::load_matrix_sync(a_frag, &buf[s * WB + p * 4], BK); + rw::load_matrix_sync(b_frag, &B_lds[p * 4 * K], K); + rw::mma_sync(acc3, a_frag, b_frag, acc3); + } + + // --- Store to HBM --- + // Canonical C[a*K² + b*K + c] with a=s. + rw::store_matrix_sync(&c_ptr[(size_t)s * K2], acc3, K, rw::mem_row_major); +#endif // __HIP__ && __HIP_DEVICE_COMPILE__ +} + +template +inline size_type blocked_rocwmma_shmem_size(int K) { + return (size_type)((K * K * (K + 1) + K * K) * sizeof(T)); +} + +template +inline Dim3 blocked_rocwmma_blockdim(int K) { + return Dim3(64, K, 1); +} + +template +inline void submit_transform_bench_blocked_rocwmma(int nfuncs, int /*nblocks*/, int K, + const T* A, const T* B, T* C, T* workspace, + Stream stream) +{ + if (K == 16) { + constexpr int Kv = 16; + Dim3 td = blocked_rocwmma_blockdim(Kv); + size_type smem = blocked_rocwmma_shmem_size(Kv); + CONFIGURE_KERNEL((transform_kernel_blocked_rocwmma), smem); + CALL_KERNEL((transform_kernel_blocked_rocwmma), nfuncs, td, smem, stream, + (nfuncs, A, B, C, workspace)); + } else { + fprintf(stderr, "blocked rocwmma transform: K=%d not supported (K=16 only)\n", K); + assert(false); + } +} + +#endif // __HIP__ diff --git a/transformbench.cu b/transformbench.cu index 81854e0..d1cc222 100644 --- a/transformbench.cu +++ b/transformbench.cu @@ -12,16 +12,23 @@ #include "transform_level7.h" #include "transform_kron.h" #include "mxm_cublasdx.h" +#include "transform_blocked.h" /* L9, L11 (HIP-only, see __HIP__ guards) */ +#include "transform_blocked_rocwmma.h" /* L10 (HIP-only) */ #include "util.h" /** * Optimization levels: - * 1 - L1: thread-parallel over j, serial k-loop, all global memory (mxm.h fallback) - * 2 - L2: B in LDS, threads distributed over rows - * 3 - L3: B in LDS + register accumulation (acc[K] in VGPRs) - * 4 - L4: AMD MFMA (GFX90A/GFX940) for K=16,32; falls back to L3 elsewhere - * 5 - L5: cuBLASDx (NVIDIA only, double-buffered block GEMM with Tensor Cores) - * 6 - L6: Single GEMM via K³×K³ Kronecker product (B^T ⊗ B^T ⊗ B^T) + * 1 - L1: thread-parallel over j, serial k-loop, all global memory (mxm.h fallback) + * 2 - L2: B in LDS, threads distributed over rows + * 3 - L3: B in LDS + register accumulation (acc[K] in VGPRs) + * 4 - L4: AMD MFMA (GFX90A/GFX940) for K=16,32; falls back to L3 elsewhere + * 5 - L5: cuBLASDx (NVIDIA only, double-buffered block GEMM with Tensor Cores) + * 6 - L6: rocWMMA (HIP) + * 7 - L7: multi-wave MFMA, B resident in VGPRs (HIP) + * 8 - L8: Single GEMM via K³×K³ Kronecker product (B^T ⊗ B^T ⊗ B^T) + * 9 - L9: block-distributed MFMA transform, K=16 (HIP) -- "blocked" + * 10 - L10: block-distributed rocWMMA transform, K=16 (HIP) -- "blocked-rocwmma" + * 11 - L11: block-distributed K=20 hybrid MFMA (4x4x4 + 16x16x4) (HIP) -- "blocked-k20" */ template @@ -51,15 +58,18 @@ void transform_bench(int nreps, int ntasks, int nfuncs, int nblocks, int K, int } const char* level_names[] = { - "", /* unused [0] */ - "L1-global", /* 1 */ - "L2-lds_b", /* 2 */ - "L3-regblk", /* 3 */ - "L4-mfma", /* 4 */ - "L5",/* 5 */ - "L6-rocwmma",/* 6 */ - "L7-builtins",/* 7 */ - "L8-kron" /* 8 */ + "", /* unused [0] */ + "L1-global", /* 1 */ + "L2-lds_b", /* 2 */ + "L3-regblk", /* 3 */ + "L4-mfma", /* 4 */ + "L5", /* 5 */ + "L6-rocwmma", /* 6 */ + "L7-builtins", /* 7 */ + "L8-kron", /* 8 */ + "L9-blocked", /* 9 */ + "L10-blocked-rocwmma", /* 10 */ + "L11-blocked-k20", /* 11 */ }; /* Print shmem and thread dims for this level */ @@ -98,6 +108,20 @@ void transform_bench(int nreps, int ntasks, int nfuncs, int nblocks, int K, int smem_size = kron_shmem_size(K); thread_dims = kron_blockdim(K); break; +#if defined(__HIP__) + case 9: + smem_size = (int)blocked_shmem_size(K); + thread_dims = blocked_blockdim(K); + break; + case 10: + smem_size = (int)blocked_rocwmma_shmem_size(K); + thread_dims = blocked_rocwmma_blockdim(K); + break; + case 11: + smem_size = (int)blocked_k20_shmem_size(); + thread_dims = blocked_k20_blockdim(); + break; +#endif // __HIP__ } /* Level 8: build Kronecker matrix once, before the timing loop */ @@ -144,6 +168,17 @@ void transform_bench(int nreps, int ntasks, int nfuncs, int nblocks, int K, int case 8: submit_transform_kron_bench(nfuncs, K, A, KronMat, C, blas_handle, streams[t%num_streams]); break; +#if defined(__HIP__) + case 9: + submit_transform_bench_blocked(nfuncs, nblocks, K, A, B, C, workspace, streams[t%num_streams]); + break; + case 10: + submit_transform_bench_blocked_rocwmma(nfuncs, nblocks, K, A, B, C, workspace, streams[t%num_streams]); + break; + case 11: + submit_transform_bench_blocked_k20(nfuncs, A, B, C, workspace, streams[t%num_streams]); + break; +#endif // __HIP__ } } for (int t = 0; t < num_streams; ++t) { @@ -193,7 +228,7 @@ int main(int argc, char **argv) { int N = opt.parse("-N", 2048); /* number of functions */ int K = opt.parse("-K", 16); /* number of coefficients */ int M = opt.parse("-M", 512); /* max number of blocks */ - int level = opt.parse("-l", 0); /* 0 = auto, 1-5 = explicit */ + int level = opt.parse("-l", 0); /* 0 = auto, 1-11 = explicit (9-11 are HIP-only blocked variants) */ int num_streams = opt.parse("-s", 4); /* number of concurrent streams to use */ /* Legacy -m flag: force level 1 */ diff --git a/validate.hip b/validate.hip new file mode 100644 index 0000000..c17cb39 --- /dev/null +++ b/validate.hip @@ -0,0 +1,263 @@ +/** + * validate.hip — compare GPU L1 output against the CPU reference transform. + * + * CPU reference mirrors madness/src/madness/tensor/transform3d.cc :: transform3d(): + * + * result(a,b,c) = sum(i',j',k') A(i',j',k') C(i',a) C(j',b) C(k',c) + * + * implemented as three sequential mTxm passes: + * mTxm(K², K, K, r, A, B) pass 1: contract dim-0 + * mTxm(K², K, K, tmp, r, B) pass 2: contract dim-1 + * mTxm(K², K, K, r, tmp, B) pass 3: contract dim-2 + * + * where mTxm(dimi, dimj, dimk, c, a, b): + * c(i,j) += sum(k) a(k,i) * b(k,j) + */ + +#include +#include +#include +#include +#include + +#include "transform.h" +#include "util.h" + +// --------------------------------------------------------------------------- +// CPU reference +// --------------------------------------------------------------------------- + +// c(i,j) += sum(k) a(k,i) * b(k,j) (from MADNESS mTxm) +static void cpu_mTxm(int dimi, int dimj, int dimk, + double* c, const double* a, const double* b) { + for (int i = 0; i < dimi; ++i) { + double* ci = c + i * dimj; + const double* aik = a + i; // a(k,i) strides by dimi + for (int k = 0; k < dimk; ++k, aik += dimi) { + double aki = *aik; + for (int j = 0; j < dimj; ++j) + ci[j] += aki * b[k * dimj + j]; + } + } +} + +// Three-pass transform matching transform3d.cc +static void cpu_transform3d(int K, const double* A, const double* B, double* result) { + int K3 = K * K * K; + std::vector tmp(K3, 0.0); + std::fill(result, result + K3, 0.0); + + cpu_mTxm(K * K, K, K, result, A, B); // pass 1 + cpu_mTxm(K * K, K, K, tmp.data(), result, B); // pass 2 + std::fill(result, result + K3, 0.0); + cpu_mTxm(K * K, K, K, result, tmp.data(), B); // pass 3 +} + +// --------------------------------------------------------------------------- +// Block-distributed transform (corner-turn variant). +// +// Block A along the fastest tensor index k': block s owns A[:,:,k'=s] as a +// K x K matrix indexed by (i', j'). +// +// Pass 1 (contract i'): local. blk_s <- B^T . blk_s -> [a, j'] +// Pass 2 (contract j'): local. blk_s <- blk_s . B -> [a, b] +// Pass 3 (contract k'): all-to-all "corner turn" where block t pulls row +// a=t from every other block, storing into the +// (b, k') layout. Then local right-multiply: +// blk_t <- blk_t . B -> [b, c] +// +// Applying B on the right in pass 3 makes the output land in canonical (b, c) +// order directly; the in-block transpose is absorbed into the corner turn for +// free (same number of loads/stores, different destination cell). +// --------------------------------------------------------------------------- + +static void print_blocks(const char* label, int K, + const std::vector>& blk) { + std::cout << " " << label << ":\n"; + for (int s = 0; s < (int)blk.size(); ++s) { + std::cout << " block " << s << ":\n"; + for (int i = 0; i < K; ++i) { + std::cout << " "; + for (int j = 0; j < K; ++j) + std::cout << " " << std::setw(8) << blk[s][i*K + j]; + std::cout << "\n"; + } + } +} + +static void cpu_transform3d_blocked(int K, const double* A, const double* B, + double* out, bool trace = false) { + const int K2 = K * K; + std::vector> blk(K, std::vector(K2)); + + // distribute: block s := A[:,:,k=s] (strided read from the flat layout) + for (int s = 0; s < K; ++s) + for (int i = 0; i < K; ++i) + for (int j = 0; j < K; ++j) + blk[s][i*K + j] = A[i*K2 + j*K + s]; + if (trace) print_blocks("after distribute (block_s[i,j] = A[i,j,s])", K, blk); + + // Pass 1: contract i'. blk_s <- B^T . blk_s -> indexed [a, j] + for (int s = 0; s < K; ++s) { + std::vector t(K2, 0.0); + for (int a = 0; a < K; ++a) + for (int j = 0; j < K; ++j) + for (int i = 0; i < K; ++i) + t[a*K + j] += B[i*K + a] * blk[s][i*K + j]; + blk[s] = std::move(t); + } + if (trace) print_blocks("after pass 1 (block_s[a,j])", K, blk); + + // Pass 2: contract j'. blk_s <- blk_s . B -> indexed [a, b] + for (int s = 0; s < K; ++s) { + std::vector t(K2, 0.0); + for (int a = 0; a < K; ++a) + for (int b = 0; b < K; ++b) + for (int j = 0; j < K; ++j) + t[a*K + b] += blk[s][a*K + j] * B[j*K + b]; + blk[s] = std::move(t); + } + if (trace) print_blocks("after pass 2 (block_s[a,b], fixed k'=s)", K, blk); + + // Corner turn: block t pulls row a=t from every block, stored as (b, k'). + // Before: blk[s][a, b] (block-index = k' = s) + // After : blk[t][b, k] (block-index = a = t) + { + std::vector> exch(K, std::vector(K2)); + for (int t = 0; t < K; ++t) + for (int b = 0; b < K; ++b) + for (int k = 0; k < K; ++k) + exch[t][b*K + k] = blk[k][t*K + b]; + blk = std::move(exch); + } + if (trace) print_blocks("after corner turn (block_t[b,k], fixed a=t)", K, blk); + + // Pass 3: contract k'. blk_t <- blk_t . B -> indexed [b, c] + for (int t = 0; t < K; ++t) { + std::vector r(K2, 0.0); + for (int b = 0; b < K; ++b) + for (int c = 0; c < K; ++c) + for (int k = 0; k < K; ++k) + r[b*K + c] += blk[t][b*K + k] * B[k*K + c]; + blk[t] = std::move(r); + } + if (trace) print_blocks("after pass 3 (block_t[b,c] = result[t,b,c])", K, blk); + + // Canonical write-out: block_a holds result[a, :, :] in (b, c) order. + for (int a = 0; a < K; ++a) + for (int b = 0; b < K; ++b) + for (int c = 0; c < K; ++c) + out[a*K2 + b*K + c] = blk[a][b*K + c]; +} + +// K=2 walkthrough with the counting tensor A[i,j,k] = 100i + 10j + k and B = I. +// With B = I the GEMMs become no-ops, so every stage just shows the current +// distribution of elements across blocks -- the shuffle pattern is laid bare. +// Expected final output == A (since B = I acts as identity in all three passes). +static void debug_blocked_K2() { + const int K = 2, K2 = 4, K3 = 8; + + std::vector A(K3); + for (int i = 0; i < K; ++i) + for (int j = 0; j < K; ++j) + for (int k = 0; k < K; ++k) + A[i*K2 + j*K + k] = 100.0*i + 10.0*j + k; + + std::vector B(K2, 0.0); + for (int i = 0; i < K; ++i) B[i*K + i] = 1.0; + + std::cout << "=== debug_blocked_K2 (A[i,j,k] = 100i+10j+k, B = I) ===\n"; + std::cout << " B = I makes every GEMM a no-op; the only motion is the shuffle.\n"; + std::cout << " expected final output == A.\n"; + + std::vector out(K3, 0.0); + cpu_transform3d_blocked(K, A.data(), B.data(), out.data(), /*trace=*/true); + + double max_err = 0.0; + for (int i = 0; i < K3; ++i) + max_err = std::max(max_err, std::abs(out[i] - A[i])); + std::cout << " final max_abs_err vs A = " << max_err + << (max_err < 1e-10 ? " MATCH" : " DIFFER") << "\n\n"; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char** argv) { + auto opt = OptionParser(argc, argv); + int K = opt.parse("-K", 16); + int N = opt.parse("-N", 8); + int debug = opt.parse("-debug", 0); + + if (debug) { + debug_blocked_K2(); + return 0; + } + + const int K3 = K * K * K; + + // Deterministic input + std::vector h_A(N * K3); + std::vector h_B(K * K); + for (int i = 0; i < N * K3; ++i) h_A[i] = (double)(i % 13) * 0.1 - 0.6; + for (int i = 0; i < K * K; ++i) h_B[i] = (double)(i % 7) * 0.1 - 0.3; + + // CPU reference (3D tensor transform) + std::vector h_ref(N * K3); + for (int n = 0; n < N; ++n) + cpu_transform3d(K, h_A.data() + n * K3, h_B.data(), h_ref.data() + n * K3); + + // Blocked algorithm: per-tensor distribute along k', corner-turn for pass 3 + std::vector h_blk(N * K3, 0.0); + for (int n = 0; n < N; ++n) + cpu_transform3d_blocked(K, h_A.data() + n * K3, h_B.data(), h_blk.data() + n * K3); + + double blk_max_err = 0.0; + for (int i = 0; i < N * K3; ++i) + blk_max_err = std::max(blk_max_err, std::abs(h_blk[i] - h_ref[i])); + std::cout << "K=" << K << " N=" << N + << " blk_vs_ref max_abs_err=" << blk_max_err + << (blk_max_err < 1e-10 ? " MATCH" : " DIFFER") + << std::endl; + + // GPU L1 + double *d_A, *d_B, *d_C, *d_workspace; + MALLOC(&d_A, N * K3 * sizeof(double)); + MALLOC(&d_B, K * K * sizeof(double)); + MALLOC(&d_C, N * K3 * sizeof(double)); + MALLOC(&d_workspace, K3 * sizeof(double)); // 1 block + + MEMCPY_H2D(d_A, h_A.data(), N * K3 * sizeof(double)); + MEMCPY_H2D(d_B, h_B.data(), K * K * sizeof(double)); + + Stream stream; + CREATE_STREAM(&stream); + submit_transform_bench(N, /*nblocks=*/1, K, d_A, d_B, d_C, d_workspace, stream); + SYNC_STREAM(stream); + + std::vector h_gpu(N * K3); + MEMCPY_D2H(h_gpu.data(), d_C, N * K3 * sizeof(double)); + + FREE(d_A); FREE(d_B); FREE(d_C); FREE(d_workspace); + + // Compare + double max_abs_err = 0.0, max_ref_val = 0.0; + for (int i = 0; i < N * K3; ++i) { + double err = std::abs(h_gpu[i] - h_ref[i]); + double ref = std::abs(h_ref[i]); + if (err > max_abs_err) max_abs_err = err; + if (ref > max_ref_val) max_ref_val = ref; + } + double rel_err = (max_ref_val > 0.0) ? max_abs_err / max_ref_val : max_abs_err; + + bool pass = rel_err < 1e-10; + std::cout << "K=" << K + << " N=" << N + << " max_abs_err=" << max_abs_err + << " max_rel_err=" << rel_err + << (pass ? " PASS" : " FAIL") + << std::endl; + + return pass ? 0 : 1; +} diff --git a/validate_levels.hip b/validate_levels.hip index de93f3e..21e260a 100644 --- a/validate_levels.hip +++ b/validate_levels.hip @@ -4,13 +4,17 @@ * Usage: * ./validate_levels [-l ] [-K ] [-N ] * - * -l level to validate (2-7, default 3) - * 2 L2: B cached in LDS - * 3 L3: register blocking (K-templated) - * 4 L4: AMD MFMA + L3 fallback - * 5 L5: wave-specialised double-buffering MFMA (HIP only) - * 6 L6: Kronecker product GEMM (hipBLAS) - * 7 L7: multi-wave MFMA, B resident in VGPRs (HIP only) + * -l level to validate (2-11, default 3) + * 2 L2: B cached in LDS + * 3 L3: register blocking (K-templated) + * 4 L4: AMD MFMA + L3 fallback + * 5 L5: wave-specialised double-buffering MFMA (HIP only) + * 6 L6: rocWMMA + * 7 L7: multi-wave MFMA, B resident in VGPRs (HIP only) + * 8 L8: Kronecker product GEMM (hipBLAS) + * 9 L9: block-distributed MFMA, K=16 (HIP only) + * 10 L10: block-distributed rocWMMA, K=16 (HIP only) + * 11 L11: block-distributed K=20 hybrid MFMA (HIP only) * -K single K value; if omitted sweeps K in {4,6,8,10} * -N batch size (default 16) */ @@ -32,7 +36,9 @@ # include "transform_cublasdx.h" // L5 — cuBLASDx (NVIDIA only) #endif #include "transform_rocwmma.h" -#include "transform_kron.h" // L6 +#include "transform_kron.h" // L8 +#include "transform_blocked.h" // L9, L11 (HIP-only) +#include "transform_blocked_rocwmma.h" // L10 (HIP-only) template void test_level(int level, int K, int nfuncs) { @@ -107,8 +113,22 @@ void test_level(int level, int K, int nfuncs) { blasDestroy(blas_handle); break; } +#if defined(__HIP__) + case 9: + submit_transform_bench_blocked(nfuncs, nblocks, K, d_A, d_B, d_Ctest, d_workspace, stream); + SYNC_STREAM(stream); + break; + case 10: + submit_transform_bench_blocked_rocwmma(nfuncs, nblocks, K, d_A, d_B, d_Ctest, d_workspace, stream); + SYNC_STREAM(stream); + break; + case 11: + submit_transform_bench_blocked_k20(nfuncs, d_A, d_B, d_Ctest, d_workspace, stream); + SYNC_STREAM(stream); + break; +#endif // __HIP__ default: - std::cerr << "Unknown level " << level << " (valid: 2-7)\n"; + std::cerr << "Unknown level " << level << " (valid: 2-11)\n"; FREE(d_A); FREE(d_B); FREE(d_Cref); FREE(d_Ctest); FREE(d_workspace_ref); FREE(d_workspace); if (d_KronMat) FREE(d_KronMat);