diff --git a/.clangd b/.clangd index 77e3d1a..a27f84d 100644 --- a/.clangd +++ b/.clangd @@ -1,7 +1,12 @@ CompileFlags: - # compile_commands.json already has all flags; nothing to add/remove Remove: - - -O3 # clangd doesn't need optimization; speeds up indexing + - -O3 # clangd doesn't need optimization; speeds up indexing + - --offload-arch=* # offload flags confuse clangd's host-side indexer + Add: + - -x + - hip # ensure .hip/.h files are parsed as HIP (C++) + - -DMRA_HAVE_HIP=1 + - -I/opt/rocm-7.2.2/include Index: Background: Build @@ -10,3 +15,4 @@ Diagnostics: Suppress: - pp_including_mainfile_in_preamble - unknown_builtin + - err_implicit_function_declaration # HIP device builtins not visible to host parser diff --git a/.gitignore b/.gitignore index 7f3e00b..67a5fbe 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,4 @@ clangd* compile_commands.json rocroof/ +prof/ diff --git a/BLOCKED_TRANSFORM.md b/BLOCKED_TRANSFORM.md new file mode 100644 index 0000000..c4dbf21 --- /dev/null +++ b/BLOCKED_TRANSFORM.md @@ -0,0 +1,508 @@ +# Blocked 3D Transform — Corner-Turn Algorithm + +A data-movement-conscious restructuring of the MRA 3D transform, designed so +each of K thread-block-equivalents ("wavefronts" on AMD) owns one K×K slab of +the tensor throughout the computation. Passes 1 and 2 are fully local per +block; pass 3 requires a single all-to-all "corner turn" that transposes the +block-level ownership from `k'` to `a`. Pass 3 right-multiplies by B so the +output lands in canonical (b, c) order with no post-transpose. + +Reference implementations +- CPU (C++): `validate.hip` → `cpu_transform3d_blocked` +- NumPy: `mra_python/algorithms.py` → `transform_nd_blocked` + +Both match `cpu_transform3d` / `transform_nd` bit-exactly (CPU) or to +floating-point noise (NumPy). + +--- + +## The canonical operation + +We compute the 3D transform + +``` +result[a, b, c] = Σ_{i, j, k} A[i, j, k] · B[i, a] · B[j, b] · B[k, c] +``` + +via three sequential contractions, one per tensor axis. Standard formulations +apply this as three K²×K GEMMs over the whole tensor, reloading the input +through HBM on every pass. The blocked algorithm keeps each K×K slab +resident in registers/LDS of a single wavefront for all three passes, paying +one inter-wave exchange instead of three HBM round-trips. + +--- + +## Visual language + +A consistent set of conventions used across every frame: + +- **Axis colors.** `i'/a` red, `j'/b` green, `k'/c` blue. +- **Blocks are K×K squares.** Row-bar and column-bar outside each square + name the axis currently playing that role. +- **Solid box = data, dashed box = operation.** +- **Arrows carry data.** + +--- + +## Frame 1 — canonical operation + +> ![Frame 1](frames/frame1_canonical.png) + +A K×K×K input cube, a K×K matrix `B`, and a K×K×K output cube, with three +curved arrows labeled "contract `i'`", "contract `j'`", "contract `k'`". +Provides grounding; this is what all later frames compute. + +--- + +## Frame 2 — distribute (slice along `k'`) + +> ![Frame 2](frames/frame2_distribute.png) + +The input cube is sliced along the fastest axis into K independent K×K +squares. Block `s` owns the slice `A[:, :, k'=s]`, with row = `i'` and +col = `j'`. + +This establishes the row-of-K-squares motif used in every subsequent frame. + +--- + +## Frame 3 — passes 1 and 2 (local GEMMs) + +> ![Frame 3](frames/frame3_local_gemms.png) + +Two stages stacked vertically on one frame: + +``` +row of K squares ──(× B^T on left)──► row of K squares [Pass 1] + rows: i', cols: j' rows: a, cols: j' + +row of K squares ──(× B on right)──► row of K squares [Pass 2] + rows: a, cols: j' rows: a, cols: b +``` + +Both passes are **local** — no wavefront touches another wavefront's data. +A small "✓ local" badge on each block emphasizes this. + +--- + +## Frame 4 — the corner turn (the money slide) + +> ![Frame 4](frames/frame4_corner_turn.png) + +The critical data-movement step. Shown as two rows of K blocks with colored +rows flowing between them. + +**Before** (top): K blocks, each with rows colored red/green/blue/... +(= `a = 0, 1, 2, ...`). Every block has the same row-color pattern because +`a` is the within-block row index. + +**After** (bottom): K blocks, each entirely one color. Block 0 all red, +block 1 all green, block 2 all blue. `a` is now the **block** index. +**Each arriving row is stored as a column**, so within each block the row +index is `b` and the column index is `k'`. This orientation sets pass 3 +up to right-multiply by B and land in canonical order directly. + +The operation is a **block-level transpose of a K×K super-matrix** whose +cells are K-vectors indexed by `b`, with an in-block transpose rolled into +the receive address (free — same data movement, different destination cell). + +Callout on the frame: *"block-level transpose: block index ↔ within-block +row."* This is the single sentence the audience should leave with. + +The main visual shows only the flow of one color (e.g. red rows, all +destined for block `a=0`) to keep the arrow count manageable; an inset +shows the full K=2 picture with all four arrows. + +--- + +## Frame 5 — pass 3 (no un-shuffle needed) + +> ![Frame 5](frames/frame5_pass3.png) + +``` +row of K squares ──(× B on right)──► canonical result + rows: b, cols: k' rows: b, cols: c +``` + +Points to emphasize: + +- Block index throughout this frame is `a` (inherited from the corner turn). +- **Right-multiply by B contracts `k'` → `c` in the column slot**, so the + output is already in canonical (b, c) order. No per-block transpose on + store — it was absorbed into the corner turn. +- Passes 2 and 3 now share the same GEMM orientation (B on the right); + only pass 1 is the odd one out. One fewer MFMA microkernel variant to + maintain on GPU. + +--- + +## Frame 6 — the bookkeeping table (takeaway slide) + +> ![Frame 6](frames/frame6_table.png) + +| stage | block idx | within-block row | within-block col | +|--------------------|-----------|------------------|------------------| +| distribute | k' | i' | j' | +| after pass 1 | k' | **a** | j' | +| after pass 2 | k' | a | **b** | +| after corner turn | **a** | **b** | **k'** | +| after pass 3 | a | b | **c** | + +Bold cells mark which slot changed on each row. Two annotations: + +- Rows 1–2 (pass 1) and row 5 (pass 3): *"local GEMM: one slot updates."* +- Row 4: *"corner turn: block slot ↔ row slot, and row ↔ col inside the block."* + +--- + +## Cost summary + +| item | cost per transform | +|--------------------------------------|---------------------------------------| +| compute (mathematical minimum) | 3 · 2 · K⁴ FLOPs | +| HBM reads of A | K³ doubles (once, at distribute) | +| HBM writes of result | K³ doubles (once, at pass-3 store) | +| HBM reads of B | K² doubles (once, cached in LDS) | +| inter-wave exchange (corner turn) | K³ doubles (via LDS, one pass only) | + +Comparison against the naive 3-GEMM implementation is the whole point: HBM +traffic for A drops from 3·K³ to 1·K³, and the compute path is exactly the +same three K²×K contractions — just re-organized. + +--- + +## K=2 walkthrough (self-checking) + +Run `./validate -debug 1` in the build directory. Uses the counting tensor +`A[i, j, k] = 100i + 10j + k` and `B = I`, so every GEMM is a no-op and only +the corner turn produces visible motion. Each stage dumps all K blocks; the +final line is self-checking (`MATCH` / `DIFFER`). + +Particularly useful as a "does this still work" smoke test when porting to +GPU. + +--- + +## GPU implementation — measured perf on MI250X (K=16, one GCD) + +Wired up as level 7. One block per tensor; grid = `nfuncs`. +Thread block = 64 × K threads = 1024 at K=16 (one wave per K×K slab). + +Step-by-step results through the optimization arc (N=2048, FP64): + +| version | GF/s | vs. prev | +|-----------------------------------------------|-------:|---------:| +| First cut (double-buffer LDS, B from HBM) | 720 | — | +| Single-buffer LDS + B cached in LDS | 2270 | 3.15× | +| + LDS row pad (kill 16-way bank conflict) | 3320 | 1.46× | +| + drop same-wave-only barriers (8 → 3) | 3460 | 1.04× | +| + fuse pass-2 store with corner-turn write | 3475 | 1.00× | +| + swap pass-1 operands → pass-2 reads acc | 3475 | 1.00× | +| + coalesced cooperative distribute | 7400 | 2.13× | +| + one block per tensor (grid = nfuncs) | 8880 | 1.20× | +| + `double4` loads for distribute + B cache | **9340** | 1.05× | + +Final result: **9340 GF/s ≈ 20 % of scalar FP64 peak ≈ 10 % of MFMA peak**. +More importantly, **86 % of the HBM-bandwidth roofline** at the kernel's +arithmetic intensity (AI = 6.14 FLOPs/byte). See `roofline.py` / `roofline.png`. + +### What we learned (not all optimizations paid off) + +A few commits are worth flagging because their lesson matters more than +their number: + +- **LDS bank-conflict pad — big win (1.46×)**. Classic 16-way conflict pattern + in the corner turn. Simple row-stride pad from K to K+1. +- **LDS instruction cuts gave ~nothing**. Dropping barriers (-60 % LDS wait), + fusing pass-2 into corner turn (-8 LDS ops/lane), register-fused pass-1→ + pass-2 (-8 more LDS ops/lane) all looked big in counters (40 % fewer LDS + insts total) but moved wall clock by <5 %. Interpretation: LDS was never + on the critical path in this kernel; those ops were fully hidden in the + shadow of MFMA/HBM latency. +- **Coalesced distribute — biggest single win (2.1× at N=2048)**. The prior + per-wave stride-K reads were being served as individual cache lines by the + hardware coalescer — 16× amplification on L2 request count. `rocprof` showed + HBM BW "only 26 % of peak" but actually the cache-line layer was saturated. + Counter-intuitive: the HBM BW counter was misleading. +- **Widening the pass-3 store via LDS staging — regression (-17 %)**. The 4× + narrow stores were already coalescing into 4 cache-line writes at the + hardware level; adding a barrier + LDS round-trip lost more than the + saved instructions gained. +- **MFMA swap trick (pass 1) — enables register fusion but no perf gain**. + Swapping operands in pass 1 makes its output layout match pass 2's input + exactly, so pass 2 reads acc directly. Theoretically saves LDS traffic; + wall-clock impact hidden by memory bubbles (see point 2 above). + +### Compiler / ISA notes (GFX90A) + +- MFMA `v_mfma_f64_16x16x4f64` layouts confirmed empirically via + `test_mfma_layout.hip` (the upstream L4 comment had them wrong): + - A_frag (16×4): lane t → A[t%16][t/16] + - B_frag (4×16): lane t → B[t/16][t%16] + - D output (16×16): lane t acc[e] → D[(t/16) + 4e][t%16] +- The widest global load/store on gfx90a is `global_{load,store}_dwordx4` + (128-bit, 2 doubles per lane). `double4` in code compiles to a pair + of these. +- LDS bank period: 32 banks × 4 bytes = 128 bytes. Stride-K doubles + (K = 16) hits the same bank row → 16-way conflict; stride-(K+1) shifts + by 2 banks → conflict-free. + +--- + +## K=20 / K=24 — use v_mfma_f64_4x4x4f64 (design note) + +K=20 and K=24 are the scientifically relevant odd-size cases (K=32 turns +out not to be; it was a suggestion just because the 16-divisibility +worked nicely with MFMA 16×16×4). Neither K=20 nor K=24 divides by 16, +so they can't use `v_mfma_f64_16x16x4f64` directly without padding, and +padding to K=32 blows LDS. + +The path forward is to use the smaller **`v_mfma_f64_4x4x4f64`** MFMA +variant — same hardware matrix core, 4×4 output tile instead of 16×16. +This avoids padding entirely because 4 divides both 20 and 24 cleanly. + +### Tile-count cost + +Per pass, per wave: + +| | 16×16×4 path (K=16) | 4×4×4 path (K=20) | 4×4×4 path (K=24) | +|-------|---------------------|--------------------|--------------------| +| output tiles per K slice | 1 | 25 = 5×5 | 36 = 6×6 | +| K slices needed | 4 | 5 | 6 | +| MFMA calls per pass | 4 | 125 | 216 | + +The 4×4×4 approach uses ~30× more MFMA calls per pass than the K=16 +kernel. On gfx90a, matrix-core throughput for all f64 MFMA variants +is the same per-cycle-of-matrix-core, so more-but-smaller calls is a +wash on raw compute — but the **issue rate** may bottleneck at low +occupancy. + +### 4×4×4 fragment layout — CONFIRMED via probe + +`v_mfma_f64_4x4x4f64` on gfx90a returns **one double per lane**. +64 lanes × 1 = 64 output values = **four INDEPENDENT 4×4×4 GEMMs** +(NOT broadcast — earlier hypothesis disproved by `test_mfma_4x4x4_layout{,2,3}.hip`). + +With the lane decomposition `S = 16·α + 4·g + β` where `β = S%4`, +`g = (S/4)%4` and `α = S/16`, the **group index** is `g` (∈ {0,1,2,3}). +Each of the four groups computes its own 4×4×4 GEMM `D_g = A_g · B_g`: + +| fragment | lane S holds | +|-----------|------------------------------| +| A (4×4) | `A_g[m = β ][k = α]` | +| B (4×4) | `B_g[k = α ][n = β]` | +| D (4×4) | `D_g[m = α ][n = β]` | + +So A is column-major across 16 lanes, B row-major, D row-major, and +`t/16`/`(t/4)%4` each play different roles for the input vs. output +fragments — A's column index is where D's row index comes from, and +vice versa. The four `g`-groups share nothing: `g=1` reads none of +`g=0`'s lanes. + +This is strictly better for our K=20 case than 4×-broadcast would have +been: we can pack 4 independent 4×4 output tiles into a single +instruction, if we can lay out A/B so the four groups see different +tiles. + +### Algorithm shape (once the 4×4×4 layout is understood) + +Assuming we find a CBSZ/BLGP combination that gives 4 independent tiles: + +- Replace the single 16×16 `mma_sync` in each of the three passes + with a nested pair of tile loops: + ``` + for (row_tile = 0; row_tile < K/4; ++row_tile) + for (col_tile = 0; col_tile < K/4; ++col_tile) + for (k_slice = 0; k_slice < K/4; ++k_slice) + MFMA_4x4x4(...) + ``` + At K=20: 5×5×5 = 125 calls per pass per wave. +- The per-wave output is K² = 400 (K=20) or 576 (K=24) doubles. + 400/64 = 6.25 per lane — **not integer**. For K=20 this means + some lanes will have 6 elements, others 7. Awkward. May want + to pad to K=21 internally (6.56/lane)? Or group 2 MFMA calls + per pair of lanes (?). Ugly. + 576/64 = 9 per lane — clean for K=24. +- LDS: K³ × 8 = 20³×8 = 64 KB (K=20), 24³×8 = 110 KB (K=24). + K=20 just fits; K=24 still over. So K=24 needs the same + streaming story as K=32 (see section below). + +### Recommended next steps (in order) + +1. **Measure L3 at K=20 and K=24 first.** If L3 hits 50%+ of the HBM + roofline at these K values, a custom MFMA kernel may not be worth + the complexity. +2. **Pin down the 4×4×4 fragment layout** via a diagnostic kernel. +3. **Start with K=20 using 4×4×4** — clean LDS budget, even if awkward + lane distribution. K=24 can follow after streaming is figured out. +4. Consider a hybrid **16×16×4 main + 4×4×4 remainder** tiling (e.g. + K=20 = 16+4 in each dim gives one 16×16 main tile + strips + + corner). Complex but avoids most of the 4×4×4 instruction overhead. + +### K=20 status: implemented (MFMA via 4 independent tiles per call) + +`transform_blocked_k20.h` ships both a scalar correctness baseline and +a 4×4×4 MFMA path. Design choices and constraints: + +| knob | K=20 choice | +|-------------------------|-------------------------------------------------------| +| threads/block | 64 × 10 = 640 (20 waves would exceed the 1024 cap) | +| slabs per wave | 2 (wave w owns k' ∈ {w, w+10}) | +| LDS tensor buf | K³ = 8000 doubles = 64 KB, **unpadded** | +| B caching | read from HBM (no room for B_lds at 64 KB cap) | +| tile pattern | 25 output tiles of 4×4 per slab, 7 MFMA rounds × 5 k-slices | +| wasted groups | last round has 3 unused groups (fed zero) | + +Measured @ N=2048, FP64, MI250X single GCD: + +| variant | GF/s | vs. L3 | kernel µs | MFMA busy | LDS wait/inst | +|---------------------------------------|-----:|--------:|----------:|----------:|--------------:| +| L3 (register-block) | 1967 | — | — | — | — | +| L7 scalar (K=20) | 1083 | 0.55× | — | — | — | +| L7 pure 4×4×4 | 2787 | 1.42× | 753 | 12.2 % | 5.5 cyc | +| L7 hybrid 16×16×4 + 4×4×4 | 3480 | 1.77× | 591 | 11.1 % | 19.4 cyc | +| **L7 hybrid + fusion + wide dist** | **6900** | **3.51×** | **370** | **17.8 %** | **4.1 cyc** | + +The final configuration pairs three changes: + +1. **Hybrid tiling** — 16×16×4 MFMA for the main 16×16 sub-tile, 4×4×4 MFMA + for the 16×4 / 4×16 strips and the 4×4 corner. 43 % fewer MFMA issues + than the pure 4×4×4 path. +2. **Pass 1 → Pass 2 register fusion** (operand-swap trick from the K=16 + kernel). Pass 1 computes `D = temp1^T` so the main accumulator `p1_main[e]` + at lane t directly equals `temp1[t%16][(t/16)+4e]` — exactly what pass 2's + A-frag wants. Pass 1 right strip gets the same treatment and feeds pass 2's + 5th k-slice from register `p1_right`. Bottom strip and corner can't be + fused cleanly (α/β transpose mismatch) and still round-trip through LDS. + LDS instructions dropped 32 %, LDS wait cycles dropped 79 %. +3. **double4 distribute** — K³ = 8000 doubles = 2000 `double4` loads. + +Remaining headroom (all less impactful than the above): +- **MFMA busy still 18 %** — there's still ~2× to gain before matching K=16's + ~50 % MFMA-busy ceiling; further reductions in VALU overhead (mostly in + the bottom/corner 4×4×4 loops) and in bank conflicts (currently 16-way on + the stride-K=20 accesses for bottom/corner) would help. +- **No B_lds** — single-block-per-tensor has no room for a 3.2 KB B cache + after the K³ = 64 KB `buf`. Attempted 2-blocks-per-tensor with padded + layout + B_lds but atomicAdd + pre-zero overhead (+500 MB HBM) erased the + bank-conflict win — see notes below. +- **Pass 2 → Pass 3 fusion** isn't feasible: pass 3's cross-wave corner-turn + REQUIRES data in LDS. + +### Experiments that didn't pay off + +- **2 blocks per tensor + K+1 padded LDS + atomicAdd** (`transform_kernel_blocked_k20_split` + kept in the tree for reference). Bank conflicts did drop to 0.2-way, but + atomic read-modify-write on C and pre-zero memset inflated HBM traffic + 3× (258 → 752 MB), VALU went up 2.6×, and MFMA busy *fell* to 6 %. Net: + ~1.9× slower. Take-away: bank conflict counters were real but not the + critical path on the hybrid; the extra coordination costs of splitting + the k'-axis across blocks easily dominated the LDS wins. + +### 4×4×4 layout — CONFIRMED (overrides the "broadcast" hypothesis) + +From `test_mfma_4x4x4_layout{,2,3}.hip`: + +- The instruction is **4 INDEPENDENT 4×4×4 GEMMs**, not a 4-way + broadcast. Lane decomposition `S = 16α + 4g + β` with + `β = S%4, g = (S/4)%4, α = S/16`. `g` is the group id. +- A (4×4): lane S holds `A_g[m=β][k=α]` (col-major) +- B (4×4): lane S holds `B_g[k=α][n=β]` (row-major) +- D (4×4): lane S holds `D_g[m=α][n=β]` (row-major) + +Note α and β swap roles between input (A/B) and output (D) fragments — +same oddity as the 16×16×4 layout. + +### Obvious next optimisations (not pursued here) + +- **Fit B in LDS** via multi-block-per-tensor (e.g. 2 blocks handling + 10 k'-slabs each, atomic-add in pass 3). LDS per block drops to + ~35 KB and every MFMA can use LDS B. +- **Hybrid 16×16×4 main + 4×4×4 strip/corner** tile pattern. Uses the + faster instruction on the bulk of the work. +- **Pass-1/pass-2 register fusion** (as in the K=16 path). + +--- + +## K=32 — design note (not yet implemented) + +At K=32 the K=16 algorithm's working set doesn't fit in LDS, even split +across multiple blocks. This is a real redesign, not a tweak. + +### The budget problem + +| thing | K=16 | K=32 | +|------------------------------------------------|--------------|--------------| +| full K³ tensor | 4 K dbl = 32 KB | 32 K dbl = **256 KB** | +| per-wave K×K slab | 256 dbl = 2 KB | 1024 dbl = 8 KB | +| 16 waves × slab (one block per tensor) | 32 KB | **128 KB** | +| 16 waves × slab with K+1 pad | 34 KB | **132 KB** | +| LDS cap (MI250X) | 64 KB | 64 KB | + +Even splitting into 2 blocks (each handling K/2 of k') gives 128 KB/block — +2× over cap. There's no simple split that makes the K=16 algorithm fit. + +### What must change + +The working set per wave has to shrink from K×K to something smaller. +Two candidate approaches: + +**(A) Each wave holds K × K/2 (half-slab).** + - 2 waves cooperate on a full K×K slab. + - Per-wave LDS: 512 dbl = 4 KB. 32 waves × 4 KB = 128 KB. Still over. + - 16 waves × 4 KB = 64 KB, halved coverage. Would need 2 blocks to cover. + +**(B) Stream the corner turn in k'-chunks.** + - Hold only a chunk of the tensor at a time: K × K × (K/4) = 8 KB · 4 = 32 KB. + - Pass 3 accumulates across 4 chunked corner turns into registers. + - Extra HBM passes for re-reading A, unless we also chunk passes 1/2. + - Complex but fits the budget. + +**(C) Atomic-add reduction across 4 blocks.** + - 4 blocks per tensor; each handles K/4 = 8 values of k'. + - Each block's pass 3 contributes a partial; atomicAdd to HBM C. + - 4 atomics per output element, 4K³ = 128K atomics per tensor. + - Simpler code, slower arithmetic path. + +### Thread-block sizing + +At K=32, blockDim must be ≤ 1024 threads. With 64-thread waves, that's +≤ 16 waves per block. So we can't have 32 waves (one per k') in a single +block — each wave must handle ≥ 2 slabs, or blocks must handle fewer k' +values. + +### Recommended approach for the next iteration + +Start with **(C) — 4 blocks per tensor with atomic add**, because: +- Reuses the existing K=16 algorithm shape almost verbatim (per-block + pipeline is identical). +- Simplest to get correct; good baseline before optimizing. +- Perf will be bad (atomics, 4× HBM writes for reduction) but it + validates the design. + +Then iterate toward **(B)** — streaming — which preserves single-block- +per-tensor semantics and should be faster once correct. + +### Open questions worth probing first + +- What does the actual MADNESS workload look like at K=32? If `nfuncs` + is small, block-count explosion from approach (C) isn't a problem + but atomic contention might be. +- Does hipBLAS's batched DGEMM do better than anything we can write by + hand at K=32? Worth measuring L6 (kron) on real K=32 workloads as + a reference point. +- L4 / L5 (upstream MFMA + rocWMMA paths) likely have K=32 code; are + their designs worth stealing? + +--- + +## References + +- `validate.hip` — CPU reference and GPU L1 correctness check +- `validate_levels.hip` — multi-level correctness test (select L7 with `-l 7`) +- `transformbench.hip` — throughput benchmark (`-l 7`) +- `mra_python/algorithms.py::transform_nd_blocked` — NumPy version +- `test_mfma_layout.hip` — diagnostic that verifies MFMA fragment layouts +- `counters.txt`, `counters_deep.txt` — rocprof counter sets +- `roofline.py`, `roofline.png` — roofline analysis +- `transform.h`, `transform_level{2,3}.h` — earlier GPU levels for comparison diff --git a/CMakeLists.txt b/CMakeLists.txt index 1935e9b..9a51054 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,114 +2,29 @@ cmake_minimum_required(VERSION 3.10) project(transformbench LANGUAGES CXX) -set(CMAKE_CUDA_ARCHITECTURES 80) - set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) include(CheckLanguage) -check_language(CUDA) -if(CMAKE_CUDA_COMPILER) - enable_language(CUDA) - set(TTG_ENABLE_CUDA ON) -else(CMAKE_CUDA_COMPILER) - message(WARNING "CUDA compiler not found") -endif(CMAKE_CUDA_COMPILER) -set(HAVE_CUDA ${CMAKE_CUDA_COMPILER} CACHE BOOL "True if we can compile .cu files") - check_language(HIP) if(CMAKE_HIP_COMPILER) enable_language(HIP) - set(TTG_ENABLE_HIP ON) -else(CMAKE_HIP_COMPILER) - message(WARNING "HIP compiler not found") -endif(CMAKE_HIP_COMPILER) -set(HAVE_HIP ${CMAKE_HIP_COMPILER} CACHE BOOL "True if we can compile .hip files") - -option(USE_SUGGEST_LAYOUT "Use suggested layout instead of get_layout" ON) -option(DEBUG_TENSOR_TYPE "Compile-time print cute tensor types (breaks build)" OFF) -set(USE_CUBLASDX_VERSION "25.06" CACHE STRING "Version of cublasDx to use") - - -# fetch cublasDx -if (CMAKE_CUDA_COMPILER) - include(FetchContent) - FetchContent_Declare( - cublasdx - URL https://developer.download.nvidia.com/compute/cublasdx/redist/cublasdx/nvidia-mathdx-${USE_CUBLASDX_VERSION}.0.tar.gz - ) - FetchContent_MakeAvailable(cublasdx) - FetchContent_GetProperties(cublasdx - SOURCE_DIR CUBLASDX_SOURCE_DIR - BINARY_DIR CUBLASDX_BINARY_DIR - ) - - # look for cublasDx - find_package(mathdx REQUIRED COMPONENTS cublasdx HINTS ${CUBLASDX_SOURCE_DIR}/nvidia/mathdx/25.06/) - if (TARGET mathdx::cublasdx) - message(STATUS "Found cublasDx at ${mathdx_CUBLASDX_DIR}") - else() - message(FATAL_ERROR "cublasDx not found") - endif() - -endif(CMAKE_CUDA_COMPILER) - -# Simple interface that holds cublasDx and CUDA settings -add_library(libmra INTERFACE) -if (CMAKE_CUDA_COMPILER) - # Link against cublasDx and CUDA - target_link_libraries(libmra INTERFACE mathdx::cublasdx) - # Set the CUDA architecture - target_compile_definitions(libmra INTERFACE MRA_CUDA_ARCH=${CMAKE_CUDA_ARCHITECTURES} MRA_HAVE_CUDA=1) - # Enable support for constexpr and extended lambdas - target_compile_options(libmra INTERFACE --expt-relaxed-constexpr --extended-lambda) - - # Add the transformbench executable - add_executable(transformbench_cuda transformbench.cu) - - # Link against the MRA interface - target_link_libraries(transformbench_cuda PUBLIC libmra) - - - if (USE_SUGGEST_LAYOUT) - # Enable using suggested layout instead of get_layout - target_compile_definitions(transformbench_cuda PUBLIC USE_SUGGEST_LAYOUT) - endif (USE_SUGGEST_LAYOUT) - - if (DEBUG_TENSOR_TYPE) - # Enable compile-time printing of cute tensor types (breaks build) - target_compile_definitions(transformbench_cuda PUBLIC DEBUG_TENSOR_TYPE) - endif (DEBUG_TENSOR_TYPE) +else() + message(FATAL_ERROR "HIP compiler not found") endif() -if (CMAKE_HIP_COMPILER) - # Ensure ROCm cmake configs (hip, hipblas, etc.) are findable - list(APPEND CMAKE_PREFIX_PATH /opt/rocm-6.4.3 /opt/rocm) - find_package(hipblas REQUIRED) - - #target_link_libraries(libmra INTERFACE mathdx::cublasdx) - # Set the CUDA architecture - target_compile_definitions(libmra INTERFACE MRA_HAVE_HIP=1) - # Enable support for constexpr and extended lambdas - #target_compile_options(libmra INTERFACE --expt-relaxed-constexpr --extended-lambda) - - # Add the transformbench executable - add_executable(transformbench_hip transformbench.hip) - - # Link against the MRA interface and hipBLAS (for level 6 Kronecker GEMM) - target_link_libraries(transformbench_hip PUBLIC libmra roc::hipblas) +# Simple interface that holds HIP settings +add_library(libmra INTERFACE) +target_compile_definitions(libmra INTERFACE MRA_HAVE_HIP=1) - # Correctness test: validate any optimization level against the L1 reference - add_executable(validate_levels validate_levels.hip) - target_link_libraries(validate_levels PUBLIC libmra roc::hipblas) +# Add the transformbench executable +add_executable(transformbench_hip transformbench.hip) +target_link_libraries(transformbench_hip PUBLIC libmra) - if (USE_SUGGEST_LAYOUT) - # Enable using suggested layout instead of get_layout - target_compile_definitions(transformbench_hip PUBLIC USE_SUGGEST_LAYOUT) - endif (USE_SUGGEST_LAYOUT) +# Correctness test: GPU L1 vs CPU reference (mirrors transform3d.cc) +add_executable(validate validate.hip) +target_link_libraries(validate PUBLIC libmra) - if (DEBUG_TENSOR_TYPE) - # Enable compile-time printing of cute tensor types (breaks build) - target_compile_definitions(transformbench_hip PUBLIC DEBUG_TENSOR_TYPE) - endif (DEBUG_TENSOR_TYPE) -endif () \ No newline at end of file +# Multi-level correctness test (-l selects level, -K and -N override defaults) +add_executable(validate_levels validate_levels.hip) +target_link_libraries(validate_levels PUBLIC libmra) diff --git a/README.md b/README.md index dcb7572..39c428a 100644 --- a/README.md +++ b/README.md @@ -242,7 +242,6 @@ ln -sf build/compile_commands.json compile_commands.json ```bash cmake .. -DMRA_HAVE_HIP=1 -DCMAKE_CXX_COMPILER=hipcc \ -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_PREFIX_PATH=/opt/rocm-6.4.3 \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON ``` diff --git a/counters.txt b/counters.txt new file mode 100644 index 0000000..93add41 --- /dev/null +++ b/counters.txt @@ -0,0 +1,8 @@ +pmc: SQ_WAVES GRBM_GUI_ACTIVE +pmc: SQ_INSTS_VALU SQ_INSTS_MFMA SQ_INSTS_LDS +pmc: SQ_WAIT_INST_LDS SQ_LDS_BANK_CONFLICT +pmc: SQ_VALU_MFMA_BUSY_CYCLES +pmc: SQ_INSTS_VALU_MFMA_MOPS_F64 SQ_INSTS_VALU_FMA_F64 +pmc: TCC_HIT_sum TCC_MISS_sum +pmc: TCC_EA_RDREQ_sum TCC_EA_WRREQ_sum +pmc: TCC_EA_RDREQ_DRAM_sum TCC_EA_WRREQ_DRAM_sum diff --git a/frames/.gitkeep b/frames/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/mxm_cublasdx.h b/mxm_cublasdx.h deleted file mode 100644 index fb80692..0000000 --- a/mxm_cublasdx.h +++ /dev/null @@ -1,391 +0,0 @@ -#ifndef MRA_OPS_MXM_CUBLASDX_H -#define MRA_OPS_MXM_CUBLASDX_H - -#include "util.h" - -/** - * An implementation of A^T x B using cublasdx. - * We assume that A is tall-and-skinny (K^2 x K) and B is square (K x K). - * There is some code to cover the case where A is square and B is wide-and-skinny - * but that is not yet implemented and we don't use it yet. - */ - -#define MRA_CUBLASDX_BLOCK_C 0 - -#if __has_include() - -#define MRA_HAVE_CUBLASDX 1 - -#if !defined(MRA_CUDA_ARCH) || MRA_CUDA_ARCH < 70 -#error "MRA_CUDA_ARCH must be defined and >= 70 to use cublasdx" -#endif - -#include - -#if MRA_CUDA_ARCH == 70 -#define MRA_CUBLASDX_SM 700 -#define MRA_CUBLASDX_MAX_SHM (30*1024) -#elif MRA_CUDA_ARCH == 80 -#define MRA_CUBLASDX_SM 800 -#define MRA_CUBLASDX_MAX_SHM (40*1024) -#elif MRA_CUDA_ARCH == 90 -#define MRA_CUBLASDX_SM 900 -#define MRA_CUBLASDX_MAX_SHM (110*1024) -#else -#warning "Unknown MRA_CUDA_ARCH for cublasdx, using 80" -#define MRA_CUBLASDX_SM 800 -#endif - -#ifdef DEBUG_TENSOR_TYPE -#define PRINT_TENSOR_TYPE(t) cute::print_type(t) -#else // DEBUG_TENSOR_TYPE -#define PRINT_TENSOR_TYPE(t) -#endif // DEBUG_TENSOR_TYPE - -// get the layout for tensor t in GEMM -#ifdef USE_SUGGEST_LAYOUT -#define GET_SHARED_LAYOUT(op, t) op::suggest_layout_smem_##t() -#else // USE_SUGGEST_LAYOUT -#define GET_SHARED_LAYOUT(op, t) op::get_layout_smem_##t() -#endif // USE_SUGGEST_LAYOUT - - -namespace mra { - - namespace detail { - - constexpr int CUBLAS_MIN_MN = 16; - - template - constexpr int cublasdx_max_mn() { - // K^2 for square B/A, double buffering for A/B and C - auto max_nm = ((MRA_CUBLASDX_MAX_SHM / sizeof(T)) - K*K) / ((3+MRA_CUBLASDX_BLOCK_C)*K); - // round down to the nearest power of 2 - // TODO: std::log2 is constexpr only since C++26 - //int p = std::pow(2, (int)std::log2(max_nm)); - int l = 1; - while ((l<<1) <= max_nm) l <<= 1; - return std::min(l, K*K); - } - - template - struct GEMMBuilder { - - private: - using BaseGEMM = decltype(cublasdx::Precision() - + cublasdx::Type() - + cublasdx::Function() - + cublasdx::SM() // TODO - + cublasdx::Block() - + cublasdx::MaxAlignment()); - using GEMM_ = decltype(BaseGEMM() + cublasdx::Size() - + cublasdx::Arrangement()); - using GEMM_suggested_ld = cublasdx::suggested_leading_dimension_of_t; - public: - using GEMM = decltype(GEMM_() + GEMM_suggested_ld()); - }; - - template - __forceinline__ - __device__ void mTxmq_cublasdx_core(auto&& a_shared_tensor, auto&& b_shared_tensor, - auto&& c_tensor, - auto&& load = [](){}, auto&& prefetch = [](){}) { - - using alignment = cublasdx::alignment_of; - - /* load data to shared memory */ - load(); - /* wait for load to complete */ - cublasdx::copy_wait(); - - /* prefetch data for next iteration */ - prefetch(); - - // Execute using register API - auto [c_register_fragment, partitioner] = GEMM().execute(a_shared_tensor, b_shared_tensor); - - // Store back to global memory using cublasdx::copy_fragment API - - cublasdx::copy_fragment(c_register_fragment, c_tensor, partitioner); - } - - /** - * Compute the shared memory requirements for a given GEMM. - * Takes into account double buffering of A (block_a) and B (block_b) as well as - * staging of results through shared memory (block_c). - */ - template - constexpr int cublasdx_shmem_size_for(bool block_a, bool block_b, bool block_c) { - auto calc = cublasdx::make_shared_storage_calculator() - .add(cublasdx::alignment_of_v_a, sizeof(typename GEMM::a_value_type), GET_SHARED_LAYOUT(GEMM, a)) - .add(cublasdx::alignment_of_v_b, sizeof(typename GEMM::b_value_type), GET_SHARED_LAYOUT(GEMM, b)); - if (block_a) { - calc.add(cublasdx::alignment_of_v_a, sizeof(typename GEMM::a_value_type), GET_SHARED_LAYOUT(GEMM, a)); - } - if (block_b) { - calc.add(cublasdx::alignment_of_v_b, sizeof(typename GEMM::b_value_type), GET_SHARED_LAYOUT(GEMM, b)); - } - if (block_c) { - // double buffering of C - calc.add(cublasdx::alignment_of_v_c, sizeof(typename GEMM::c_value_type), GET_SHARED_LAYOUT(GEMM, c)); - calc.add(cublasdx::alignment_of_v_c, sizeof(typename GEMM::c_value_type), GET_SHARED_LAYOUT(GEMM, c)); - } - - int shared_memory_size = calc.get(); - return shared_memory_size; - } - - template - constexpr int cublasdx_shmem_size_k() { - constexpr auto blockdims = max_thread_dims(K); - using BaseGEMM = decltype(cublasdx::Precision() - + cublasdx::Type() - + cublasdx::Function() - + cublasdx::SM() // TODO - + cublasdx::Block() - + cublasdx::BlockDim() - + cublasdx::MaxAlignment()); - constexpr auto max_mn = cublasdx_max_mn(); - using GEMMBlockA = typename GEMMBuilder::GEMM; - auto size = cublasdx_shmem_size_for(true, false, MRA_CUBLASDX_BLOCK_C); - return size; - } - - template - __forceinline__ - __device__ void mTxmq_cublasdx_block(T* c, const T* a, const T* b) { - constexpr auto blockdims = max_thread_dims(K); - extern __shared__ __align__(16) char smem[]; - constexpr auto max_mn = cublasdx_max_mn(); - /* assuming aT = bT = cT for now */ - using GEMM = typename GEMMBuilder::GEMM; - - using alignment = cublasdx::alignment_of; - - - if constexpr (M == K*K) { - constexpr auto num_iter = M/max_mn; - //if (is_team_lead()) printf("mTxmq_cublasdx_block: max_mn %d, shared_memory %u, smem %p, M = %d, N = %d, K = %d iter %d\n", max_mn, cublasdx_shmem_size_for(true, false, true), smem, M, N, K, num_iter); - //__syncthreads(); - - if constexpr (num_iter > 0) { - auto [smem_a, smem_b, smem_a_n, smem_c, smem_c_n] = - cublasdx::shared_memory::slice_into_pointers( - smem, - cublasdx::alignment_of_v_a, cublasdx::cosize(GET_SHARED_LAYOUT(GEMM, a)), - cublasdx::alignment_of_v_b, cublasdx::cosize(GET_SHARED_LAYOUT(GEMM, b)), - cublasdx::alignment_of_v_a, cublasdx::cosize(GET_SHARED_LAYOUT(GEMM, a)), - cublasdx::alignment_of_v_c, cublasdx::cosize(GET_SHARED_LAYOUT(GEMM, c)), - cublasdx::alignment_of_v_c, cublasdx::cosize(GET_SHARED_LAYOUT(GEMM, c))); - - /* copy b tensor into shared memory and leave there */ - auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b()); - auto b_shared_tensor = cublasdx::make_tensor(smem_b, GET_SHARED_LAYOUT(GEMM, b)); - cublasdx::copy(b_global_tensor, b_shared_tensor); - PRINT_TENSOR_TYPE(b_global_tensor); - PRINT_TENSOR_TYPE(b_shared_tensor); - - auto a_shared_tensor = cublasdx::make_tensor(smem_a, GET_SHARED_LAYOUT(GEMM, a)); - auto a_shared_tensor_n = cublasdx::make_tensor(smem_a_n, GET_SHARED_LAYOUT(GEMM, a)); - - auto c_shared_tensor = cublasdx::make_tensor(smem_c, GET_SHARED_LAYOUT(GEMM, c)); - auto c_shared_tensor_n = cublasdx::make_tensor(smem_c_n, GET_SHARED_LAYOUT(GEMM, c)); - - int i; // used past the for loop below - - auto make_c_global_tensor = [&](int i){ - return cublasdx::make_tensor(c+((i*max_mn)*N), GEMM::get_layout_gmem_c()); - }; - - auto store_c = [&]() { -#if MRA_CUBLASDX_BLOCK_C - auto c_shared_tensor = cublasdx::make_tensor(smem_c, GET_SHARED_LAYOUT(GEMM, c)); - __syncthreads(); // make sure prior computations are done - auto c_global_tensor = make_c_global_tensor(i-1); - cublasdx::copy(c_shared_tensor, c_global_tensor); -#endif // MRA_CUBLASDX_BLOCK_C - }; - for (i = 0; i < num_iter; i++) { - // Make global memory tensors - auto a_global_tensor = cublasdx::make_tensor(a+(i*max_mn), GEMM::get_layout_gmem_a(cute::Int{})); - auto a_shared_tensor = cublasdx::make_tensor(smem_a, GET_SHARED_LAYOUT(GEMM, a)); - auto a_shared_tensor_n = cublasdx::make_tensor(smem_a_n, GET_SHARED_LAYOUT(GEMM, a)); - - auto c_shared_tensor = cublasdx::make_tensor(smem_c, GET_SHARED_LAYOUT(GEMM, c)); - auto c_shared_tensor_n = cublasdx::make_tensor(smem_c_n, GET_SHARED_LAYOUT(GEMM, c)); - - PRINT_TENSOR_TYPE(a_global_tensor); - PRINT_TENSOR_TYPE(a_shared_tensor); - PRINT_TENSOR_TYPE(make_c_global_tensor(i)); - PRINT_TENSOR_TYPE(c_shared_tensor); - //auto c_global_tensor = cublasdx::make_tensor(c+((i*max_mn)*N), GEMM::get_layout_gmem_c()); - mTxmq_cublasdx_core(a_shared_tensor, b_shared_tensor, -#if MRA_CUBLASDX_BLOCK_C - c_shared_tensor, -#else // MRA_CUBLASDX_BLOCK_C - /* global tensor */ - make_c_global_tensor(i), -#endif // MRA_CUBLASDX_BLOCK_C - [&](){ - /* load only on first iteration, all others are prefetched */ - if (i == 0) { - //if (is_team_lead()) printf("Loading initial block %d\n", i); - cublasdx::copy(a_global_tensor, a_shared_tensor); - } - }, - [&](){ - /* store prior iteration's result */ - if (i > 0) { - //if (is_team_lead()) printf("Storing block %d\n", i-1); - store_c(); - } - /* prefetch into shared memory */ - if ((i+1) < num_iter) { - //if (is_team_lead()) printf("Prefetching block %d\n", i); - auto a_global_tensor = cublasdx::make_tensor(a+((i+1)*max_mn), GEMM::get_layout_gmem_a(cute::Int{})); - cublasdx::copy(a_global_tensor, a_shared_tensor_n); - } - }); - auto tmp_a = smem_a; - smem_a = smem_a_n; - smem_a_n = tmp_a; - auto tmp_c = smem_c; - smem_c = smem_c_n; - smem_c_n = tmp_c; - -#if 0 - auto tmp_a = a_shared_tensor; - a_shared_tensor = a_shared_tensor_n; - a_shared_tensor_n = tmp_a; - auto tmp_b = c_shared_tensor; - c_shared_tensor = c_shared_tensor_n; - c_shared_tensor_n = tmp_b; -#else - //std::swap(a_shared_tensor, a_shared_tensor_n); - //std::swap(c_shared_tensor, c_shared_tensor_n); -#endif // 0 - } - /* store the last block of C */ - store_c(); - } - - /* handle remainder */ - constexpr const auto R = M%max_mn; - if constexpr (0 < R) { - // Make global memory tensors - using GEMM = typename GEMMBuilder::GEMM; - auto [smem_a, smem_b, smem_c] = cublasdx::slice_shared_memory(smem, GET_SHARED_LAYOUT(GEMM, a), - GET_SHARED_LAYOUT(GEMM, b), - GET_SHARED_LAYOUT(GEMM, c)); - auto a_shared_tensor = cublasdx::make_tensor(smem_a, GET_SHARED_LAYOUT(GEMM, a)); - auto a_global_tensor = cublasdx::make_tensor(a+((M/max_mn)*max_mn), GEMM::get_layout_gmem_a(cute::Int{})); - auto b_global_tensor = cublasdx::make_tensor(b, GEMM::get_layout_gmem_b()); - auto b_shared_tensor = cublasdx::make_tensor(smem_b, GET_SHARED_LAYOUT(GEMM, b)); - auto c_global_tensor = cublasdx::make_tensor(c+((M/max_mn)*max_mn*N), GEMM::get_layout_gmem_c()); - auto c_shared_tensor = cublasdx::make_tensor(smem_c, GET_SHARED_LAYOUT(GEMM, c)); - mTxmq_cublasdx_core(a_shared_tensor, b_shared_tensor, -#if MRA_CUBLASDX_BLOCK_C - c_shared_tensor, -#else // MRA_CUBLASDX_BLOCK_C - c_global_tensor, -#endif // MRA_CUBLASDX_BLOCK_C - [&](){ - cublasdx::copy(a_global_tensor, a_shared_tensor); - cublasdx::copy(b_global_tensor, b_shared_tensor); - }, - [](){}); - /* move the C block back to global memory */ - cublasdx::copy(c_shared_tensor, c_global_tensor); - } - } else { - // TODO: implement! - static_assert(M == K*K, "N equal to K*K currently not supported"); - } - /* final sync */ - cublasdx::copy_wait(); - } - - } // namespace detail - - template - __forceinline__ - __device__ void mTxmq(long dimi, long dimj, long dimk, - cT* c, const aT* a, const bT* b) { - int M = dimi; - int N = dimj; - int K = dimk; - if (M == K*K) { - // A is tall and skinny, B is square - if (K == 6) { - detail::mTxmq_cublasdx_block<36, 6, 6>(c, a, b); - } else if (K == 8) { - detail::mTxmq_cublasdx_block<64, 8, 8>(c, a, b); - } else if (K == 10) { - detail::mTxmq_cublasdx_block<100, 10, 10>(c, a, b); - } else if (K == 12) { - detail::mTxmq_cublasdx_block<12*12, 12, 12>(c, a, b); - } else if (K == 16) { - detail::mTxmq_cublasdx_block<16*16, 16, 16>(c, a, b); - } else if (K == 20) { - detail::mTxmq_cublasdx_block<400, 20, 20>(c, a, b); - } else if (K == 32) { - detail::mTxmq_cublasdx_block<32*32, 32, 32>(c, a, b); - } else { - if (is_team_lead()) printf("mTxmq: Unsupport K = %d\n", K); - } - } else { - printf("mTxmq: Unknown configuration with M = %d, N = %d, K = %d\n", M, N, K); - } - /* make sure all is done */ - __syncthreads(); - } - - template - constexpr int mTxmq_shmem_size(int K) { - switch (K) { - case 6: return detail::cublasdx_shmem_size_k(); - case 8: return detail::cublasdx_shmem_size_k(); - case 10: return detail::cublasdx_shmem_size_k(); - case 12: return detail::cublasdx_shmem_size_k(); - case 16: return detail::cublasdx_shmem_size_k(); - case 20: return detail::cublasdx_shmem_size_k(); - case 32: return detail::cublasdx_shmem_size_k(); - default: THROW("CUBLASdx: Unsupported K"); - } - } - - - namespace detail { - template - constexpr Dim3 cublasdx_blockdim_k() { - - return Dim3(MAX_THREADS_PER_BLOCK, 1, 1); - constexpr auto max_mn = cublasdx_max_mn(); - using GEMM = typename GEMMBuilder::GEMM; - return GEMM::suggested_block_dim; - } - - } // namespace detail - template - constexpr Dim3 mTxmq_blockdim(int K) { - switch (K) { - case 6: return detail::cublasdx_blockdim_k(); - case 8: return detail::cublasdx_blockdim_k(); - case 10: return detail::cublasdx_blockdim_k(); - case 12: return detail::cublasdx_blockdim_k(); - case 16: return detail::cublasdx_blockdim_k(); - case 20: return detail::cublasdx_blockdim_k(); - case 32: return detail::cublasdx_blockdim_k(); - default: THROW("CUBLASdx: Unsupported K"); - } - } - -} // namespace mra - -#endif // __has_include() - -#endif // MRA_OPS_MXM_CUBLASDX_H diff --git a/mxm_level2.h b/mxm_level2.h index 278445e..0e237e7 100644 --- a/mxm_level2.h +++ b/mxm_level2.h @@ -1,58 +1,19 @@ #pragma once - #include "util.h" -/** - * Level 2: B matrix loaded into LDS (shared memory) once per mTxmq call. - * A is streamed from global memory. Threads are distributed over - * rows (i) rather than columns (j), so all 128 threads stay busy - * even for small K. - * - * c(i,j) = sum_k a(k,i)*b(k,j) - * A: K^2 x K col-major a[k,i] = a[k*dimi + i] - * B: K x K row-major b[k,j] = b[k*dimj + j] - * C: K^2 x K row-major c[i,j] = c[i*dimj + j] - */ +// mxm_level2.h — L2 metadata: B staged in LDS by the transform kernel. +// The actual mTxmq computation reuses mra::mTxmq from mxm.h (B pointer is LDS). namespace mra { -/* Public entry-point: always clears C (mTxmq semantics, equivalent to Q=true) */ -template -__device__ void mTxmq_level2(size_type dimi, size_type dimj, size_type dimk, - cT* __restrict__ c, const aT* a, const bT* b) { - extern __shared__ char smem_level2[]; - bT* b_shmem = reinterpret_cast(smem_level2); - - /* Cooperatively load B (dimk * dimj elements) into LDS */ - for (int idx = threadIdx.x; idx < dimk * dimj; idx += blockDim.x) { - b_shmem[idx] = b[idx]; - } - __syncthreads(); - - /* Each thread handles a stripe of rows; full j and k loops are sequential */ - for (size_type i = (size_type)threadIdx.x; i < dimi; i += (size_type)blockDim.x) { - const aT* a_col_i = a + i; /* pointer to a[0,i] in col-major layout */ - cT* ci = c + i * dimj; - for (size_type j = 0; j < dimj; ++j) { - cT sum = cT(0); /* always clear: mTxmq semantics */ - const aT* aik = a_col_i; - for (size_type k = 0; k < dimk; ++k, aik += dimi) { - sum += (*aik) * b_shmem[k * dimj + j]; - } - ci[j] = sum; - } - } - __syncthreads(); -} - -template -constexpr size_type mTxmq_level2_shmem_size(size_type K) { - return K * K * sizeof(T); +template +constexpr size_type mTxmq_L2_shmem_size(size_type K) { + return (size_type)(K * K * sizeof(T)); } -template -constexpr Dim3 mTxmq_level2_blockdim(int /*K*/) { - return Dim3(MAX_THREADS_PER_BLOCK, 1, 1); +template +constexpr Dim3 mTxmq_L2_blockdim(int K) { + return max_thread_dims(K); } } // namespace mra diff --git a/mxm_level3.h b/mxm_level3.h index 86b5cc2..959feb7 100644 --- a/mxm_level3.h +++ b/mxm_level3.h @@ -1,86 +1,59 @@ #pragma once - #include "util.h" -/** - * Level 3: B in LDS + register accumulation. - * Each thread owns a full row of the output tile held in a compile-time - * register array T acc[K]. The k-loop loads a[k,i] once and FMAs it - * against all K columns of B (from LDS), eliminating redundant global - * loads and keeping the hot loop inside the register file. - * - * Use mTxmq_level3_k (K known at compile time) so that each K value - * gets its own kernel binary with isolated register pressure. - * - * c(i,j) = sum_k a(k,i)*b(k,j) - * A: K^2 x K col-major a[k,i] = a[k*dimi + i] - * B: K x K row-major b[k,j] = b[k*dimj + j] - * C: K^2 x K row-major c[i,j] = c[i*dimj + j] - */ +// mxm_level3.h — L3: register-blocked mTxmq, K compile-time template parameter. +// +// Algorithm (per thread, linear tid over K² rows): +// acc[K] = 0 // K doubles in VGPRs +// for k in 0..K-1: +// aki = a[k * K² + i] // one global load per k +// for j in 0..K-1: [unrolled] +// acc[j] += aki * b_shm[k*K+j] // LDS reads +// for j in 0..K-1: [unrolled] +// c[i*K + j] = acc[j] // one store per j +// +// c(i,j) = sum_k a(k,i)*b(k,j) [Q-mode: c zeroed via acc initialisation] namespace mra { -namespace detail { - -/** - * Inner kernel: B is already in b_shmem, register array acc[K] accumulates - * the dot product. Compile-time K keeps acc[] in VGPRs. - */ -template -__device__ void mTxmq_level3_impl(T* __restrict__ c, const T* a, const T* b_shmem) { - constexpr int DIMI = K * K; - - for (int i = (int)threadIdx.x; i < DIMI; i += (int)blockDim.x) { - T acc[K]; - - if constexpr (Q) { - for (int j = 0; j < K; ++j) acc[j] = T(0); - } else { - for (int j = 0; j < K; ++j) acc[j] = c[i * K + j]; +template +__device__ void mTxmq_L3(T* __restrict__ c, const T* __restrict__ a, const T* __restrict__ b_shm) +{ + constexpr int dimi = K * K; // number of output rows + constexpr int dimj = K; // number of output cols + + int nthr = blockDim.x * blockDim.y; + int tid = blockDim.x * threadIdx.y + threadIdx.x; + + for (int i = tid; i < dimi; i += nthr) { + T acc[K]; + #pragma unroll + for (int j = 0; j < K; ++j) acc[j] = T(0); + + for (int k = 0; k < K; ++k) { + T aki = a[k * dimi + i]; // global load: a(k, i) + #pragma unroll + for (int j = 0; j < K; ++j) { + acc[j] += aki * b_shm[k * dimj + j]; // LDS: b(k, j) + } + } + + #pragma unroll + for (int j = 0; j < K; ++j) { + c[i * dimj + j] = acc[j]; + } } - - /* k-loop: load a[k,i] once, FMA with all K entries of row k of B */ - const T* aik = a + i; /* a[0,i] in col-major */ - for (int k = 0; k < K; ++k, aik += DIMI) { - T aki = *aik; - for (int j = 0; j < K; ++j) { - acc[j] += aki * b_shmem[k * K + j]; - } - } - - for (int j = 0; j < K; ++j) c[i * K + j] = acc[j]; - } -} - -} // namespace detail - - -/** - * K-templated entry point — one binary per K value. - * Each instantiation sees only acc[K] for its specific K, - * keeping register pressure proportional to K rather than max(K). - */ -template -__device__ void mTxmq_level3_k(T* __restrict__ c, const T* a, const T* b) { - extern __shared__ char smem_level3[]; - T* b_shmem = reinterpret_cast(smem_level3); - - for (int idx = (int)threadIdx.x; idx < K * K; idx += (int)blockDim.x) - b_shmem[idx] = b[idx]; - __syncthreads(); - - detail::mTxmq_level3_impl(c, a, b_shmem); - __syncthreads(); + __syncthreads(); // output fully written before caller reads it } -template -constexpr size_type mTxmq_level3_shmem_size(size_type K) { - return K * K * sizeof(T); +template +constexpr size_type L3_shmem_size(int K) { + return K * K * sizeof(T); // same as L2 — one K×K matrix in LDS } -template -constexpr Dim3 mTxmq_level3_blockdim(int /*K*/) { - return Dim3(MAX_THREADS_PER_BLOCK, 1, 1); +template +constexpr Dim3 mTxmq_L3_blockdim(int K) { + return max_thread_dims(K); } } // namespace mra diff --git a/mxm_level4.h b/mxm_level4.h deleted file mode 100644 index 1133b21..0000000 --- a/mxm_level4.h +++ /dev/null @@ -1,185 +0,0 @@ -#pragma once - -#include "util.h" -#include "mxm_level3.h" /* for the Level-3 fallback */ - -/** - * Level 4: AMD MFMA (Matrix Fused Multiply-Accumulate) for FP64. - * - * On GFX90A / GFX940 the v_mfma_f64_16x16x4f64 instruction computes a - * 16x16 output tile with a 4-deep contraction in one wavefront step. - * Thread layout (64-thread wavefront, cbsz=0, abid=0, blgp=0): - * - * A input (16 rows x 4 cols = 64 elements): thread t -> A[t/4][t%4] - * B input (4 rows x 16 cols = 64 elements): thread t -> B[t/16][t%16] - * C/D out (16 rows x 16 cols = 256 elements, 4 per thread): - * thread t -> D[{(t/16)*4 + 0..3}][t%16] - * - * The blockdim for Level 4 is {64,1,1} (one wavefront). - * - * For K values not handled by MFMA (or on non-AMD targets), the implementation - * transparently falls back to the Level-3 register-blocking kernel. - * - * c(i,j) = sum_k a(k,i)*b(k,j) - * A: K^2 x K col-major a[k,i] = a[k*dimi + i] - * B: K x K row-major b[k,j] = b[k*dimj + j] - * C: K^2 x K row-major c[i,j] = c[i*dimj + j] - */ - -namespace mra { - -namespace detail { - -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx90a__) || defined(__gfx940__)) - -/* AMD vector type for 4 doubles — the return type of the MFMA builtin */ -typedef double mfma_d4 __attribute__((ext_vector_type(4))); - -/** - * MFMA kernel for compile-time K. - * Requires blockDim.x == 64 (one wavefront). - * B must already be loaded into b_shmem. - */ -template -__device__ void mTxmq_level4_mfma(T* __restrict__ c, const T* a, const T* b_shmem) { - static_assert(K % 16 == 0, - "mTxmq_level4_mfma: K must be a multiple of 16 for 16x16 MFMA"); - static_assert(K * K % 16 == 0, - "mTxmq_level4_mfma: K^2 must be a multiple of 16"); - - constexpr int DIMI = K * K; - constexpr int ROW_TILES = DIMI / 16; /* number of 16-row tiles */ - constexpr int COL_TILES = K / 16; /* number of 16-col tiles */ - - const int tid = (int)threadIdx.x; /* 0..63 */ - - /* Thread's contribution indices within one MFMA tile */ - const int a_row_in_tile = tid / 4; /* 0..15 */ - const int a_col_in_tile = tid % 4; /* 0..3 */ - const int b_row_in_tile = tid / 16; /* 0..3 */ - const int b_col_in_tile = tid % 16; /* 0..15 */ - const int d_col_in_tile = tid % 16; - const int d_row_grp = (tid / 16) * 4; /* first of 4 output rows this thread owns */ - - for (int r = 0; r < ROW_TILES; ++r) { - for (int ct = 0; ct < COL_TILES; ++ct) { - mfma_d4 acc = {0.0, 0.0, 0.0, 0.0}; - - /* loop over 4-deep contraction blocks */ - for (int k_block = 0; k_block < K; k_block += 4) { - /* A element: A^T[r*16 + a_row_in_tile, k_block + a_col_in_tile] - * A is col-major K^2 x K: a[k,i] = a[k*DIMI + i] - * so A^T[i, k] = a[k*DIMI + i] */ - int a_i = r * 16 + a_row_in_tile; - int a_k = k_block + a_col_in_tile; - double a_elem = (double)a[a_k * DIMI + a_i]; - - /* B element: B[k_block + b_row_in_tile, ct*16 + b_col_in_tile] - * B is row-major K x K: b[k,j] = b[k*K + j] */ - int b_k = k_block + b_row_in_tile; - int b_j = ct * 16 + b_col_in_tile; - double b_elem = (double)b_shmem[b_k * K + b_j]; - - acc = (mfma_d4)__builtin_amdgcn_mfma_f64_16x16x4f64( - a_elem, b_elem, (mfma_d4)acc, 0, 0, 0); - } - - /* Store 4 output elements owned by this thread: - * rows r*16 + d_row_grp + 0..3, col ct*16 + d_col_in_tile */ - int c_col = ct * 16 + d_col_in_tile; - int c_row_base = r * 16 + d_row_grp; - c[(c_row_base + 0) * K + c_col] = (T)acc[0]; - c[(c_row_base + 1) * K + c_col] = (T)acc[1]; - c[(c_row_base + 2) * K + c_col] = (T)acc[2]; - c[(c_row_base + 3) * K + c_col] = (T)acc[3]; - } - } -} - -#endif /* AMD MFMA guard */ - -} // namespace detail - - -/* Public entry-point: always clears C (mTxmq semantics, Q=true). */ -template -__device__ void mTxmq_level4(size_type dimi, size_type dimj, size_type dimk, - cT* __restrict__ c, const aT* a, const bT* b) { - extern __shared__ char smem_level4[]; - bT* b_shmem = reinterpret_cast(smem_level4); - - /* Load B into LDS */ - for (int idx = (int)threadIdx.x; idx < dimk * dimj; idx += (int)blockDim.x) { - b_shmem[idx] = b[idx]; - } - __syncthreads(); - -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx90a__) || defined(__gfx940__)) - /* MFMA path: only for K divisible by 16 */ - if (dimi == dimj * dimj) { - if (dimj == 16) { - detail::mTxmq_level4_mfma(c, a, b_shmem); - __syncthreads(); - return; - } else if (dimj == 32) { - detail::mTxmq_level4_mfma(c, a, b_shmem); - __syncthreads(); - return; - } - } - /* Fall through to Level-3 register blocking for other K values */ -#endif - - /* Level-3 fallback (also the path on CUDA / non-GFX90A AMD) */ - if (dimi == dimj * dimj) { - if (dimj == 6) detail::mTxmq_level3_impl(c, a, b_shmem); - else if (dimj == 8) detail::mTxmq_level3_impl(c, a, b_shmem); - else if (dimj == 10) detail::mTxmq_level3_impl(c, a, b_shmem); - else if (dimj == 12) detail::mTxmq_level3_impl(c, a, b_shmem); - else if (dimj == 16) detail::mTxmq_level3_impl(c, a, b_shmem); - else if (dimj == 20) detail::mTxmq_level3_impl(c, a, b_shmem); - else if (dimj == 32) detail::mTxmq_level3_impl(c, a, b_shmem); - else { - if (is_team_lead()) printf("mTxmq_level4: unsupported K=%d\n", (int)dimj); - } - } - __syncthreads(); -} - -/** - * K-templated entry point — one binary per K value. - * Loads B into LDS, then dispatches to MFMA (if available) or Level-3 fallback. - * Requires blockDim.x == 64 (one wavefront) on MFMA path. - */ -template -__device__ void mTxmq_level4_k(T* __restrict__ c, const T* a, const T* b) { - extern __shared__ char smem_level4[]; - T* b_shmem = reinterpret_cast(smem_level4); - - for (int idx = (int)threadIdx.x; idx < K * K; idx += (int)blockDim.x) - b_shmem[idx] = b[idx]; - __syncthreads(); - -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx90a__) || defined(__gfx940__)) - if constexpr (K % 16 == 0) { - detail::mTxmq_level4_mfma(c, a, b_shmem); - __syncthreads(); - return; - } -#endif - /* Level-3 register-blocking fallback */ - detail::mTxmq_level3_impl(c, a, b_shmem); - __syncthreads(); -} - -template -constexpr size_type mTxmq_level4_shmem_size(size_type K) { - return K * K * sizeof(T); -} - -template -constexpr Dim3 mTxmq_level4_blockdim(int /*K*/) { - return Dim3(64, 1, 1); /* one wavefront */ -} - -} // namespace mra diff --git a/mxm_level5.h b/mxm_level5.h deleted file mode 100644 index bb330a6..0000000 --- a/mxm_level5.h +++ /dev/null @@ -1,191 +0,0 @@ -#pragma once - -#include "util.h" -#include "mxm_level3.h" - -/** - * Level 5: LDS-staged A with v_mfma_f64_16x16x4f64, 4-wavefront block. - * - * For K=16, A is 256×16 (col-major). A is loaded in NCHUNKS strips of - * CHUNK_ROWS rows each. All 256 threads cooperate to load one strip into - * LDS, then each of the 4 wavefronts takes a disjoint 16×16 subtile of that - * strip and runs v_mfma_f64_16x16x4f64 against B (always resident in LDS). - * Once all 4 wavefronts finish, the next strip is loaded. - * - * For K=16 (primary target): - * CHUNK_ROWS = 64 (K²/4) — loads all of A in 4 strips - * NCHUNKS = 4 - * Each wavefront: 1 tile of 16×16 per chunk → 4 tiles total = 64 rows each - * Total C coverage: 4 chunks × (4 WFs × 16 rows) = 256 rows = K² ✓ - * - * CHUNK_ROWS is chosen at compile time as the largest fraction of DIMI - * (DIMI/4, DIMI/8, DIMI/16) whose A strip fits alongside B in 64 KB LDS. - * CHUNK_ROWS is always a multiple of NWARPS×16=64 so tiles divide evenly. - * For K=32: CHUNK_ROWS=128, NCHUNKS=8, TILES_PER_WF=2 per chunk. - * - * LDS layout: - * [0 ]: B K×K row-major (K*K doubles, loaded once) - * [K*K ]: A_strip K×(CHUNK_ROWS+1) col-major with +1 padding per k-column - * a_lds[k*(CHUNK_ROWS+1) + row_local] = A^T[row_base+row_local][k] - * Padding shifts LDS banks between k-groups, avoiding 4-way conflicts. - * - * Global memory load pattern (K=16, CHUNK_ROWS=64): - * 256 threads load 1024 elements in 4 passes; each pass loads 64 consecutive - * doubles from one k-column of A → full 512-byte coalesced transaction. - * - * v_mfma_f64_16x16x4f64 thread layout (tid = lane in wavefront, 0..63): - * a_row = tid/4 (0..15): row in the 16×4 A operand - * a_k = tid%4 (0..3 ): k-depth offset within the 4-wide block - * b_k = tid/16 (0..3 ): k-depth offset in the 4×16 B operand - * b_col = tid%16 (0..15): column - * d_row_grp = (tid/16)*4 : first of the 4 output rows owned by this thread - * d_col = tid%16 - * - * c(i,j) = sum_k a(k,i)*b(k,j) - * A: K²×K col-major a[k,i] = a[k*K²+i] - * B: K×K row-major b[k,j] = b[k*K +j] - * C: K²×K row-major c[i,j] = c[i*K +j] - */ - -namespace mra { - -namespace detail { - -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx90a__) || defined(__gfx940__)) - -typedef double mfma_d4 __attribute__((ext_vector_type(4))); - -template -__device__ void mTxmq_level5_mfma(T* __restrict__ c, const T* a, T* b_shmem) { - static_assert(K % 16 == 0, "mTxmq_level5_mfma: K must be a multiple of 16"); - - constexpr int DIMI = K * K; - constexpr int NWARPS = 4; - - /* --- Compile-time chunk sizing ------------------------------------------ */ - constexpr int LDS_BUDGET = 64 * 1024; - constexpr int B_BYTES = DIMI * (int)sizeof(T); - /* A strip with +1 padding: K*(CHUNK_ROWS+1)*sizeof(T) bytes */ - /* CHUNK_ROWS+1 <= (LDS_BUDGET - B_BYTES) / (K*sizeof(T)) */ - constexpr int MAX_CHUNK_ROWS = (LDS_BUDGET - B_BYTES) / (K * (int)sizeof(T)) - 1; - /* Pick largest DIMI/N that fits AND is a multiple of NWARPS*16=64 */ - constexpr int CHUNK_ROWS = (DIMI / 4 <= MAX_CHUNK_ROWS) ? DIMI / 4 : - (DIMI / 8 <= MAX_CHUNK_ROWS) ? DIMI / 8 : DIMI / 16; - constexpr int NCHUNKS = DIMI / CHUNK_ROWS; - constexpr int A_STRIDE = CHUNK_ROWS + 1; /* padded LDS column stride */ - constexpr int ROWS_PER_WF = CHUNK_ROWS / NWARPS; /* rows per WF per chunk */ - constexpr int TILES_PER_WF = ROWS_PER_WF / 16; /* 16-row tiles per WF */ - - /* --- Thread indices ----------------------------------------------------- */ - const int tid_block = (int)threadIdx.x; /* 0..255 */ - const int warp_id = tid_block / 64; /* 0..3 — selects SIMD unit */ - const int tid = tid_block % 64; /* 0..63 — lane within wavefront */ - - /* v_mfma_f64_16x16x4f64 lane mapping */ - const int a_row = tid / 4; /* 0..15: row in A operand */ - const int a_k = tid % 4; /* 0..3: k-depth */ - const int b_k = tid / 16; /* 0..3: k-depth in B operand */ - const int b_col = tid % 16; /* 0..15: column */ - const int d_row_grp = (tid / 16) * 4; /* first of 4 output rows */ - const int d_col = tid % 16; - - /* A strip buffer sits directly after B in LDS */ - T* a_lds = b_shmem + DIMI; - - /* ========================================================================= - * Outer loop: one A strip per iteration. - * ========================================================================= */ - for (int chunk = 0; chunk < NCHUNKS; ++chunk) { - const int row_base = chunk * CHUNK_ROWS; /* first global A^T row in strip */ - - /* --- Cooperative load of A strip (all 256 threads) ------------------- */ - /* a_lds[k * A_STRIDE + row_local] = a[k*DIMI + row_base + row_local] */ - /* */ - /* Load order: idx steps over [0, K*CHUNK_ROWS) in strides of 256. */ - /* For K=16, CHUNK_ROWS=64: each pass covers one complete k-column */ - /* (64 consecutive doubles), giving a fully-coalesced 512-byte burst. */ - for (int idx = tid_block; idx < K * CHUNK_ROWS; idx += 256) { - const int row_local = idx % CHUNK_ROWS; /* row within strip */ - const int k = idx / CHUNK_ROWS; /* k-column of A */ - a_lds[k * A_STRIDE + row_local] = a[k * DIMI + row_base + row_local]; - } - __syncthreads(); /* strip fully in LDS before any MFMA begins */ - - /* --- MFMA: each wavefront owns TILES_PER_WF consecutive 16-row tiles - */ - const int wf_local_row_start = warp_id * ROWS_PER_WF; - - for (int t = 0; t < TILES_PER_WF; ++t) { - const int local_row = wf_local_row_start + t * 16; /* tile start in strip */ - - mfma_d4 acc = {0.0, 0.0, 0.0, 0.0}; - - /* K/4 steps of 4-deep contraction */ - for (int kb = 0; kb < K; kb += 4) { - /* A^T[row_base + local_row + a_row][kb + a_k] from LDS */ - const double a_elem = - (double)a_lds[(kb + a_k) * A_STRIDE + local_row + a_row]; - - /* B[kb + b_k][b_col] from LDS (row-major) */ - const double b_elem = - (double)b_shmem[(kb + b_k) * K + b_col]; - - acc = (mfma_d4)__builtin_amdgcn_mfma_f64_16x16x4f64( - a_elem, b_elem, (mfma_d4)acc, 0, 0, 0); - } - - /* Write 4 output elements to C (row-major) */ - const int c_row = row_base + local_row + d_row_grp; - c[(c_row + 0) * K + d_col] = (T)acc[0]; - c[(c_row + 1) * K + d_col] = (T)acc[1]; - c[(c_row + 2) * K + d_col] = (T)acc[2]; - c[(c_row + 3) * K + d_col] = (T)acc[3]; - } - - __syncthreads(); /* all WFs done before strip is overwritten */ - } -} - -#endif /* AMD MFMA guard */ - -} // namespace detail - - -template -__device__ void mTxmq_level5_k(T* __restrict__ c, const T* a, const T* b) { - extern __shared__ char smem_level5[]; - T* b_shmem = reinterpret_cast(smem_level5); - - /* All threads cooperate to load B once — stays resident throughout */ - for (int idx = (int)threadIdx.x; idx < K * K; idx += (int)blockDim.x) - b_shmem[idx] = b[idx]; - __syncthreads(); - -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx90a__) || defined(__gfx940__)) - if constexpr (K % 16 == 0) { - detail::mTxmq_level5_mfma(c, a, b_shmem); - __syncthreads(); - return; - } -#endif - detail::mTxmq_level3_impl(c, a, b_shmem); - __syncthreads(); -} - -template -inline size_type mTxmq_level5_shmem_size(int K) { - const int DIMI = K * K; - const int b_bytes = DIMI * (int)sizeof(T); - const int max_chunk = (64 * 1024 - b_bytes) / (K * (int)sizeof(T)) - 1; - const int chunk_rows = (DIMI/4 <= max_chunk) ? DIMI/4 : - (DIMI/8 <= max_chunk) ? DIMI/8 : DIMI/16; - const int a_stride = chunk_rows + 1; /* padded */ - return static_cast((DIMI + K * a_stride) * (int)sizeof(T)); -} - -template -constexpr Dim3 mTxmq_level5_blockdim(int /*K*/) { - return Dim3(256, 1, 1); /* 4 wavefronts, one per SIMD unit */ -} - -} // namespace mra - diff --git a/mxm_level7.h b/mxm_level7.h deleted file mode 100644 index d10571e..0000000 --- a/mxm_level7.h +++ /dev/null @@ -1,282 +0,0 @@ -#pragma once - -#include "util.h" -#include "mxm_level3.h" - -/** - * Level 7: B resident in VGPRs, A loaded directly from HBM/LDS into VGPRs (transposed), - * no LDS used for A or B. - * - * Computes C[K²×K] = A^T[K²×K] × B[K×K] using v_mfma_f64_16x16x4f64. - * - * Block = 256 threads (4 wavefronts, one per SIMD/Matrix core). - * - * --- B register layout --- - * B [K×K] row-major. For the MFMA lane mapping: b_k = tid/16 (0..3), b_col = tid%16 (0..15). - * Each lane pre-loads its K/4 relevant elements into b_reg[NSTEPS] and holds them in VGPRs - * across all three GEMMs. - * - * --- A register layout --- - * A [K×K²] col-major in source memory (global for GEMM 1, LDS for GEMMs 2&3). - * Wave w handles rows [w*ROWS_PER_WF .. w*ROWS_PER_WF+ROWS_PER_WF) of A^T. - * MFMA A-operand lane mapping: a_row = tid/4 (0..15), a_k = tid%4 (0..3). - * The transposed read is achieved by the MFMA lane mapping — no explicit LDS transpose. - * - * --- Pointer trick --- - * After each GEMM, C [K²×K] written row-major to LDS is reinterpreted as A [K×K²] - * col-major for the next GEMM via identical flat indices: - * Write: buf[i*K + j] (C row-major, i ∈ [0,K²), j ∈ [0,K)) - * Read: buf[k*K² + i] (A col-major, k ∈ [0,K), i ∈ [0,K²)) - * Since K*K² = K³ = K²*K, both index the same flat buffer — just different shapes. - * This is the standard MADNESS mTxmq pointer trick enabling 3D separable transforms. - * - * --- XOR swizzle for LDS bank conflicts --- - * gfx90a: 32 banks × 4 bytes. bank_of_double = (addr_in_doubles × 2) % 32. - * K²=256 doubles → stride (256×2)%32 = 0 → all k-values alias → 16-way conflict. - * - * Fix: apply a consistent XOR swizzle to all LDS addresses (both write and read): - * lds_swizzle(flat) = flat ^ (((flat >> 8) & 3) << 3) - * - * Write side (flat = i*K + j): - * flat >> 8 = i*K/256 = i/16 → for K=16, i/16 is exactly the k-group index - * (rows i ∈ [0,16) belong to k=0, rows i ∈ [16,32) to k=1, etc.) - * - * Read side (flat = k*K² + i): - * flat >> 8 = k*K²/256 = k → same swizzle key as write side - * - * XOR value = (k & 3) * 8 ∈ {0, 8, 16, 24}: scatters k=0..3 to banks {0,8,16,24}, - * reducing 16-way conflicts to 2-way (theoretical minimum for 64 lanes / 32 banks). - * - * --- Single-buffer reuse --- - * Within gemm7_pass the full A pre-load into a_reg completes before any write to dst. - * Combined with __syncthreads() between GEMMs, the same LDS buffer is safely reused: - * GEMM 1: global A → buf (LDS, row-major + swizzle) - * GEMM 2: buf (LDS) → buf (LDS, in-place, row-major + swizzle) - * GEMM 3: buf (LDS) → C (global, plain row-major, no swizzle) - * - * LDS size = K³ * sizeof(T). For K=16: 16³ × 8 = 32,768 bytes (32 KB). - * - * --- GEMM chain: global memory traffic --- - * GEMM 1: reads A from global. - * GEMMs 2&3: no global memory traffic for A. - * GEMM 3: writes final C to global. - * B remains in VGPRs throughout. - * - * Supported: K multiple of 16, gfx90a / gfx940. Falls back to L3 otherwise. - */ - -namespace mra { - -namespace detail { - -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx90a__) || defined(__gfx940__)) - -typedef double mfma_d4 __attribute__((ext_vector_type(4))); - -/** - * XOR swizzle for LDS bank conflict reduction. - * - * Applied consistently on both LDS write (flat = i*K + j) and LDS read - * (flat = k*K² + i): in both cases flat>>8 == k (for K=16), giving the same - * per-k-group XOR offset of (k & 3) * 8. - * - * Not applied to global memory accesses (GEMM 1 read, GEMM 3 write). - */ -__device__ __forceinline__ int lds_swizzle(int flat) { - return flat ^ (((flat >> 8) & 3) << 3); -} - -/** - * One GEMM pass: dst = src^T × B (B already in b_reg[]). - * - * SWIZZLE_SRC — true → apply lds_swizzle to src read address (LDS source) - * false → use flat col-major address as-is (global source) - * SWIZZLE_DST — true → apply lds_swizzle to dst write address (LDS dest) - * false → use flat row-major address as-is (global dest) - * - * Read path (A [K×K²] col-major, standard pointer-trick flat layout): - * flat_src = (s*4 + a_k) * K² + wave_row_offset + t*16 + a_row - * src[SWIZZLE_SRC ? lds_swizzle(flat_src) : flat_src] - * - * Write path (C [K²×K] row-major): - * flat_dst = (base_row + r) * K + d_col - * dst[SWIZZLE_DST ? lds_swizzle(flat_dst) : flat_dst] - * - * GEMM 1: — global A → LDS (swizzled) - * GEMM 2: — LDS → LDS (in-place, swizzled) - * GEMM 3: — LDS → global C (plain row-major) - */ -template -__device__ __forceinline__ void gemm7_pass( - const T* __restrict__ src, - T* __restrict__ dst, - const double b_reg[K / 4], - const int wave_row_offset, - const int a_row, - const int a_k, - const int d_row_grp, - const int d_col) -{ - constexpr int K2 = K * K; - constexpr int ROWS_PER_WF = K2 / 4; - constexpr int TILES_PER_WF = ROWS_PER_WF / 16; - constexpr int NSTEPS = K / 4; - - /* --- Pre-load A partition into registers (transposed in-flight) ----------- - * A [K×K²] col-major: src[k * K² + i] where k = s*4+a_k, i = wave_row_offset+t*16+a_row. - * MFMA lane mapping (a_row=tid/4, a_k=tid%4) achieves the in-flight transpose: - * no extra LDS step needed. XOR swizzle applied when reading from LDS so that - * swizzled-write addresses are matched exactly by swizzled-read addresses. */ - double a_reg[TILES_PER_WF][NSTEPS]; - #pragma unroll - for (int t = 0; t < TILES_PER_WF; ++t) - #pragma unroll - for (int s = 0; s < NSTEPS; ++s) { - const int flat_src = (s * 4 + a_k) * K2 - + wave_row_offset + t * 16 + a_row; - a_reg[t][s] = (double)src[SWIZZLE_SRC ? lds_swizzle(flat_src) : flat_src]; - } - - /* --- Issue all MFMAs for all tiles before draining any accumulator -------- - * Separate AGPR sets per tile allow the matrix core to pipeline tiles - * while the VALU/LDS units write completed tile results. */ - mfma_d4 acc[TILES_PER_WF]; - #pragma unroll - for (int t = 0; t < TILES_PER_WF; ++t) - acc[t] = {0.0, 0.0, 0.0, 0.0}; - - #pragma unroll - for (int t = 0; t < TILES_PER_WF; ++t) - #pragma unroll - for (int s = 0; s < NSTEPS; ++s) - acc[t] = (mfma_d4)__builtin_amdgcn_mfma_f64_16x16x4f64( - a_reg[t][s], b_reg[s], (mfma_d4)acc[t], 0, 0, 0); - - /* --- Write results to dst ------------------------------------------------- - * Row-major C [K²×K]: dst[i*K + j]. - * Pointer trick: the same flat buffer is later read as col-major A [K×K²]: - * A[k][i] = buf[k*K² + i] (flat index unchanged, shape reinterpreted). - * XOR swizzle applied on LDS writes matches the swizzle on subsequent reads. - * No swizzle on the final global write (GEMM 3). */ - #pragma unroll - for (int t = 0; t < TILES_PER_WF; ++t) { - const int base_row = wave_row_offset + t * 16 + d_row_grp; - #pragma unroll - for (int r = 0; r < 4; ++r) { - const int flat_dst = (base_row + r) * K + d_col; - dst[SWIZZLE_DST ? lds_swizzle(flat_dst) : flat_dst] = (T)acc[t][r]; - } - } -} - -/** - * Three-GEMM chain for level 7. - * - * B loaded once into VGPRs, resident through all three GEMMs. - * Single LDS buffer (K³ doubles) reused in-place; XOR swizzle on all LDS I/O. - * - * GEMM 1: gemm7_pass — global A → LDS - * GEMM 2: gemm7_pass — LDS → LDS (in-place) - * GEMM 3: gemm7_pass — LDS → global C - */ -template -__device__ void mTxmq_level7_mfma( - T* __restrict__ c, /* output [K²×K] row-major, global */ - const T* __restrict__ a, /* input [K×K²] col-major, global */ - const T* __restrict__ b, /* B [K×K] row-major, global */ - T* buf) /* LDS scratch: K³ doubles */ -{ - constexpr int K2 = K * K; - constexpr int NSTEPS = K / 4; - - const int tid_block = (int)threadIdx.x; - const int warp_id = tid_block / 64; - const int tid = tid_block % 64; - - /* MFMA lane indices — fixed for all three GEMMs. */ - const int a_row = tid / 4; - const int a_k = tid % 4; - const int b_k = tid / 16; - const int b_col = tid % 16; - const int d_row_grp = (tid / 16) * 4; - const int d_col = tid % 16; - const int wave_row_offset = warp_id * (K2 / 4); - - /* ----------------------------------------------------------------------- - * Load B into VGPRs once. - * Lane tid holds b_reg[s] = B[(s*4 + b_k)*K + b_col] for s=0..NSTEPS-1. - * ----------------------------------------------------------------------- */ - double b_reg[NSTEPS]; - #pragma unroll - for (int s = 0; s < NSTEPS; ++s) - b_reg[s] = (double)b[(s * 4 + b_k) * K + b_col]; - - /* ----------------------------------------------------------------------- - * GEMM 1: A (global, unpadded K² stride) → buf (LDS, row-major + swizzle) - * ----------------------------------------------------------------------- */ - gemm7_pass( - a, buf, b_reg, wave_row_offset, a_row, a_k, d_row_grp, d_col); - __syncthreads(); - - /* ----------------------------------------------------------------------- - * GEMM 2: buf (LDS, swizzled) reread as col-major via pointer trick - * → buf (LDS, in-place, row-major + swizzle) - * Full A pre-load into a_reg completes before any write, so same buffer - * is safe to overwrite. - * ----------------------------------------------------------------------- */ - gemm7_pass( - buf, buf, b_reg, wave_row_offset, a_row, a_k, d_row_grp, d_col); - __syncthreads(); - - /* ----------------------------------------------------------------------- - * GEMM 3: buf (LDS, swizzled) reread as col-major → c (global, row-major) - * ----------------------------------------------------------------------- */ - gemm7_pass( - buf, c, b_reg, wave_row_offset, a_row, a_k, d_row_grp, d_col); -} - -#endif /* AMD MFMA guard */ - -} // namespace detail - - -/** - * Public interface: executes the full three-GEMM transform chain. - * Dispatches to mTxmq_level7_mfma on gfx90a/gfx940; falls back to L3 elsewhere. - */ -template -__device__ void mTxmq_level7_k( - T* __restrict__ c, - const T* __restrict__ a, - const T* __restrict__ b) -{ - extern __shared__ char smem_level7[]; - T* buf = reinterpret_cast(smem_level7); - -#if defined(__HIP_DEVICE_COMPILE__) && (defined(__gfx90a__) || defined(__gfx940__)) - if constexpr (K % 16 == 0) { - detail::mTxmq_level7_mfma(c, a, b, buf); - return; - } -#endif - /* Fallback: L3. Load B into the LDS buffer and run one GEMM at a time. */ - for (int idx = (int)threadIdx.x; idx < K * K; idx += (int)blockDim.x) - buf[idx] = b[idx]; - __syncthreads(); - detail::mTxmq_level3_impl(c, a, buf); - __syncthreads(); -} - -template -inline size_type mTxmq_level7_shmem_size(int K) { - /* Flat K³ buffer: C [K²×K] row-major reinterpreted as A [K×K²] col-major - * via the pointer trick. No padding needed — swizzle handles bank conflicts. */ - return static_cast(K * K * K * (int)sizeof(T)); -} - -template -constexpr Dim3 mTxmq_level7_blockdim(int /*K*/) { - return Dim3(256, 1, 1); -} - -} // namespace mra diff --git a/mxm_rocwmma.h b/mxm_rocwmma.h deleted file mode 100644 index 472cabe..0000000 --- a/mxm_rocwmma.h +++ /dev/null @@ -1,282 +0,0 @@ -#ifndef MRA_OPS_MXM_ROCWMMA_H -#define MRA_OPS_MXM_ROCWMMA_H - -/** - * rocWMMA implementation of mTxmq: c(i,j) = sum_k a(k,i) * b(k,j) - * - * Matrices (all row-major): - * A : [K_ord × K_ord²] (used transposed; leading dimension = K_ord²) - * B : [K_ord × K_ord ] (leading dimension = K_ord) - * C : [K_ord² × K_ord] (leading dimension = K_ord) - * - * Supported K_ord values: multiples of 4 with K_ord ≤ 16 - * K=4, 8, 12, 16 (fp32/fp64). Falls back for K=6, 10, 20+. - * - * Rationale: MFMA K-tile is 4 (MFMA_F*_16x16x4), so K%4=0 is required. - * K_ord%4=0 implies K_ord²%16=0, so A/C have no partial M-tiles. - * K_ord≤16 keeps total tiles = K_ord²/16 ≤ 16, fitting in a 1024-thread block. - * - * Required block configuration: - * blockDim.x = mTxmq_rocwmma_nthreads(K_ord) - * = (K_ord² / 16) * warpSize (warpSize = 64 on AMD) - * - * Required shared memory: - * mTxmq_shmem_size(K_ord) * sizeof(T) bytes - * (= mTxmq_rocwmma_shmem_bytes()) - * - * Integration: include this header before "mra/ops/mxm.h"; the macro - * MRA_HAVE_MTXMQ defined here suppresses the fallback definition in mxm.h. - */ - -#include -#include "util.h" - -#ifdef __HIP_DEVICE_COMPILE__ - - -namespace mra { - -namespace detail { - - // WMMA fragment tile size (M × N × K_TILE) on AMD hardware. - // M and N are always 16; K_TILE depends on precision. - constexpr int ROCWMMA_TILE = 16; // fragment M and N dimension - // MFMA K-tile: 4 for both MFMA_F32_16x16x4F32 and MFMA_F64_16x16x4F64 - template constexpr int rocwmma_k_tile = 4; - - /** - * Core device function: C[M×N] = A^T[M×K] × B[K×N] - * where M = K_ord², N = K = K_ord. - * - * Shared memory layout (smem must hold mTxmq_rocwmma_shmem_bytes bytes): - * smem_b[K_ord × N_PAD] zero-padded B; N_PAD = ceil(K_ord/16)*16 = 16 - * smem_c[M × N_PAD] staged output (only allocated when N < N_PAD) - * - * When N == N_PAD (K_ord == 16): smem_c is omitted and each WMMA tile stores - * directly to global C, keeping shared memory usage minimal (~2 KB for fp64). - * - * When N < N_PAD (K_ord < 16): B is padded with zeros so the WMMA tiles - * produce correct results in the valid columns; the full N_PAD-wide output is - * staged in smem_c, and only the N valid columns are written back to global C. - */ - template - __device__ void mTxmq_rocwmma_core(T* __restrict__ c, const T* a, const T* b, T* smem) - { - static_assert(K_ord % rocwmma_k_tile == 0, - "K_ord must be divisible by the MFMA K tile size (4 for fp32/fp64)"); - - constexpr int M = K_ord * K_ord; - constexpr int N = K_ord; - constexpr int K = K_ord; - constexpr int K_TILE = rocwmma_k_tile; - - // N_PAD: round K_ord up to the next multiple of 16. - // For K_ord ≤ 16 this is always 16. - constexpr int N_PAD = ((N + ROCWMMA_TILE - 1) / ROCWMMA_TILE) * ROCWMMA_TILE; - - // M is already a multiple of 16 when K_ord % 4 == 0 (K_ord² = (4n)² = 16n²). - static_assert(M % ROCWMMA_TILE == 0, - "K_ord² must be 16-aligned; this is guaranteed when K_ord % 4 == 0"); - - constexpr int M_TILES = M / ROCWMMA_TILE; // output row tiles - constexpr int N_TILES = N_PAD / ROCWMMA_TILE; // = 1 for K_ord ≤ 16 - constexpr int TOTAL_TILES = M_TILES * N_TILES; - - // True when B/C need column zero-padding (K_ord not a multiple of 16). - constexpr bool NEEDS_PAD = (N != N_PAD); - - // ── Shared memory layout ─────────────────────────────────────────────── - T* smem_b = smem; // [K × N_PAD] - T* smem_c = smem_b + K * N_PAD; // [M × N_PAD] (used only when NEEDS_PAD) - - // ── Phase 1: load B into smem_b with zero-padding ───────────────────── - // smem_b[ki][ni] = B[ki][ni] if ni < N, else 0. - for (int idx = threadIdx.x; idx < K * N_PAD; idx += blockDim.x) { - const int ki = idx / N_PAD; - const int ni = idx % N_PAD; - smem_b[idx] = (ni < N) ? b[ki * N + ni] : T(0); - } - __syncthreads(); - - // ── Phase 2: WMMA computation (one wavefront per output tile) ───────── - const int warp_id = threadIdx.x / warpSize; - - rocwmma::fragment c_frag; - rocwmma::fill_fragment(c_frag, T(0)); - - if (warp_id < TOTAL_TILES) { - // tile_m : row tile index [0, M_TILES) - // tile_n : col tile index [0, N_TILES) — always 0 for K_ord ≤ 16 - const int tile_m = warp_id % M_TILES; - const int tile_n = warp_id / M_TILES; - const int m_start = tile_m * ROCWMMA_TILE; - const int n_start = tile_n * ROCWMMA_TILE; - - for (int k = 0; k < K; k += K_TILE) { - // ── A fragment ──────────────────────────────────────────────────── - // We want A^T[m_start : m_start+16, k : k+K_TILE]. - // A is stored row-major as [K rows × M cols]. - // Using col_major layout for matrix_a: element [m_i][k_j] is read from - // ptr[k_j * ld + m_i] where ptr = a + k*M + m_start, ld = M. - // This gives a[k*M + m_start + k_j*M + m_i] - // = A[k + k_j][m_start + m_i] - // = A^T[m_start + m_i][k + k_j]. ✓ - rocwmma::fragment a_frag; - - rocwmma::load_matrix_sync(a_frag, - a + k * M + m_start, - static_cast(M)); - - // ── B fragment ──────────────────────────────────────────────────── - // B tile: smem_b[k : k+K_TILE, n_start : n_start+16] - // Stored row-major in smem_b with stride N_PAD. - rocwmma::fragment b_frag; - - rocwmma::load_matrix_sync(b_frag, - smem_b + k * N_PAD + n_start, - static_cast(N_PAD)); - - rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - } - - if constexpr (!NEEDS_PAD) { - // ── K_ord == 16: N == N_PAD, no column padding ─────────────────── - // Each tile's 16 output columns are all valid. Store directly to C - // (row-major, leading dimension N). Warps write to non-overlapping - // row ranges so there are no global-memory conflicts. - rocwmma::store_matrix_sync(c + m_start * N + n_start, - c_frag, - static_cast(N), - rocwmma::mem_row_major); - } else { - // ── K_ord < 16: N < N_PAD, stage through smem_c ────────────────── - // Each WMMA tile covers N_PAD columns, but only the first N are valid. - // Store the full N_PAD-wide tile to smem_c so that the copy phase can - // extract just the N valid columns without corrupting adjacent data. - rocwmma::store_matrix_sync( - smem_c + tile_m * ROCWMMA_TILE * N_PAD + tile_n * ROCWMMA_TILE, - c_frag, - static_cast(N_PAD), - rocwmma::mem_row_major); - } - } - - // ── Phase 3: copy valid columns from smem_c to global C ─────────────── - // Only needed when N < N_PAD. - if constexpr (NEEDS_PAD) { - __syncthreads(); // wait for all tiles to finish storing to smem_c - // smem_c is [M × N_PAD] row-major; copy the valid [M × N] sub-block. - for (int idx = threadIdx.x; idx < M * N; idx += blockDim.x) { - const int mi = idx / N; - const int ni = idx % N; - c[idx] = smem_c[mi * N_PAD + ni]; - } - } - } - - // ── Shared-memory size helpers ───────────────────────────────────────────── - - template - constexpr size_t mTxmq_rocwmma_shmem_bytes() { - constexpr int N_PAD = ((K_ord + ROCWMMA_TILE - 1) / ROCWMMA_TILE) * ROCWMMA_TILE; - constexpr bool NEEDS_PAD = (K_ord != N_PAD); - constexpr size_t smem_b = static_cast(K_ord) * N_PAD; - constexpr size_t smem_c = NEEDS_PAD - ? static_cast(K_ord) * K_ord * N_PAD - : 0; - return (smem_b + smem_c) * sizeof(T); - } - -} // namespace detail - - -// ── Public interface ─────────────────────────────────────────────────────────── - -/** - * rocWMMA-accelerated mTxmq: c(i,j) = sum_k a(k,i) * b(k,j) - * - * A : [dimk × dimi] row-major (transposed in the multiply) - * B : [dimk × dimj] row-major - * C : [dimi × dimj] row-major - * - * Requires dimi == dimk² and dimj == dimk. - * Dispatches to rocWMMA for dimk ∈ {4, 8, 12, 16}; prints a diagnostic for - * other values (callers should fall back to a reference implementation for those). - */ -template -__device__ void mTxmq(size_type dimi, size_type dimj, size_type dimk, - cT* __restrict__ c, const aT* a, const bT* b) -{ - static_assert(std::is_same_v && std::is_same_v, - "rocWMMA mTxmq requires identical input and output types"); - - extern __shared__ char smem_raw[]; - cT* smem = reinterpret_cast(smem_raw); - - if (dimi == dimk * dimk && dimj == dimk) { - switch (dimk) { - case 4: detail::mTxmq_rocwmma_core< 4, cT>(c, a, b, smem); break; - case 8: detail::mTxmq_rocwmma_core< 8, cT>(c, a, b, smem); break; - case 12: detail::mTxmq_rocwmma_core<12, cT>(c, a, b, smem); break; - case 16: detail::mTxmq_rocwmma_core<16, cT>(c, a, b, smem); break; - default: - if (threadIdx.x == 0) - printf("mTxmq_rocwmma: unsupported dimk=%u " - "(supported: 4, 8, 12, 16 when dimi=dimk² and dimj=dimk)\n", dimk); - } - } else { - if (threadIdx.x == 0) - printf("mTxmq_rocwmma: unsupported dimi=%u dimj=%u dimk=%u\n", dimi, dimj, dimk); - } -} - -/** - * Required shared memory, in units of T elements. - * Pass mTxmq_shmem_size(K) * sizeof(T) to the kernel's shared memory - * allocation, e.g. via hipLaunchKernelGGL's sharedMemBytes argument. - * - * Shared memory requirements (fp64): - * K= 4 : 80 elements = 640 B - * K= 8 : 1152 elements = 9.0 KB - * K=12 : 2496 elements = 19.5 KB - * K=16 : 256 elements = 2.0 KB (direct store path; no smem_c needed) - */ -template -constexpr size_type mTxmq_shmem_size(size_type K) { - switch (K) { - case 4: return static_cast( - detail::mTxmq_rocwmma_shmem_bytes() / sizeof(T)); - case 8: return static_cast( - detail::mTxmq_rocwmma_shmem_bytes() / sizeof(T)); - case 12: return static_cast( - detail::mTxmq_rocwmma_shmem_bytes() / sizeof(T)); - case 16: return static_cast( - detail::mTxmq_rocwmma_shmem_bytes() / sizeof(T)); - default: return 0; - } -} - -/** - * Required block thread count for K_ord (AMD wavefront = 64 threads). - * K= 4 : 64 threads ( 1 wavefront) - * K= 8 : 256 threads ( 4 wavefronts) - * K=12 : 576 threads ( 9 wavefronts) - * K=16 : 1024 threads ( 16 wavefronts) - */ -template -constexpr size_type mTxmq_rocwmma_nthreads(size_type K) { - // (K² / 16) wavefronts × 64 threads/wavefront - return static_cast((K * K / 16) * 64); -} - -} // namespace mra - -#define MRA_HAVE_MTXMQ 1 - -#endif // __HIP_DEVICE_COMPILE__ -#endif // MRA_OPS_MXM_ROCWMMA_H diff --git a/test_mfma_4x4x4_layout.hip b/test_mfma_4x4x4_layout.hip new file mode 100644 index 0000000..4a578cc --- /dev/null +++ b/test_mfma_4x4x4_layout.hip @@ -0,0 +1,127 @@ +// Diagnostic: probe the v_mfma_f64_4x4x4f64 layout on GFX90A. +// +// Output type is a single double per lane (64 lanes * 1 = 64 outputs total = +// four 4x4 tiles, or four replicas of one 4x4 tile). Question: is it +// 4-broadcast (single A, single B, four D replicas) or 4-independent +// (four separate 4x4 GEMMs)? +// +// By analogy with v_mfma_f64_16x16x4f64, the natural layout hypothesis is: +// A (M=4, K=4): lane t -> A[t%4][(t/4)%4] (col-major within 16 lanes) +// B (K=4, N=4): lane t -> B[(t/4)%4][t%4] (row-major within 16 lanes) +// D (M=4, N=4): lane t -> D[(t/4)%4][t%4] (row-major within 16 lanes) +// with the block/replica index = t/16. +// +// Tests below feed distinguishing patterns to confirm (or refute) each piece. + +#include +#include + +__device__ inline double mfma_4x4x4_f64(double a, double b, double c) { + return __builtin_amdgcn_mfma_f64_4x4x4f64(a, b, c, 0, 0, 0); +} + +// Test 1: all-ones. Expected D[m][n] = 4 everywhere. Any deviation +// (including zeros) exposes subtle lane participation rules. +__global__ void probe_ones(double* out) { + int t = threadIdx.x; + double acc = mfma_4x4x4_f64(1.0, 1.0, 0.0); + out[t] = acc; +} + +// Test 2: broadcast vs independent. Zero out A from lanes 16..63. +// If broadcast: D = 4 in ALL four blocks (t/16 = 0,1,2,3) because the +// hardware reads A only from lanes 0..15 and replicates. +// If independent: D = 4 in block 0 only; D = 0 in blocks 1..3 because +// those blocks' A was zeroed. +__global__ void probe_broadcast_A(double* out) { + int t = threadIdx.x; + double a = (t < 16) ? 1.0 : 0.0; + double b = 1.0; + out[t] = mfma_4x4x4_f64(a, b, 0.0); +} + +// Test 3: same but zero B for lanes 16..63. +__global__ void probe_broadcast_B(double* out) { + int t = threadIdx.x; + double a = 1.0; + double b = (t < 16) ? 1.0 : 0.0; + out[t] = mfma_4x4x4_f64(a, b, 0.0); +} + +// Test 4: reveal n-dim of D. +// Hypothesis: B[k][n] under lane t<16 is B[(t/4)][t%4]. +// Feed b = t%4 (so B[k][n] = n, independent of k). +// Then D[m][n] = sum_k 1 * n = 4n, independent of m. +// The output tells us which lanes correspond to which n. +__global__ void probe_reveal_n(double* out) { + int t = threadIdx.x; + double a = 1.0; + double b = (double)(t & 3); + out[t] = mfma_4x4x4_f64(a, b, 0.0); +} + +// Test 5: reveal m-dim of D. +// Hypothesis: A[m][k] under lane t<16 is A[t%4][(t/4)]. +// Feed a = t%4 (so A[m][k] = m, independent of k). +// Then D[m][n] = sum_k m * 1 = 4m. +// (Note under broadcast, lanes 16..63 also need to provide sensible values; +// feeding a = t%4 keeps A[m][k] = m consistent across replicas.) +__global__ void probe_reveal_m(double* out) { + int t = threadIdx.x; + double a = (double)(t & 3); + double b = 1.0; + out[t] = mfma_4x4x4_f64(a, b, 0.0); +} + +// Test 6: full (m, n) probe. +// A[m][k] = m, B[k][n] = n => D[m][n] = sum_k m*n = 4mn. +// Encodes 4m*n = 0..36 into each lane. Combined with test 4/5 this +// uniquely identifies (m, n) per lane. +__global__ void probe_mn(double* out) { + int t = threadIdx.x; + double a = (double)(t & 3); + double b = (double)(t & 3); + out[t] = mfma_4x4x4_f64(a, b, 0.0); +} + +// Test 7: reveal A's k-dim. +// A[m][k] = 10k (so rows don't matter). +// Under hypothesis lane t<16 -> A[t%4][t/4], feed a = 10*(t/4). +// Then D[m][n] = sum_k 10k * 1 = 10*(0+1+2+3) = 60, uniform. +// If we feed a = 10*((t/4)%4), lanes 0..63 still feed A[m][k]=10k consistently. +// Uniform 60 across all lanes means A's k-dim is loaded correctly from 4 lanes. +__global__ void probe_reveal_Ak(double* out) { + int t = threadIdx.x; + double a = 10.0 * (double)((t >> 2) & 3); + double b = 1.0; + out[t] = mfma_4x4x4_f64(a, b, 0.0); +} + +static void run(const char* name, void (*kernel)(double*)) { + double *d_out, h[64]; + hipMalloc(&d_out, 64 * sizeof(double)); + kernel<<<1, 64>>>(d_out); + hipDeviceSynchronize(); + hipMemcpy(h, d_out, 64 * sizeof(double), hipMemcpyDeviceToHost); + hipFree(d_out); + + printf("\n=== %s ===\n", name); + printf("lane (t/16,t%%16) : value\n"); + for (int t = 0; t < 64; ++t) { + int block = t >> 4; + int sub = t & 15; + printf("%2d (%d,%2d): %6.1f", t, block, sub, h[t]); + if ((t & 3) == 3) printf("\n"); else printf(" "); + } +} + +int main() { + run("Test 1: A=1, B=1 (expect D=4 if broadcast, all lanes)", probe_ones); + run("Test 2: A nonzero only for t<16, B=1 (broadcast? all D=4 vs tile-0 only)", probe_broadcast_A); + run("Test 3: A=1, B nonzero only for t<16 (mirror of Test 2)", probe_broadcast_B); + run("Test 4: reveal n-dim (D[m][n] = 4n; expect 0,4,8,12 pattern vs n)", probe_reveal_n); + run("Test 5: reveal m-dim (D[m][n] = 4m; expect 0,4,8,12 pattern vs m)", probe_reveal_m); + run("Test 6: D[m][n] = 4mn (uniquely identifies (m,n) per lane)", probe_mn); + run("Test 7: A_k probe (expect 60 everywhere if A's k-dim loaded from 4 lanes)", probe_reveal_Ak); + return 0; +} diff --git a/test_mfma_4x4x4_layout2.hip b/test_mfma_4x4x4_layout2.hip new file mode 100644 index 0000000..2197688 --- /dev/null +++ b/test_mfma_4x4x4_layout2.hip @@ -0,0 +1,83 @@ +// Additional probes to pin down v_mfma_f64_4x4x4f64 fragment layout. +// Strategy: feed values that uniquely encode (row, col) per lane to reveal +// which subset of lanes contributes to each output cell. + +#include +#include + +__device__ inline double mfma(double a, double b, double c) { + return __builtin_amdgcn_mfma_f64_4x4x4f64(a, b, c, 0, 0, 0); +} + +// P1: a = lane_id, b = 1. Output tells us sum of A[m][0..3] where +// A[m][k] is loaded from "some" lanes; the sum reveals the identity +// of contributing lanes at row m. +__global__ void p1(double* out) { + int t = threadIdx.x; + out[t] = mfma((double)t, 1.0, 0.0); +} + +// P2: a = 1, b = lane_id. +__global__ void p2(double* out) { + int t = threadIdx.x; + out[t] = mfma(1.0, (double)t, 0.0); +} + +// P3: a = t&3, b = 1 (variant of prior test 5 with accumulator verification). +__global__ void p3(double* out) { + int t = threadIdx.x; + out[t] = mfma((double)(t & 3), 1.0, 0.0); +} + +// P4: a = 1, b = (t>>4) (reveal B's batch/block dim if any). +__global__ void p4(double* out) { + int t = threadIdx.x; + out[t] = mfma(1.0, (double)((t >> 4) & 3), 0.0); +} + +// P5: a = (t>>4), b = 1 (reveal A's batch/block dim if any). +__global__ void p5(double* out) { + int t = threadIdx.x; + out[t] = mfma((double)((t >> 4) & 3), 1.0, 0.0); +} + +// P6: a = lane_id only for a selected lane (t == TARGET); zero elsewhere. +// Shows how much that single a-value contributes to each output lane. +template +__global__ void p_single(double* out) { + int t = threadIdx.x; + double a = (t == TARGET) ? 1.0 : 0.0; + double b = 1.0; + out[t] = mfma(a, b, 0.0); +} + +static void run(const char* name, void (*kernel)(double*)) { + double *d_out, h[64]; + hipMalloc(&d_out, 64 * sizeof(double)); + kernel<<<1, 64>>>(d_out); + hipDeviceSynchronize(); + hipMemcpy(h, d_out, 64 * sizeof(double), hipMemcpyDeviceToHost); + hipFree(d_out); + + printf("\n=== %s ===\n", name); + for (int t = 0; t < 64; ++t) { + printf("%2d:%5.1f ", t, h[t]); + if ((t & 7) == 7) printf("\n"); + } +} + +int main() { + run("P1: a=lane_id, b=1", p1); + run("P2: a=1, b=lane_id", p2); + run("P3: a=t&3, b=1 (sanity, repeats prior test 5)", p3); + run("P4: a=1, b=(t>>4)&3 (does B depend on t/16?)", p4); + run("P5: a=(t>>4)&3, b=1 (does A depend on t/16?)", p5); + run("P-lane0 only A", p_single<0>); + run("P-lane1 only A", p_single<1>); + run("P-lane4 only A", p_single<4>); + run("P-lane15 only A", p_single<15>); + run("P-lane16 only A", p_single<16>); + run("P-lane32 only A", p_single<32>); + run("P-lane48 only A", p_single<48>); + return 0; +} diff --git a/test_mfma_4x4x4_layout3.hip b/test_mfma_4x4x4_layout3.hip new file mode 100644 index 0000000..c235227 --- /dev/null +++ b/test_mfma_4x4x4_layout3.hip @@ -0,0 +1,61 @@ +// Third probe pass for v_mfma_f64_4x4x4f64. +// Goal: nail down the B fragment layout. +// +// From the single-A-lane probes we learned: +// A at lane S (a=1, rest 0) --> output 1 at 4 lanes = 4*m + {0,1,2,3} +// where m = 4*(S%4) + (S/4)%4, k = S/16. +// (i.e. instruction does D_{16x4} = A_{16x4} * B_{4x4}, +// A[m][k] at lane (16k + (m%4)*4 + m/4).) +// +// Now: isolate B. Feed a=1 everywhere, b=1 only at one specific lane X. +// D[m][n] = sum_k A[m][k] * B[k][n] * 1_{lane has that value} +// = A[m][k_X] * 1 * [n == n_X] (since A=1, only k_X contributes) +// Output=1 wherever t%4 == n_X. Reveals n_X per lane. + +#include +#include + +__device__ inline double mfma(double a, double b, double c) { + return __builtin_amdgcn_mfma_f64_4x4x4f64(a, b, c, 0, 0, 0); +} + +template +__global__ void probe_b(double* out) { + int t = threadIdx.x; + double b = (t == X) ? 1.0 : 0.0; + out[t] = mfma(1.0, b, 0.0); +} + +static void run(const char* name, void (*kernel)(double*)) { + double *d_out, h[64]; + hipMalloc(&d_out, 64 * sizeof(double)); + kernel<<<1, 64>>>(d_out); + hipDeviceSynchronize(); + hipMemcpy(h, d_out, 64 * sizeof(double), hipMemcpyDeviceToHost); + hipFree(d_out); + + printf("\n=== %s ===\n", name); + // Only show where output != 0 and mark block/row structure. + int nnz = 0; + for (int t = 0; t < 64; ++t) if (h[t] != 0.0) ++nnz; + printf(" nonzero lanes: %d\n", nnz); + for (int t = 0; t < 64; ++t) { + if (h[t] != 0.0) { + int m = t / 4; + int n = t % 4; + printf(" lane %2d -> D[m=%2d][n=%d] = %.1f\n", t, m, n, h[t]); + } + } +} + +int main() { + run("B at lane 0 only", probe_b<0>); + run("B at lane 1 only", probe_b<1>); + run("B at lane 2 only", probe_b<2>); + run("B at lane 3 only", probe_b<3>); + run("B at lane 4 only", probe_b<4>); + run("B at lane 15 only", probe_b<15>); + run("B at lane 16 only", probe_b<16>); + run("B at lane 17 only", probe_b<17>); + return 0; +} diff --git a/test_mfma_layout.hip b/test_mfma_layout.hip new file mode 100644 index 0000000..b4dcc29 --- /dev/null +++ b/test_mfma_layout.hip @@ -0,0 +1,48 @@ +// Diagnostic: probe the MFMA output layout for v_mfma_f64_16x16x4f64. +// A = identity-like (A[m][k] = 1 if m==k else 0), B = all ones. +// Expected D[m][n] = sum_k A[m][k] * B[k][n] = (m < 4 ? 1.0 : 0.0) +// +// Write acc[0..3] per-thread to host-readable memory, inspect layout. + +#include +#include + +typedef double v4f64 __attribute__((ext_vector_type(4))); + +__global__ void probe_mfma(double* out) { + int t = threadIdx.x; + + // A: col-major (lane t -> A[t%16][t/16]); value A[m][k] = 1 + // B: hypothesis lane t -> B[t/16][t%16]; value B[k][n] = n + // -> D[m][n] = sum_k 1 * n = 4n (for K=4) + // Distinct D per column -> reveals col mapping per lane. + double a_val = 1.0; + double b_val = (double)(t & 15); + + v4f64 acc = v4f64{0.0, 0.0, 0.0, 0.0}; + acc = __builtin_amdgcn_mfma_f64_16x16x4f64(a_val, b_val, acc, 0, 0, 0); + + for (int e = 0; e < 4; ++e) { + out[t * 4 + e] = acc[e]; + } +} + +int main() { + double *d_out, h_out[256]; + hipMalloc(&d_out, 256 * sizeof(double)); + probe_mfma<<<1, 64>>>(d_out); + hipDeviceSynchronize(); + hipMemcpy(h_out, d_out, 256 * sizeof(double), hipMemcpyDeviceToHost); + hipFree(d_out); + + // A has diagonal 1s (m==k for m,k in 0..3), zero elsewhere. B = all ones (1 everywhere). + // Expected D[m][n] = 1 if m<4 else 0 (for all n in 0..15). + // + // Print acc[0..3] per lane so we can work out which (m, n) each acc[e] corresponds to. + printf("lane | acc[0] acc[1] acc[2] acc[3]\n"); + for (int t = 0; t < 64; ++t) { + printf("%3d | %4.1f %4.1f %4.1f %4.1f\n", + t, h_out[t*4], h_out[t*4+1], h_out[t*4+2], h_out[t*4+3]); + } + return 0; +} diff --git a/transform.h b/transform.h index 365506f..45b776d 100644 --- a/transform.h +++ b/transform.h @@ -1,8 +1,8 @@ #ifndef HAVE_TRANSFORM_H #define HAVE_TRANSFORM_H +#include #include "util.h" -#include "mxm_cublasdx.h" #include "mxm.h" /***************************************** @@ -13,7 +13,7 @@ /** - * T'is the version for HIP! Hop! + * L1: B read from global memory each pass (no LDS staging). */ template @@ -26,12 +26,9 @@ __device__ void transform( { constexpr const int ndim = 3; // fixed for benchmark - extern __shared__ T b_shm[]; - - for (int i = thread_id(); i < K*K; i += block_size()) b_shm[i] = b[i]; - const T* pc = b_shm; + /* L1: B stays in global memory — pc points directly to device B */ + const T* pc = b; T *t0=workspace, *t1=c; - //std::swap(t0,t1); auto tmp = t0; t0 = t1; t1 = tmp; @@ -43,7 +40,6 @@ __device__ void transform( auto tmp = t0; t0 = t1; t1 = tmp; - //std::swap(t0,t1); } /* no need to synchronize here, mTxmq synchronizes */ } @@ -113,11 +109,8 @@ inline void submit_transform_bench(int nfuncs, int nblocks, int K, { Dim3 thread_dims = mra::mTxmq_blockdim(K); assert(block_size(thread_dims) <= MAX_THREADS_PER_BLOCK); - auto smem_size = mra::mTxmq_shmem_size(K); - size_type K2 = K*K; - if (smem_size < K2*sizeof(T)) { - smem_size = K2*sizeof(T); - } + /* L1: no LDS used — smem = 0 */ + size_type smem_size = 0; CONFIGURE_KERNEL(transform_kernel, smem_size); CALL_KERNEL(transform_kernel, std::min(nfuncs, nblocks), thread_dims, smem_size, stream, (nfuncs, K, A, B, C, workspace)); } diff --git a/transform_blocked.h b/transform_blocked.h new file mode 100644 index 0000000..11143aa --- /dev/null +++ b/transform_blocked.h @@ -0,0 +1,235 @@ +#pragma once +#include +#include "util.h" + +#if defined(__HIP__) +typedef double v4f64 __attribute__((ext_vector_type(4))); + +__device__ inline v4f64 mfma_16x16x4_f64_shared(double a, double b, v4f64 c) { + return __builtin_amdgcn_mfma_f64_16x16x4f64(a, b, c, 0, 0, 0); +} +#endif + +#include "transform_blocked_k20.h" + +// transform_blocked.h — block-distributed 3D transform with AMD MFMA. +// +// One wavefront owns one K×K block of the tensor throughout the computation: +// wave s initially holds A[:, :, k'=s]; +// after the corner turn, wave s holds result[a=s, :, :]. +// +// Flow: +// distribute wave s loads A[:, :, k'=s] (strided read from HBM) +// Pass 1 (local) blk_s <- B^T · blk_s (a, j') +// Pass 2 (local) blk_s <- blk_s · B (a, b) +// corner turn LDS all-to-all: wave t ends up holding +// temp2[a=t, b, k'] stored as (b, k') +// Pass 3 (local) blk_t <- blk_t · B (b, c) — canonical +// store wave s writes result[s, :, :] to HBM +// +// Local GEMMs use v_mfma_f64_16x16x4f64 (GFX90A). One MFMA does a 16×16 +// output with K=4 contraction; for K=16 we chain 4 MFMAs per pass. +// +// LDS layout (34 KB at K=16): +// buf[K³] single K³ scratch, in-place across passes via register stash +// B_lds[K²] cached B matrix, shared across all waves +// +// Thread-block size: 64 × K = 1024 threads at K=16. + +#if defined(__HIP__) +__device__ inline v4f64 mfma_16x16x4_f64(double a, double b, v4f64 c) { + return __builtin_amdgcn_mfma_f64_16x16x4f64(a, b, c, 0, 0, 0); +} +#endif + +template +__global__ +__launch_bounds__(K * 64, 1) +void transform_kernel_blocked(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*workspace unused*/) +{ + static_assert(std::is_same::value, "MFMA path is FP64 only"); + static_assert(K == 16, "MFMA path: K=16 only for now"); + static_assert(K * K % 64 == 0, "K^2 must be a multiple of the wavefront size (64)"); + + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int ELEMS_PER_LANE = K2 / 64; // 4 at K=16 + constexpr int NMFMA = K / 4; // 4 MFMAs per K=16 pass + + // LDS bank-conflict pad: give each "row" inside a wave's K×K region an + // extra column of dead storage so within-wave stride-K accesses (pass 2/3 + // A_frag loads, corner-turn cross-wave writes) don't collide on the same + // 128-byte bank row. Per-wave region becomes K × (K+1) instead of K × K. + constexpr int BK = K + 1; // padded inner (row) stride + constexpr int WB = K * BK; // per-wave region size (padded) + + extern __shared__ unsigned char _smem_blk[]; + T* buf = reinterpret_cast(_smem_blk); + T* B_lds = buf + K * WB; + + const int s = threadIdx.y; // wave index (0..K-1) + const int t = threadIdx.x; // lane within wave (0..63) + const int tid = s * 64 + t; + const int nthr = blockDim.x * blockDim.y; + + // Cache B into LDS once, using wide double4 loads. + // K²=256 doubles; with double4 (4-wide) we need K²/4 = 64 threads -- the + // first wave does it, the rest wait at the barrier. Each active thread + // issues a single global_load_dwordx4 (vs global_load_dwordx2 before), + // halving the load instruction count. B pointer is cudaMalloc'd and + // B_lds is 16-byte aligned (offset K*WB*8 = 34816, multiple of 16). + static_assert(K2 % 4 == 0, "B cache double4 path needs K² divisible by 4"); + if (tid < K2 / 4) { + const double4* B_vec = reinterpret_cast(B); + double4* B_lds_vec = reinterpret_cast(B_lds); + B_lds_vec[tid] = B_vec[tid]; + } + __syncthreads(); + + // One block = one tensor. Grid is sized to nfuncs at launch. + { + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // --- Distribute: cooperative coalesced load of the K^3 tensor --- + // Each thread loads 4 CONTIGUOUS doubles from HBM as a single double4 + // (global_load_dwordx8 = 32 bytes per lane on GFX90A). Per-thread + // contiguity is what lets the compiler emit a wide load; across 16 + // consecutive lanes this still touches just 4 cache lines (same HBM + // traffic as before, 4× fewer load instructions). + // + // Per-thread HBM indices: a_ptr[4*tid .. 4*tid+3]. + // For K=16, all 4 values share the same (i, j) with consecutive k: + // k_start = (4*tid) % 16 ∈ {0, 4, 8, 12} + // + // buf layout is canonical (i, j, k), k fastest: + // buf[i*WB + j*BK + k] = A[i, j, k] + { + static_assert(K == 16, "wide-load distribute tuned for K=16"); + const double4* a_ptr_vec = reinterpret_cast(a_ptr); + const int base_idx = 4 * tid; // 0..K^3-4 + const int i = base_idx >> 8; // / K^2 + const int j = (base_idx >> 4) & 15; + const int k_start = base_idx & 15; // 0, 4, 8, or 12 + + double4 v = a_ptr_vec[tid]; // one global_load_dwordx8 + T* row = &buf[i*WB + j*BK + k_start]; + row[0] = v.x; + row[1] = v.y; + row[2] = v.z; + row[3] = v.w; + } + __syncthreads(); // cross-wave writes visible before pass-1 reads + + // --- Pass 1 MFMA (SWAPPED OPERANDS): compute blk^T · B instead of B^T · blk --- + // GFX90A v_mfma_f64_16x16x4f64 confirmed layouts: + // A_frag (M=16, K=4): lane t -> A[t%16][t/16] (col-major) + // B_frag (K=4, N=16): lane t -> B[t/16][t%16] (row-major) + // D (M=16, N=16): lane t acc[e] -> D[(t/16)+4*e][t%16] + // + // By feeding A_frag = blk^T (from buf, treating (i,j) as if transposed) + // and B_frag = B (from B_lds), MFMA produces D where + // D[j][a] = sum_i blk[i][j] * B[i][a] = temp1[a][j] + // i.e. temp1 stored with its axes swapped in the register file: + // lane t, acc[e] == temp1[t%16][(t/16) + 4*e] + // + // This is exactly the layout pass 2's A_frag wants at iter p = e, + // so pass 2 can consume pass 1's acc directly -- no LDS round trip. + // With canonical buf layout, thread t's A_frag contribution at iter p + // lives at buf[i*WB + j*BK + k] with i=p*4+t/16, j=t%16, k=s. + v4f64 acc1 = v4f64{0.0, 0.0, 0.0, 0.0}; + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + double a_val = buf[(p*4 + (t >> 4)) * WB + (t & 15) * BK + s]; + double b_val = B_lds[(p*4 + (t >> 4)) * K + (t & 15)]; + acc1 = mfma_16x16x4_f64(a_val, b_val, acc1); + } + // No LDS writeback: acc1[p] is pass 2's a_val at iter p directly. + + // --- Pass 2 MFMA: out = temp1 · B (cols: j' -> b) --- + // A_frag = temp1 (from acc1), B_frag = B. + // thread t at iter p uses acc1[p] = temp1[t%16][p*4 + t/16] + v4f64 acc = v4f64{0.0, 0.0, 0.0, 0.0}; + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + double a_val = acc1[p]; + double b_val = B_lds[(p*4 + (t >> 4)) * K + (t & 15)]; + acc = mfma_16x16x4_f64(a_val, b_val, acc); + } + // Pass 2 output is in acc. The corner-turn stash step used to read + // pass 2's buf at positions (a = t/16 + 4e, b = t%16) -- exactly the + // positions MFMA already deposited in acc. So acc[e] IS the stash; + // we can skip both the pass-2 store and the stash read, and write + // acc directly into the cross-wave destination below. + + __syncthreads(); // all waves done with pass-2 reads from own region + + // --- Corner turn (cross-wave write, fused with pass-2 store) --- + // buf[a*WB + b*BK + s] <- acc[e] where (a, b) = (t/16 + 4e, t%16) + #pragma unroll + for (int e = 0; e < ELEMS_PER_LANE; ++e) { + int idx = t + e * 64; + int a = idx / K; + int b = idx % K; + buf[a*WB + b*BK + s] = acc[e]; + } + __syncthreads(); + + // --- Pass 3 MFMA + store: out = blk · B, write directly to HBM --- + // Same orientation as pass 2 (blk has (b, k') layout). + acc = v4f64{0.0, 0.0, 0.0, 0.0}; + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + double a_val = buf[s*WB + (t & 15) * BK + (p*4 + (t >> 4))]; + double b_val = B_lds[(p*4 + (t >> 4)) * K + (t & 15)]; + acc = mfma_16x16x4_f64(a_val, b_val, acc); + } + // No sync: acc is in registers, about to write to HBM. + + #pragma unroll + for (int e = 0; e < 4; ++e) { + int row = (t >> 4) + 4 * e; // b + int col = t & 15; // c + c_ptr[s*K2 + row*K + col] = acc[e]; + } + } +} + +template +inline size_type blocked_shmem_size(int K) { + // Padded buf (K × K × (K+1)) + B cache (K × K) + return (size_type)((K * K * (K + 1) + K * K) * sizeof(T)); +} + +template +inline Dim3 blocked_blockdim(int K) { + return Dim3(64, K, 1); +} + +template +inline void submit_transform_bench_blocked(int nfuncs, int /*nblocks*/, int K, + const T* A, const T* B, T* C, T* workspace, + Stream stream) +{ + if (K == 16) { + constexpr int Kv = 16; + Dim3 td = blocked_blockdim(Kv); + size_type smem = blocked_shmem_size(Kv); + CONFIGURE_KERNEL((transform_kernel_blocked), smem); + // Grid = nfuncs: one block per tensor. nblocks arg is ignored here + // (matches upstream convention for single-function-per-kernel kernels). + CALL_KERNEL((transform_kernel_blocked), nfuncs, td, smem, stream, + (nfuncs, A, B, C, workspace)); + } else if (K == 20) { + submit_transform_bench_blocked_k20(nfuncs, A, B, C, workspace, stream); + } else { + fprintf(stderr, "blocked transform: K=%d not supported\n", K); + assert(false); + } +} diff --git a/transform_blocked_k20.h b/transform_blocked_k20.h new file mode 100644 index 0000000..24f37dd --- /dev/null +++ b/transform_blocked_k20.h @@ -0,0 +1,1119 @@ +#pragma once +#include +#include "util.h" + +// transform_blocked_k20.h — corner-turn 3D transform at K=20. +// +// Two variants live here: +// * scalar — `transform_kernel_blocked_k20_scalar` (correctness baseline) +// * MFMA — `transform_kernel_blocked_k20_mfma` (v_mfma_f64_4x4x4f64) +// +// Structural notes (shared): +// +// * K=20 needs 20 slabs but one wave (64 lanes) per slab would require +// 20 × 64 = 1280 threads, past the 1024/block limit. We use 10 waves +// and give each wave two k' slabs (w, w+10). +// +// * LDS budget is tight. K³ = 8000 doubles = 64 KB. No room for a +// B_lds, and bank-conflict padding (K → K+1) would push us to 8400 +// doubles. B lives in HBM and goes through the L1 cache; at +// K²=400 it's 3.2 KB and fits L1 comfortably. +// +// Logical flow (matches CPU reference cpu_transform3d_blocked): +// +// distribute : buf[i, j, k] ← A[i, j, k] canonical layout +// Pass 1 : buf[a, j, k] ← Σ_i B[i, a] · buf[i, j, k] +// Pass 2 : buf[a, b, k] ← Σ_j buf[a, j, k] · B[j, b] +// "corner turn": implicit -- buf is already (a, b, k); wave w +// simply reads rows a ∈ {w, w+10} in pass 3 instead of +// the k' it wrote. No LDS shuffle needed. +// Pass 3 : C[a, b, c] ← Σ_k buf[a, b, k] · B[k, c] +// (write directly to HBM) + +// ============================================================================ +// Scalar variant (correctness baseline) +// ============================================================================ + +template +__global__ +__launch_bounds__(640, 1) +void transform_kernel_blocked_k20_scalar(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*ws unused*/) +{ + static_assert(std::is_same::value, "K=20 blocked: FP64 only"); + + constexpr int K = 20; + constexpr int K2 = K * K; // 400 + constexpr int K3 = K * K * K; // 8000 + constexpr int NWAVES = 10; + constexpr int SPW = 2; // slabs per wave + constexpr int LANES = 64; + constexpr int MAX_SLOTS = 7; // ceil(K²/64) + + extern __shared__ unsigned char _smem_k20[]; + T* buf = reinterpret_cast(_smem_k20); + + const int w = threadIdx.y; + const int t = threadIdx.x; + const int tid = w * LANES + t; + const int nthr = blockDim.x * blockDim.y; + + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + for (int idx = tid; idx < K3; idx += nthr) buf[idx] = a_ptr[idx]; + __syncthreads(); + + // Pass 1 + { + double acc[SPW][MAX_SLOTS]; + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) { acc[ss][slot] = 0.0; continue; } + const int a = cell_idx / K, j = cell_idx % K; + double s = 0.0; + #pragma unroll + for (int i = 0; i < K; ++i) s += B[i * K + a] * buf[i * K2 + j * K + kp]; + acc[ss][slot] = s; + } + } + __syncthreads(); + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) continue; + const int a = cell_idx / K, j = cell_idx % K; + buf[a * K2 + j * K + kp] = acc[ss][slot]; + } + } + } + __syncthreads(); + + // Pass 2 + { + double acc[SPW][MAX_SLOTS]; + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) { acc[ss][slot] = 0.0; continue; } + const int a = cell_idx / K, b = cell_idx % K; + double s = 0.0; + #pragma unroll + for (int j = 0; j < K; ++j) s += buf[a * K2 + j * K + kp] * B[j * K + b]; + acc[ss][slot] = s; + } + } + __syncthreads(); + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) continue; + const int a = cell_idx / K, b = cell_idx % K; + buf[a * K2 + b * K + kp] = acc[ss][slot]; + } + } + } + __syncthreads(); + + // Pass 3 (direct to HBM) + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int a = w + ss * NWAVES; + #pragma unroll + for (int slot = 0; slot < MAX_SLOTS; ++slot) { + const int cell_idx = slot * LANES + t; + if (cell_idx >= K2) continue; + const int b = cell_idx / K, c = cell_idx % K; + double s = 0.0; + #pragma unroll + for (int kp = 0; kp < K; ++kp) s += buf[a * K2 + b * K + kp] * B[kp * K + c]; + c_ptr[a * K2 + b * K + c] = s; + } + } +} + +// ============================================================================ +// MFMA variant — v_mfma_f64_4x4x4f64 +// +// Layout (empirically confirmed in test_mfma_4x4x4_layout*.hip): +// The instruction computes FOUR INDEPENDENT 4×4×4 GEMMs per call. +// With lane S = 16·α + 4·g + β (g = group, α = S/16, β = (S/4)%4), +// each of the four groups g ∈ {0,1,2,3} operates on its own 4×4 A, +// 4×4 B and 4×4 D: +// A_g[m = β ][k = α] at lane S +// B_g[k = α ][n = β] at lane S +// D_g[m = α ][n = β] at lane S +// +// GEMM plan per pass, per slab: +// Output is 20×20 = 5×5 tiles of 4×4. Inner contraction K_inner=20=5×4. +// Per output tile: 5 MFMA calls to accumulate the K-slices. +// Per MFMA call: 4 tiles computed simultaneously (one per g). +// 25 tiles ⇒ 7 MFMA rounds (25/4 rounds = 6 full + 1 partial, last +// round leaves groups 1,2,3 idle -- fed zeros so their output is +// discarded cleanly). +// ============================================================================ + +#if defined(__HIP__) +__device__ inline double mfma_4x4x4_f64(double a, double b, double c) { + return __builtin_amdgcn_mfma_f64_4x4x4f64(a, b, c, 0, 0, 0); +} +#endif + +template +__global__ +__launch_bounds__(640, 1) +void transform_kernel_blocked_k20_mfma(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*ws unused*/) +{ + static_assert(std::is_same::value, "K=20 blocked: FP64 only"); + + constexpr int K = 20; + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int NWAVES = 10; + constexpr int SPW = 2; + constexpr int LANES = 64; + constexpr int TILES_DIM = 5; // K / 4 + constexpr int K_SLICES = 5; // K / 4 + constexpr int TOTAL_TILES = TILES_DIM * TILES_DIM; // 25 + constexpr int MFMA_ROUNDS = 7; // ceil(25/4) + + extern __shared__ unsigned char _smem_k20m[]; + T* buf = reinterpret_cast(_smem_k20m); + + const int w = threadIdx.y; + const int t = threadIdx.x; + const int tid = w * LANES + t; + const int nthr = blockDim.x * blockDim.y; + + // Lane decomposition: S = 16α + 4g + β with β = S%4, g = (S/4)%4, α = S/16. + const int beta = t & 3; // β: low 2 bits + const int g = (t >> 2) & 3; // g: middle 2 bits -- MFMA group id + const int alpha = (t >> 4) & 3; // α: high 2 bits + + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // ----- Distribute ----- + for (int idx = tid; idx < K3; idx += nthr) buf[idx] = a_ptr[idx]; + __syncthreads(); + + // ----- Pass 1: buf[a, j, kp] = Σ_i B[i, a] · buf[i, j, kp] ----- + // MFMA A_g[m=β][k=α] ← B[i=4·k_tile+α, a=4·a_tile_g+β] + // MFMA B_g[k=α][n=β] ← buf[i=4·k_tile+α, j=4·j_tile_g+β, kp] + // D_g[m=α][n=β] ↦ buf[a=4·a_tile_g+α, j=4·j_tile_g+β, kp] + { + double tile_acc[SPW][MFMA_ROUNDS]; + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + const bool valid = tile_idx < TOTAL_TILES; + const int a_tile = valid ? tile_idx / TILES_DIM : 0; + const int j_tile = valid ? tile_idx % TILES_DIM : 0; + + double acc = 0.0; + #pragma unroll + for (int k_tile = 0; k_tile < K_SLICES; ++k_tile) { + const int gi = 4 * k_tile + alpha; + const int ga = 4 * a_tile + beta; + const int gj = 4 * j_tile + beta; + const double aval = valid ? B[gi * K + ga] : 0.0; + const double bval = valid ? buf[gi * K2 + gj * K + kp] : 0.0; + acc = mfma_4x4x4_f64(aval, bval, acc); + } + tile_acc[ss][round] = acc; + } + } + + __syncthreads(); + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + if (tile_idx >= TOTAL_TILES) continue; + const int a_tile = tile_idx / TILES_DIM; + const int j_tile = tile_idx % TILES_DIM; + const int ga = 4 * a_tile + alpha; + const int gj = 4 * j_tile + beta; + buf[ga * K2 + gj * K + kp] = tile_acc[ss][round]; + } + } + } + __syncthreads(); + + // ----- Pass 2: buf[a, b, kp] = Σ_j buf[a, j, kp] · B[j, b] ----- + // MFMA A_g[m=β][k=α] ← buf[a=4·a_tile_g+β, j=4·j_tile+α, kp] + // (here "m" of the mfma = "a" of our GEMM; "k" of the mfma = "j") + // MFMA B_g[k=α][n=β] ← B[j=4·j_tile+α, b=4·b_tile_g+β] + // D_g[m=α][n=β] ↦ buf[a=4·a_tile_g+α, b=4·b_tile_g+β, kp] + { + double tile_acc[SPW][MFMA_ROUNDS]; + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + const bool valid = tile_idx < TOTAL_TILES; + const int a_tile = valid ? tile_idx / TILES_DIM : 0; + const int b_tile = valid ? tile_idx % TILES_DIM : 0; + + double acc = 0.0; + #pragma unroll + for (int j_tile = 0; j_tile < K_SLICES; ++j_tile) { + const int ga = 4 * a_tile + beta; + const int gj = 4 * j_tile + alpha; + const int gb = 4 * b_tile + beta; + const double aval = valid ? buf[ga * K2 + gj * K + kp] : 0.0; + const double bval = valid ? B[gj * K + gb] : 0.0; + acc = mfma_4x4x4_f64(aval, bval, acc); + } + tile_acc[ss][round] = acc; + } + } + + __syncthreads(); + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + if (tile_idx >= TOTAL_TILES) continue; + const int a_tile = tile_idx / TILES_DIM; + const int b_tile = tile_idx % TILES_DIM; + const int ga = 4 * a_tile + alpha; + const int gb = 4 * b_tile + beta; + buf[ga * K2 + gb * K + kp] = tile_acc[ss][round]; + } + } + } + __syncthreads(); + + // ----- Pass 3: C[a, b, c] = Σ_k buf[a, b, k] · B[k, c] ----- + // After pass 2 the buf layout is (a, b, k'). Wave w now "owns" + // rows a ∈ {w, w+10} -- it reads from cross-wave contributions. + // + // For pass 3, fix a = the wave's owned row. The per-slab GEMM is + // over (b, c) with inner k'. + // + // MFMA A_g[m=β][k=α] ← buf[a=fixed, b=4·b_tile_g+β, k'=4·k_tile+α] + // ("m" = "b" of the GEMM, "k" = "k'") + // MFMA B_g[k=α][n=β] ← B[k'=4·k_tile+α, c=4·c_tile_g+β] + // D_g[m=α][n=β] ↦ C[a=fixed, b=4·b_tile_g+α, c=4·c_tile_g+β] + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int a = w + ss * NWAVES; + + double tile_acc[MFMA_ROUNDS]; + + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + const bool valid = tile_idx < TOTAL_TILES; + const int b_tile = valid ? tile_idx / TILES_DIM : 0; + const int c_tile = valid ? tile_idx % TILES_DIM : 0; + + double acc = 0.0; + #pragma unroll + for (int k_tile = 0; k_tile < K_SLICES; ++k_tile) { + const int gb = 4 * b_tile + beta; + const int gk = 4 * k_tile + alpha; + const int gc = 4 * c_tile + beta; + const double aval = valid ? buf[a * K2 + gb * K + gk] : 0.0; + const double bval = valid ? B[gk * K + gc] : 0.0; + acc = mfma_4x4x4_f64(aval, bval, acc); + } + tile_acc[round] = acc; + } + + // No __syncthreads needed: writes go to HBM, not LDS. + #pragma unroll + for (int round = 0; round < MFMA_ROUNDS; ++round) { + const int tile_idx = round * 4 + g; + if (tile_idx >= TOTAL_TILES) continue; + const int b_tile = tile_idx / TILES_DIM; + const int c_tile = tile_idx % TILES_DIM; + const int gb = 4 * b_tile + alpha; + const int gc = 4 * c_tile + beta; + c_ptr[a * K2 + gb * K + gc] = tile_acc[round]; + } + } +} + +// ============================================================================ +// Hybrid MFMA variant — 16×16×4 main + 4×4×4 strips + 4×4 corner +// +// Profile of the pure-4×4×4 path showed MFMA units at 12 % busy (starved by +// per-instruction VALU/LDS overhead), not HBM-bound. The fix is to push +// more FMAs per instruction by using v_mfma_f64_16x16x4f64 for the bulk +// 16×16 sub-tile; the remaining 16×4, 4×16 and 4×4 pieces are handled by +// v_mfma_f64_4x4x4f64 with 4, 4 and 1 groups respectively. +// +// Per slab per pass: 5 k-slices × 4 MFMA kinds = 20 MFMAs (vs. 35 for the +// pure-4×4×4 path -- a 43 % reduction in instruction issues). +// +// Tile partition of a 20×20 output slab: +// ┌────────────────┬──────┐ +// │ 16×16 main │16×4 │ right strip (4 tiles, 4 MFMA groups) +// │ │right │ +// ├────────────────┼──────┤ +// │ 4×16 bottom │ 4×4 │ corner (1 tile, 1 MFMA group, 3 wasted) +// │ │corner│ +// └────────────────┴──────┘ +// +// Lane roles (inherit both layouts): +// 16×16×4 A (16×4): lane t → A[t%16][t/16] col-major +// B ( 4×16): lane t → B[t/16][t%16] row-major +// D (16×16): lane t acc[e] → D[(t/16)+4e][t%16] +// 4×4×4 (per group g): +// A_g[m=β][k=α] at S = 16α + 4g + β +// B_g[k=α][n=β] +// D_g[m=α][n=β] +// +// Optimisations on top of the vanilla hybrid: +// (a) Wide double4 distribute from HBM. +// (b) Pass 1 → Pass 2 register fusion for the main + right strip tiles. +// The K=16 kernel already uses this trick; we swap pass 1's main/right +// operands so the accumulator layout matches pass 2's A-frag directly: +// feed A_frag = buf, B_frag = B ⇒ D[m][n] = temp1[n][m] +// at lane t acc_main[e] = temp1[t%16][(t/16) + 4e] (main) +// at lane S acc_right = temp1[4g+β][16+α] (right strip) +// Pass 2 main slices 0..3 (j∈0..15) read acc_main[0..3] directly; slice 4 +// (j∈16..19) reads acc_right. Pass 2 right strip does the same — its +// A-frag layout at lane S is temp1[4g+β][4p+α] which matches acc_main +// for p=0..3 and acc_right for p=4. Bottom strip and corner can't be +// fused cleanly (α/β transpose mismatch) and still use LDS. +// ============================================================================ + +template +__global__ +__launch_bounds__(640, 1) +void transform_kernel_blocked_k20_hybrid(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*ws unused*/) +{ + static_assert(std::is_same::value, "K=20 blocked: FP64 only"); + + constexpr int K = 20; + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int NWAVES = 10; + constexpr int SPW = 2; + constexpr int K_SLICES = K / 4; // 5 + + extern __shared__ unsigned char _smem_hyb[]; + T* buf = reinterpret_cast(_smem_hyb); + + const int w = threadIdx.y; + const int t = threadIdx.x; + const int tid = w * 64 + t; + const int nthr = blockDim.x * blockDim.y; + + // 4×4×4 lane decomposition: S = 16α + 4g + β + const int beta = t & 3; + const int g = (t >> 2) & 3; + const int alpha = (t >> 4) & 3; + // 16×16×4 lane decomposition + const int tmod16 = t & 15; + const int tdiv16 = t >> 4; + + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // --- Wide (double4) distribute. One load per 4 doubles = 2000 loads total. + // a_ptr[] is hipMalloc'd (16-byte aligned) and K³ = 8000 is divisible by 4. + // Each thread issues at most ceil(2000/640) = 4 double4 loads. + { + const double4* a_ptr_vec = reinterpret_cast(a_ptr); + constexpr int NVEC = K3 / 4; // 2000 + for (int idx = tid; idx < NVEC; idx += nthr) { + double4 v = a_ptr_vec[idx]; + const int base = idx * 4; + buf[base + 0] = v.x; + buf[base + 1] = v.y; + buf[base + 2] = v.z; + buf[base + 3] = v.w; + } + } + __syncthreads(); + + // -------------------------------------------------------------------- + // PASS 1 (swapped operands) : keep acc_main and acc_right in registers + // across the sync so pass 2 can read them directly instead of going + // through LDS. See header comment for the layout derivation. + // + // With A_frag = buf and B_frag = B, the MFMA produces D = temp1^T. + // Concretely: at lane t acc_main[e] = temp1[t%16][(t/16) + 4e]. + // + // Bottom strip and corner are NOT swapped (their α/β transpose doesn't + // align with pass 2's bottom/corner A-frag layout) and still go to LDS. + // -------------------------------------------------------------------- + v4f64 p1_main[SPW]; // pass-1 main accumulator, survives to pass 2 + double p1_right[SPW]; // pass-1 right accumulator, survives to pass 2 + + { + double acc_bottom[SPW], acc_corner[SPW]; + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + p1_main[ss] = v4f64{0.0, 0.0, 0.0, 0.0}; + p1_right[ss] = 0.0; + acc_bottom[ss] = 0.0; + acc_corner[ss] = 0.0; + } + + #pragma unroll + for (int ks = 0; ks < K_SLICES; ++ks) { + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + + // main 16×16 (SWAPPED: A_frag=buf, B_frag=B) + // feeds D[m=(t/16)+4e][n=t%16] = temp1[t%16][(t/16)+4e] + { + const int gi = 4*ks + tdiv16; // global i (0..19) + const int gj = tmod16; // pass-1 "m" from buf = j + const double av = buf[gi * K2 + gj * K + kp]; + const double bv = B[gi * K + gj]; // pass-1 "n" from B = a + p1_main[ss] = mfma_16x16x4_f64_shared(av, bv, p1_main[ss]); + } + // right strip 16×4 (SWAPPED) + // at lane S acc_right = temp1[4g+β][16+α] + { + const int gi = 4*ks + alpha; + const int gj = 16 + beta; // j from buf = 16..19 + const int ga = 4*g + beta; // a from B = 0..15 + const double av = buf[gi * K2 + gj * K + kp]; + const double bv = B[gi * K + ga]; + p1_right[ss] = mfma_4x4x4_f64(av, bv, p1_right[ss]); + } + // bottom strip 4×16 (NOT swapped — goes to LDS) + { + const int gi = 4*ks + alpha; + const int ga = 16 + beta; + const int gj = 4*g + beta; + const double av = B[gi * K + ga]; + const double bv = buf[gi * K2 + gj * K + kp]; + acc_bottom[ss] = mfma_4x4x4_f64(av, bv, acc_bottom[ss]); + } + // corner 4×4 (NOT swapped — goes to LDS, group 0 only) + { + const int gi = 4*ks + alpha; + const int ga = 16 + beta; + const int gj = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? B[gi * K + ga] : 0.0; + const double bv = vg ? buf[gi * K2 + gj * K + kp] : 0.0; + acc_corner[ss] = mfma_4x4x4_f64(av, bv, acc_corner[ss]); + } + } + } + + __syncthreads(); // all pass-1 LDS reads done; about to overwrite buf + + // Write only bottom + corner to LDS. Main and right stay in + // (p1_main, p1_right) registers for the pass-2 fused reads. + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + { + const int ga = 16 + alpha; + const int gj = 4*g + beta; + buf[ga * K2 + gj * K + kp] = acc_bottom[ss]; + } + if (g == 0) { + const int ga = 16 + alpha; + const int gj = 16 + beta; + buf[ga * K2 + gj * K + kp] = acc_corner[ss]; + } + } + } + __syncthreads(); + + // -------------------------------------------------------------------- + // PASS 2 : D[a, b] = Σ_j temp1[a, j, kp] · B_filter[j, b] + // Main & right strip read A-frag from (p1_main, p1_right) registers: + // ks=0..3 → A-frag = p1_main[ss][ks] (j ∈ 0..15) + // ks=4 → A-frag = p1_right[ss] (j ∈ 16..19) + // Bottom strip and corner still read from LDS. + // -------------------------------------------------------------------- + { + v4f64 acc_main[SPW]; + double acc_right[SPW], acc_bottom[SPW], acc_corner[SPW]; + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + acc_main[ss] = v4f64{0.0, 0.0, 0.0, 0.0}; + acc_right[ss] = 0.0; + acc_bottom[ss] = 0.0; + acc_corner[ss] = 0.0; + } + + #pragma unroll + for (int ks = 0; ks < K_SLICES; ++ks) { + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + + // A-frag for main & right: from registers (fused). + // ks=0..3 → p1_main[ss][ks] at this lane (4g+β, α encoding). + // ks=4 → p1_right[ss] at this lane. + const double a_main_av = (ks < 4) ? p1_main[ss][ks] : p1_right[ss]; + // For the 4x4x4 right strip, the A-frag value at lane S is the + // same register value (the encoding matches — see header). + const double a_right_av = (ks < 4) ? p1_main[ss][ks] : p1_right[ss]; + + // main 16×16 (a∈0..15, b∈0..15) + { + const int gj = 4*ks + tdiv16; + const int gb = tmod16; + const double bv = B[gj * K + gb]; + acc_main[ss] = mfma_16x16x4_f64_shared(a_main_av, bv, acc_main[ss]); + } + // right strip 16×4 (a∈0..15, b∈16..19) + { + const int gj = 4*ks + alpha; + const int gb = 16 + beta; + const double bv = B[gj * K + gb]; + acc_right[ss] = mfma_4x4x4_f64(a_right_av, bv, acc_right[ss]); + } + // bottom strip 4×16 (a∈16..19, b∈0..15) — LDS read + { + const int ga = 16 + beta; + const int gj = 4*ks + alpha; + const int gb = 4*g + beta; + const double av = buf[ga * K2 + gj * K + kp]; + const double bv = B[gj * K + gb]; + acc_bottom[ss] = mfma_4x4x4_f64(av, bv, acc_bottom[ss]); + } + // corner 4×4 (a∈16..19, b∈16..19) — LDS read, group 0 only + { + const int ga = 16 + beta; + const int gj = 4*ks + alpha; + const int gb = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? buf[ga * K2 + gj * K + kp] : 0.0; + const double bv = vg ? B[gj * K + gb] : 0.0; + acc_corner[ss] = mfma_4x4x4_f64(av, bv, acc_corner[ss]); + } + } + } + + __syncthreads(); + + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int kp = w + ss * NWAVES; + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int ga = tdiv16 + 4*e; + const int gb = tmod16; + buf[ga * K2 + gb * K + kp] = acc_main[ss][e]; + } + { + const int ga = 4*g + alpha; + const int gb = 16 + beta; + buf[ga * K2 + gb * K + kp] = acc_right[ss]; + } + { + const int ga = 16 + alpha; + const int gb = 4*g + beta; + buf[ga * K2 + gb * K + kp] = acc_bottom[ss]; + } + if (g == 0) { + const int ga = 16 + alpha; + const int gb = 16 + beta; + buf[ga * K2 + gb * K + kp] = acc_corner[ss]; + } + } + } + __syncthreads(); + + // -------------------------------------------------------------------- + // PASS 3 : C[a, b, c] = Σ_k temp2[a, b, k] · B_filter[k, c] + // For each wave-owned a, do the hybrid GEMM over (b, c); inner k=k'. + // A_mfma[b, k'] = buf[a, b, k']; B_mfma[k', c] = B_filter[k', c] + // Output written directly to HBM. + // -------------------------------------------------------------------- + #pragma unroll + for (int ss = 0; ss < SPW; ++ss) { + const int a = w + ss * NWAVES; // wave's owned row + + v4f64 acc_main = v4f64{0.0, 0.0, 0.0, 0.0}; + double acc_right = 0.0; + double acc_bottom = 0.0; + double acc_corner = 0.0; + + #pragma unroll + for (int ks = 0; ks < K_SLICES; ++ks) { + // main 16×16 (b∈0..15, c∈0..15) + { + const int gb = tmod16; + const int gkp = 4*ks + tdiv16; + const int gc = tmod16; + const double av = buf[a * K2 + gb * K + gkp]; + const double bv = B[gkp * K + gc]; + acc_main = mfma_16x16x4_f64_shared(av, bv, acc_main); + } + // right strip 16×4 (b∈0..15, c∈16..19) + { + const int gb = 4*g + beta; + const int gkp = 4*ks + alpha; + const int gc = 16 + beta; + const double av = buf[a * K2 + gb * K + gkp]; + const double bv = B[gkp * K + gc]; + acc_right = mfma_4x4x4_f64(av, bv, acc_right); + } + // bottom strip 4×16 (b∈16..19, c∈0..15) + { + const int gb = 16 + beta; + const int gkp = 4*ks + alpha; + const int gc = 4*g + beta; + const double av = buf[a * K2 + gb * K + gkp]; + const double bv = B[gkp * K + gc]; + acc_bottom = mfma_4x4x4_f64(av, bv, acc_bottom); + } + // corner 4×4 (b∈16..19, c∈16..19). Only group 0. + { + const int gb = 16 + beta; + const int gkp = 4*ks + alpha; + const int gc = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? buf[a * K2 + gb * K + gkp] : 0.0; + const double bv = vg ? B[gkp * K + gc] : 0.0; + acc_corner = mfma_4x4x4_f64(av, bv, acc_corner); + } + } + + // Writeback directly to HBM C. + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int gb = tdiv16 + 4*e; + const int gc = tmod16; + c_ptr[a * K2 + gb * K + gc] = acc_main[e]; + } + { + const int gb = 4*g + alpha; + const int gc = 16 + beta; + c_ptr[a * K2 + gb * K + gc] = acc_right; + } + { + const int gb = 16 + alpha; + const int gc = 4*g + beta; + c_ptr[a * K2 + gb * K + gc] = acc_bottom; + } + if (g == 0) { + const int gb = 16 + alpha; + const int gc = 16 + beta; + c_ptr[a * K2 + gb * K + gc] = acc_corner; + } + } +} + +// ============================================================================ +// 2-block-per-tensor SPLIT variant — padded LDS + B_lds +// +// Profile of the 1-block hybrid revealed 23-way LDS bank conflicts from the +// stride-K=20 access pattern (GCD(20, 32)=4). Adding a K+1 pad fixes the +// conflict but overflows 64 KB in a single block. Splitting the k'-axis +// across 2 blocks (each handles K/2=10 slabs) buys the padded layout *and* +// a B_lds cache: +// +// buf_padded : 10 · K · BK doubles = 10 · 20 · 21 = 4200 (33.6 KB) +// B_lds : K · K doubles = 400 ( 3.2 KB) +// total ( 36.8 KB) +// +// LDS layout: buf[kp_local · K · BK + i · BK + j], BK = 21. +// stride over i = BK (GCD(21,32)=1) -> conflict-free +// stride over j = 1 -> conflict-free +// stride over kp_local = K·BK=420 -> 8-way conflict (pass 3 contraction) +// +// Grid: (nfuncs, 2). Block 0 handles kp ∈ [0, 10); block 1 handles [10, 20). +// Each block computes a partial-sum over its k' range and atomic-adds to C. +// Caller must zero C before launching (submit helper does hipMemsetAsync). +// +// Pass 3 inner contraction has K_inner = 10 = 2 full 4-wide MFMA k-slices + 1 +// partial (α∈{0,1} valid, α∈{2,3} fed zero). Handled by a per-lane guard. +// ============================================================================ + +template +__global__ +__launch_bounds__(640, 2) +void transform_kernel_blocked_k20_split(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*ws unused*/) +{ + static_assert(std::is_same::value, "K=20 split: FP64 only"); + + constexpr int K = 20; + constexpr int K_HALF = K / 2; // 10 slabs per block + constexpr int BK = K + 1; // 21 - bank-conflict pad on j + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int K_SLICES_FULL = K / 4; // 5 (pass 1 & 2: full K) + // Pass 3 inner is only K_HALF=10 per block → 3 slices, last one partial. + constexpr int K_SLICES_P3 = (K_HALF + 3) / 4; // 3 + constexpr int WAVE_STRIDE = K * BK; // 20 * 21 = 420 + + extern __shared__ unsigned char _smem_split[]; + T* buf = reinterpret_cast(_smem_split); + T* B_lds = buf + K_HALF * WAVE_STRIDE; // after padded buf + + const int w = threadIdx.y; // 0..9 kp_local this wave owns + const int t = threadIdx.x; // 0..63 + const int tid = w * 64 + t; + const int nthr = blockDim.x * blockDim.y; // 640 + + // 4×4×4 lane decomposition + const int beta = t & 3; + const int g = (t >> 2) & 3; + const int alpha = (t >> 4) & 3; + // 16×16×4 lane decomposition + const int tmod16 = t & 15; + const int tdiv16 = t >> 4; + + const int cube = blockIdx.x; + const int block_half = blockIdx.y; // 0 or 1 + if (cube >= nfuncs) return; + const int kp_start = block_half * K_HALF; // 0 or 10 + + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // --- Cache B into B_lds --- + for (int idx = tid; idx < K2; idx += nthr) B_lds[idx] = B[idx]; + + // --- Distribute: load A[:, :, kp ∈ block's range] into buf --- + // HBM layout: A[i*K² + j*K + kp]. LDS: buf[kp_local * WAVE_STRIDE + i * BK + j]. + for (int idx = tid; idx < K_HALF * K2; idx += nthr) { + int kp_local = idx / K2; + int i = (idx / K) % K; + int j = idx % K; + int kp_g = kp_start + kp_local; + buf[kp_local * WAVE_STRIDE + i * BK + j] = a_ptr[i * K2 + j * K + kp_g]; + } + __syncthreads(); + + // ============ PASS 1 : D[a, j] = Σ_i B[i, a] · buf[i, j, kp] ============ + { + v4f64 acc_main[1]; + double acc_right = 0.0, acc_bottom = 0.0, acc_corner = 0.0; + acc_main[0] = v4f64{0.0, 0.0, 0.0, 0.0}; + + #pragma unroll + for (int ks = 0; ks < K_SLICES_FULL; ++ks) { + // main 16×16 (a∈0..15, j∈0..15) + { + const int gi = 4*ks + tdiv16; + const int ga = tmod16; + const double av = B_lds[gi * K + ga]; + const double bv = buf[w * WAVE_STRIDE + gi * BK + ga]; + acc_main[0] = mfma_16x16x4_f64_shared(av, bv, acc_main[0]); + } + // right strip 16×4 (a∈0..15, j∈16..19) + { + const int gi = 4*ks + alpha; + const int ga = 4*g + beta; + const int gj = 16 + beta; + const double av = B_lds[gi * K + ga]; + const double bv = buf[w * WAVE_STRIDE + gi * BK + gj]; + acc_right = mfma_4x4x4_f64(av, bv, acc_right); + } + // bottom strip 4×16 (a∈16..19, j∈0..15) + { + const int gi = 4*ks + alpha; + const int ga = 16 + beta; + const int gj = 4*g + beta; + const double av = B_lds[gi * K + ga]; + const double bv = buf[w * WAVE_STRIDE + gi * BK + gj]; + acc_bottom = mfma_4x4x4_f64(av, bv, acc_bottom); + } + // corner 4×4 (a∈16..19, j∈16..19) group 0 only + { + const int gi = 4*ks + alpha; + const int ga = 16 + beta; + const int gj = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? B_lds[gi * K + ga] : 0.0; + const double bv = vg ? buf[w * WAVE_STRIDE + gi * BK + gj] : 0.0; + acc_corner = mfma_4x4x4_f64(av, bv, acc_corner); + } + } + + __syncthreads(); + + // Writeback pass 1 → buf (same slab, reinterpret as (a, j, kp_local)) + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int ga = tdiv16 + 4*e; + const int gj = tmod16; + buf[w * WAVE_STRIDE + ga * BK + gj] = acc_main[0][e]; + } + { + const int ga = 4*g + alpha; + const int gj = 16 + beta; + buf[w * WAVE_STRIDE + ga * BK + gj] = acc_right; + } + { + const int ga = 16 + alpha; + const int gj = 4*g + beta; + buf[w * WAVE_STRIDE + ga * BK + gj] = acc_bottom; + } + if (g == 0) { + const int ga = 16 + alpha; + const int gj = 16 + beta; + buf[w * WAVE_STRIDE + ga * BK + gj] = acc_corner; + } + } + __syncthreads(); + + // ============ PASS 2 : D[a, b] = Σ_j temp1[a, j, kp] · B[j, b] ============ + { + v4f64 acc_main[1]; + double acc_right = 0.0, acc_bottom = 0.0, acc_corner = 0.0; + acc_main[0] = v4f64{0.0, 0.0, 0.0, 0.0}; + + #pragma unroll + for (int ks = 0; ks < K_SLICES_FULL; ++ks) { + // main 16×16 (a∈0..15, b∈0..15) + { + const int ga = tmod16; + const int gj = 4*ks + tdiv16; + const int gb = tmod16; + const double av = buf[w * WAVE_STRIDE + ga * BK + gj]; + const double bv = B_lds[gj * K + gb]; + acc_main[0] = mfma_16x16x4_f64_shared(av, bv, acc_main[0]); + } + // right strip 16×4 (a∈0..15, b∈16..19) + { + const int ga = 4*g + beta; + const int gj = 4*ks + alpha; + const int gb = 16 + beta; + const double av = buf[w * WAVE_STRIDE + ga * BK + gj]; + const double bv = B_lds[gj * K + gb]; + acc_right = mfma_4x4x4_f64(av, bv, acc_right); + } + // bottom strip 4×16 (a∈16..19, b∈0..15) + { + const int ga = 16 + beta; + const int gj = 4*ks + alpha; + const int gb = 4*g + beta; + const double av = buf[w * WAVE_STRIDE + ga * BK + gj]; + const double bv = B_lds[gj * K + gb]; + acc_bottom = mfma_4x4x4_f64(av, bv, acc_bottom); + } + // corner 4×4 (a∈16..19, b∈16..19) group 0 only + { + const int ga = 16 + beta; + const int gj = 4*ks + alpha; + const int gb = 16 + beta; + const bool vg = (g == 0); + const double av = vg ? buf[w * WAVE_STRIDE + ga * BK + gj] : 0.0; + const double bv = vg ? B_lds[gj * K + gb] : 0.0; + acc_corner = mfma_4x4x4_f64(av, bv, acc_corner); + } + } + + __syncthreads(); + + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int ga = tdiv16 + 4*e; + const int gb = tmod16; + buf[w * WAVE_STRIDE + ga * BK + gb] = acc_main[0][e]; + } + { + const int ga = 4*g + alpha; + const int gb = 16 + beta; + buf[w * WAVE_STRIDE + ga * BK + gb] = acc_right; + } + { + const int ga = 16 + alpha; + const int gb = 4*g + beta; + buf[w * WAVE_STRIDE + ga * BK + gb] = acc_bottom; + } + if (g == 0) { + const int ga = 16 + alpha; + const int gb = 16 + beta; + buf[w * WAVE_STRIDE + ga * BK + gb] = acc_corner; + } + } + __syncthreads(); + + // ============ PASS 3 : partial[a, b, c] = Σ_{kp_local ∈ [0, K_HALF)} ============ + // buf[kp_local, a, b] · B[kp_start+kp_local, c] + // Output atomic-added to HBM C (caller pre-zeros C). + // + // Wave w owns 2 a-values: a ∈ {w, w + K_HALF} = {w, w+10}. + // K_inner = K_HALF = 10 → 3 MFMA k-slices (last partial, 2 of 4 valid). + #pragma unroll + for (int ss = 0; ss < 2; ++ss) { + const int a = w + ss * K_HALF; // wave's owned a-row + + v4f64 acc_main = v4f64{0.0, 0.0, 0.0, 0.0}; + double acc_right = 0.0, acc_bottom = 0.0, acc_corner = 0.0; + + #pragma unroll + for (int ks = 0; ks < K_SLICES_P3; ++ks) { + // For 16×16×4 MFMA, the K=4 inner lanes are the ones with tdiv16 ∈ [0, 4). + // gkp = 4*ks + tdiv16. Valid iff gkp < K_HALF. + const int gkp_main = 4*ks + tdiv16; + const bool v_main = (gkp_main < K_HALF); + + const int gkp_strip = 4*ks + alpha; + const bool v_strip = (gkp_strip < K_HALF); + + // main 16×16 (b∈0..15, c∈0..15) + // A_mfma[m=b][k=kp_local] = temp2[a_fixed, b, kp_local] + // = buf[kp_local, a_fixed, b] + // Lane t's (m, k) = (t%16, t/16) → b = t%16, kp_local offset = t/16 + { + const int gb = tmod16; + const int gc = tmod16; + const int kpg = kp_start + gkp_main; + const double av = v_main ? buf[gkp_main * WAVE_STRIDE + a * BK + gb] : 0.0; + const double bv = v_main ? B_lds[kpg * K + gc] : 0.0; + acc_main = mfma_16x16x4_f64_shared(av, bv, acc_main); + } + // right strip 16×4 (b∈0..15, c∈16..19) + { + const int gb = 4*g + beta; + const int gc = 16 + beta; + const int kpg = kp_start + gkp_strip; + const double av = v_strip ? buf[gkp_strip * WAVE_STRIDE + a * BK + gb] : 0.0; + const double bv = v_strip ? B_lds[kpg * K + gc] : 0.0; + acc_right = mfma_4x4x4_f64(av, bv, acc_right); + } + // bottom strip 4×16 (b∈16..19, c∈0..15) + { + const int gb = 16 + beta; + const int gc = 4*g + beta; + const int kpg = kp_start + gkp_strip; + const double av = v_strip ? buf[gkp_strip * WAVE_STRIDE + a * BK + gb] : 0.0; + const double bv = v_strip ? B_lds[kpg * K + gc] : 0.0; + acc_bottom = mfma_4x4x4_f64(av, bv, acc_bottom); + } + // corner 4×4 (b∈16..19, c∈16..19) + { + const int gb = 16 + beta; + const int gc = 16 + beta; + const int kpg = kp_start + gkp_strip; + const bool vg = (g == 0) && v_strip; + const double av = vg ? buf[gkp_strip * WAVE_STRIDE + a * BK + gb] : 0.0; + const double bv = vg ? B_lds[kpg * K + gc] : 0.0; + acc_corner = mfma_4x4x4_f64(av, bv, acc_corner); + } + } + + // --- atomic add partial to HBM C --- + #pragma unroll + for (int e = 0; e < 4; ++e) { + const int gb = tdiv16 + 4*e; + const int gc = tmod16; + atomicAdd(&c_ptr[a * K2 + gb * K + gc], (double)acc_main[e]); + } + { + const int gb = 4*g + alpha; + const int gc = 16 + beta; + atomicAdd(&c_ptr[a * K2 + gb * K + gc], acc_right); + } + { + const int gb = 16 + alpha; + const int gc = 4*g + beta; + atomicAdd(&c_ptr[a * K2 + gb * K + gc], acc_bottom); + } + if (g == 0) { + const int gb = 16 + alpha; + const int gc = 16 + beta; + atomicAdd(&c_ptr[a * K2 + gb * K + gc], acc_corner); + } + } +} + +// ============================================================================ +// Dispatch helpers +// ============================================================================ + +template +inline size_type blocked_k20_shmem_size() { + return (size_type)(20 * 20 * 20 * sizeof(T)); // 64 KB at FP64 (1-block variants) +} + +template +inline size_type blocked_k20_split_shmem_size() { + // buf[K_HALF · K · BK] + B_lds[K · K] = 10·20·21 + 20·20 = 4600 doubles + return (size_type)((10 * 20 * 21 + 20 * 20) * sizeof(T)); // 36.8 KB +} + +template +inline Dim3 blocked_k20_blockdim() { + return Dim3(64, 10, 1); +} + +// Default path for K=20: hybrid 16×16×4 + 4×4×4 MFMA (1 block per tensor). +// +// The 2-block _split variant (with padded LDS + B_lds) was implemented and +// measured to be ~1.9× SLOWER at N=2048 despite eliminating bank conflicts +// (23.9 → 0.2-way). The overheads that killed it: +// * atomicAdd read-modify-write on C (HBM traffic 258 → 752 MB) +// * pre-zero hipMemsetAsync on C (+128 MB) +// * extra VALU for 420-stride indexing + partial-k-slice guards (+2.6×) +// Net: MFMA busy fell 11 → 6 % because the extra VALU starved them further. +// Conclusion: bank conflicts were a visible counter but not the critical path. +// +// Alternative kernels kept for reference / future experimentation: +// transform_kernel_blocked_k20_split (2 blocks + padded LDS + atomics) +// transform_kernel_blocked_k20_mfma (pure 4×4×4) +// transform_kernel_blocked_k20_scalar (correctness baseline) +template +inline void submit_transform_bench_blocked_k20(int nfuncs, + const T* A, const T* B, T* C, T* workspace, + Stream stream) +{ + Dim3 td = blocked_k20_blockdim(); + size_type smem = blocked_k20_shmem_size(); + CONFIGURE_KERNEL((transform_kernel_blocked_k20_hybrid), smem); + CALL_KERNEL((transform_kernel_blocked_k20_hybrid), nfuncs, td, smem, stream, + (nfuncs, A, B, C, workspace)); +} diff --git a/transform_blocked_rocwmma.h b/transform_blocked_rocwmma.h new file mode 100644 index 0000000..cfde62b --- /dev/null +++ b/transform_blocked_rocwmma.h @@ -0,0 +1,191 @@ +#pragma once +#include +#include "util.h" + +#if defined(__HIP__) && defined(__HIP_DEVICE_COMPILE__) +#include +#endif + +// transform_blocked_rocwmma.h — block-distributed 3D transform, rocWMMA variant. +// +// Mirrors the algorithm in transform_blocked.h (L7) but replaces the manual +// MFMA calls with rocwmma::mma_sync. rocWMMA hides the MFMA fragment layouts +// and emits the same underlying v_mfma_f64_16x16x4f64 on GFX90A. +// +// Layout differences from L7: +// * buf uses the per-wave layout buf[s*WB + i*BK + j] (not L7's canonical +// (i, j, k) form) so rocWMMA's standard row-major loads work directly +// with ldm = BK. +// * No register fusion between pass 1 and pass 2 (fragment contents are +// opaque to us), so pass 1's output goes through LDS. +// +// Thread block: K waves × 64 threads = 1024 at K=16. One block per tensor. +// LDS budget: buf (K × K × (K+1) padded) + B cache (K × K) = 36 KB at K=16. + +#if defined(__HIP__) && defined(__HIP_DEVICE_COMPILE__) +namespace rw = rocwmma; + +// Fragment types for f64 16×16×4 MFMA on GFX90A. +using FragA = rw::fragment; +using FragAT = rw::fragment; // used to load B as B^T +using FragB = rw::fragment; +using FragAcc = rw::fragment; +#endif + +template +__global__ +__launch_bounds__(K * 64, 1) +void transform_kernel_blocked_rocwmma(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ /*workspace unused*/) +{ +#if defined(__HIP__) && defined(__HIP_DEVICE_COMPILE__) + static_assert(std::is_same::value, "rocWMMA path is FP64 only"); + static_assert(K == 16, "rocWMMA path: K=16 only for now"); + + constexpr int K2 = K * K; + constexpr int K3 = K * K * K; + constexpr int ELEMS_PER_LANE = K2 / 64; + constexpr int NMFMA = K / 4; + + // Bank-conflict pad identical to L7. + constexpr int BK = K + 1; + constexpr int WB = K * BK; + + extern __shared__ unsigned char _smem_rwmma[]; + T* buf = reinterpret_cast(_smem_rwmma); + T* B_lds = buf + K * WB; + + const int s = threadIdx.y; // wave index (0..K-1) + const int t = threadIdx.x; // lane 0..63 + const int tid = s * 64 + t; + + // --- Cache B into LDS (wide double4 loads, like L7) --- + static_assert(K2 % 4 == 0); + if (tid < K2 / 4) { + const double4* B_vec = reinterpret_cast(B); + double4* B_lds_vec = reinterpret_cast(B_lds); + B_lds_vec[tid] = B_vec[tid]; + } + __syncthreads(); + + // One block = one tensor. + const int cube = blockIdx.x; + if (cube >= nfuncs) return; + const T* a_ptr = A + (size_t)cube * K3; + T* c_ptr = C + (size_t)cube * K3; + + // --- Distribute: coalesced HBM load into per-wave layout --- + // Each thread loads 4 contiguous HBM doubles as a double4, placing them + // across 4 WAVES (since consecutive HBM elements have consecutive k, and + // wave s owns k=s). Per-wave layout: buf[s*WB + i*BK + j] = A[i, j, s]. + { + static_assert(K == 16); + const double4* a_ptr_vec = reinterpret_cast(a_ptr); + const int base_idx = 4 * tid; + const int i = base_idx >> 8; + const int j = (base_idx >> 4) & 15; + const int k_start = base_idx & 15; // 0, 4, 8, or 12 + + double4 v = a_ptr_vec[tid]; + // 4 cross-wave stores (destination wave = k_start + e for e=0..3). + buf[(k_start + 0) * WB + i * BK + j] = v.x; + buf[(k_start + 1) * WB + i * BK + j] = v.y; + buf[(k_start + 2) * WB + i * BK + j] = v.z; + buf[(k_start + 3) * WB + i * BK + j] = v.w; + } + __syncthreads(); + + // --- Pass 1 (SWAPPED operands): compute blk^T · B --- + // A_frag = blk^T: col_major load of the per-wave blk region. + // B_frag = B: row_major load. + // Output D[j][a] = temp1[a][j]. In the fragment, thread t's acc1.x[e] = + // temp1[t%16][(t/16) + 4e] = temp1[t%16][p*4 + t/16] (commutative). + // That is EXACTLY the value pass 2's matrix_a fragment wants at iter p=e, + // so we can feed acc1 straight into pass 2 without any LDS round-trip. + FragAcc acc1; + rw::fill_fragment(acc1, 0.0); + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + FragAT a_frag; // col_major → blk^T + FragB b_frag; // B + rw::load_matrix_sync(a_frag, &buf[s * WB + p * 4 * BK], BK); + rw::load_matrix_sync(b_frag, &B_lds[p * 4 * K], K); + rw::mma_sync(acc1, a_frag, b_frag, acc1); + } + // No LDS writeback. acc1.x[p] is pass 2's a_frag value at iter p. + + // --- Pass 2: temp2 = temp1 · B (consumes acc1 directly) --- + FragAcc acc2; + rw::fill_fragment(acc2, 0.0); + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + FragA a_frag_p; + // Populate matrix_a's per-lane element from acc1's p-th acc value. + a_frag_p.x[0] = acc1.x[p]; + FragB b_frag; + rw::load_matrix_sync(b_frag, &B_lds[p * 4 * K], K); + rw::mma_sync(acc2, a_frag_p, b_frag, acc2); + } + // acc2 holds pass 2 output (temp2) in the same per-lane layout as L7's + // acc2 — so acc2.x[e] IS the corner-turn stash for iter e. + + // --- Corner turn: direct cross-wave write from acc2 (no stash read) --- + __syncthreads(); // all waves done with pass-1 reads from own region + #pragma unroll + for (int e = 0; e < ELEMS_PER_LANE; ++e) { + int idx = t + e * 64; + int a_ix = idx / K; + int b_ix = idx % K; + buf[a_ix * WB + b_ix * BK + s] = acc2.x[e]; + } + __syncthreads(); + + // --- Pass 3: result = temp2_reshuffled · B --- + // Wave s now represents a=s. buf[s*WB + b*BK + k'] = temp2[a=s, b, k']. + FragAcc acc3; + rw::fill_fragment(acc3, 0.0); + #pragma unroll + for (int p = 0; p < NMFMA; ++p) { + FragA a_frag; + FragB b_frag; + rw::load_matrix_sync(a_frag, &buf[s * WB + p * 4], BK); + rw::load_matrix_sync(b_frag, &B_lds[p * 4 * K], K); + rw::mma_sync(acc3, a_frag, b_frag, acc3); + } + + // --- Store to HBM --- + // Canonical C[a*K² + b*K + c] with a=s. + rw::store_matrix_sync(&c_ptr[(size_t)s * K2], acc3, K, rw::mem_row_major); +#endif // __HIP__ && __HIP_DEVICE_COMPILE__ +} + +template +inline size_type blocked_rocwmma_shmem_size(int K) { + return (size_type)((K * K * (K + 1) + K * K) * sizeof(T)); +} + +template +inline Dim3 blocked_rocwmma_blockdim(int K) { + return Dim3(64, K, 1); +} + +template +inline void submit_transform_bench_blocked_rocwmma(int nfuncs, int /*nblocks*/, int K, + const T* A, const T* B, T* C, T* workspace, + Stream stream) +{ + if (K == 16) { + constexpr int Kv = 16; + Dim3 td = blocked_rocwmma_blockdim(Kv); + size_type smem = blocked_rocwmma_shmem_size(Kv); + CONFIGURE_KERNEL((transform_kernel_blocked_rocwmma), smem); + CALL_KERNEL((transform_kernel_blocked_rocwmma), nfuncs, td, smem, stream, + (nfuncs, A, B, C, workspace)); + } else { + fprintf(stderr, "blocked rocwmma transform: K=%d not supported (K=16 only)\n", K); + assert(false); + } +} diff --git a/transform_cublasdx.h b/transform_cublasdx.h deleted file mode 100644 index b72032f..0000000 --- a/transform_cublasdx.h +++ /dev/null @@ -1,233 +0,0 @@ -#ifndef HAVE_TRANSFORM_CUBLASDX_H -#define HAVE_TRANSFORM_CUBLASDX_H - -#include "util.h" -#include "mxm_cublasdx.h" - - -/********************************************************************************** - * cublasDx implementation of transform - * - * The cublasDx implementation uses the cublasdx library directly to perform - * the tensor transformation. It relies on a single shared memory tensor - * to which the result of a GEMM is written in each iteration. - * The register fragment saves us the additional shared memory tensor - * that would be required to store the result of the GEMM. - **********************************************************************************/ - -#if __has_include() - -#define MRA_HAVE_CUBLASDX 1 - -template -__forceinline__ __device__ -void transform_cublasdx_k( - const T* t, // input tensor - const T* c, // input matrix - T* result) -{ - constexpr const int ndim = 3; // fixed for benchmark - using GEMM = typename mra::detail::GEMMBuilder::GEMM; - - using alignment = cublasdx::alignment_of; - - extern __shared__ __align__(16) char smem[]; - - auto [smem_a, smem_b] = - cublasdx::shared_memory::slice_into_pointers( - smem, - cublasdx::alignment_of_v_a, cublasdx::cosize(GET_SHARED_LAYOUT(GEMM, a)), - cublasdx::alignment_of_v_b, cublasdx::cosize(GET_SHARED_LAYOUT(GEMM, b))); - - - /* global memory tensors */ - auto a_global_tensor = cublasdx::make_tensor(t, GEMM::get_layout_gmem_a()); - auto b_global_tensor = cublasdx::make_tensor(c, GEMM::get_layout_gmem_b()); - auto c_global_tensor = cublasdx::make_tensor(result, GEMM::get_layout_gmem_c()); - - /* shared memory tensors */ - auto a_shared_tensor = cublasdx::make_tensor(smem_a, GET_SHARED_LAYOUT(GEMM, a)); - auto b_shared_tensor = cublasdx::make_tensor(smem_b, GET_SHARED_LAYOUT(GEMM, b)); - - cublasdx::copy(a_global_tensor, a_shared_tensor); - cublasdx::copy(b_global_tensor, b_shared_tensor); - - /* wait for loads to complete */ - cublasdx::copy_wait(); - - for (int n=0; n(c_register_fragment, a_shared_tensor, partitioner); - - /* wait for stores to complete */ - cublasdx::copy_wait(); - } - /* copy the result from shared memory to global memory */ - cublasdx::copy(a_shared_tensor, c_global_tensor); - - /* wait for the copy to complete */ - cublasdx::copy_wait(); -} - - -template -__forceinline__ __device__ -void transform_cublasdx( - const T* t, - const T* c, - T*& result, - T* workspace) -{ - return transform_cublasdx_k(t, c, result); - -#if 0 - (void)workspace; // unused in this implementation - switch (K) { - case 8 : transform_cublasdx_k(t, c, result); break; - case 10: transform_cublasdx_k(t, c, result); break; - case 16: transform_cublasdx_k(t, c, result); break; - case 20: transform_cublasdx_k(t, c, result); break; - default: - printf("Unsupported K value: %d\n", K); - return; - } -#endif // 0 - /* no need to synchronize here, cublasdx::copy_wait() synchronizes */ -} - -template -constexpr int transform_cublasdx_shmem_size_k() -{ - using GEMM = typename mra::detail::GEMMBuilder::GEMM; - - auto calc = cublasdx::make_shared_storage_calculator() - .add(cublasdx::alignment_of_v_a, sizeof(typename GEMM::a_value_type), GET_SHARED_LAYOUT(GEMM, a)) - .add(cublasdx::alignment_of_v_b, sizeof(typename GEMM::b_value_type), GET_SHARED_LAYOUT(GEMM, b)); - auto smem_size = calc.get(); - return smem_size; -} - -template -int transform_cublasdx_shmem_size(int K) -{ - switch (K) { - case 8 : return transform_cublasdx_shmem_size_k(); - case 10: return transform_cublasdx_shmem_size_k(); - case 16: return transform_cublasdx_shmem_size_k(); - case 20: return transform_cublasdx_shmem_size_k(); - default: - printf("Unsupported K value: %d\n", K); - return 0; - } - /* no need to synchronize here, cublasdx::copy_wait() synchronizes */ -} - - -template -constexpr auto transform_cublasdx_block_dim() -{ - using GEMM = typename mra::detail::GEMMBuilder::GEMM; - - return GEMM::suggested_block_dim; -} - - -template -constexpr auto transform_cublasdx_block_size() -{ - auto blockdims = transform_cublasdx_block_dim(); - return blockdims.x * blockdims.y * blockdims.z; -} - - -template -LAUNCH_BOUNDS((transform_cublasdx_block_size()), 1) -__global__ void transform_cublasdx_kernel(int nfuncs, const T* A, const T* B, T* C, T* workspace) { - - const T *a, *b; - T *c, *w; - int K2NDIM = K*K*K; - /* workspace is allocated for each thread-block */ - w = workspace + blockIdx.x * K2NDIM; - /* iterate over all tensors */ - for (int i = blockIdx.x; i < nfuncs; i += gridDim.x) { - a = A + i * K2NDIM; - b = B; - c = C + i * K2NDIM; - transform_cublasdx(a, b, c, w); - } -} - -template -void submit_transform_cublasdx_bench(int nfuncs, int nblocks, int K, - const T* A, const T* B, T* C, T* workspace, - cudaStream_t stream) -{ - auto smem_size = transform_cublasdx_shmem_size(K); - switch (K) { - case 8: { - CONFIGURE_KERNEL((transform_cublasdx_kernel), smem_size); - CALL_KERNEL((transform_cublasdx_kernel), std::min(nfuncs, nblocks), (transform_cublasdx_block_dim()), smem_size, stream, (nfuncs, A, B, C, workspace)); - break; - } - case 10: { - CONFIGURE_KERNEL((transform_cublasdx_kernel), smem_size); - CALL_KERNEL((transform_cublasdx_kernel), std::min(nfuncs, nblocks), (transform_cublasdx_block_dim()), smem_size, stream, (nfuncs, A, B, C, workspace)); - break; - } - case 16: { - CONFIGURE_KERNEL((transform_cublasdx_kernel), smem_size); - CALL_KERNEL((transform_cublasdx_kernel), std::min(nfuncs, nblocks), (transform_cublasdx_block_dim()), smem_size, stream, (nfuncs, A, B, C, workspace)); - break; - } - case 20: { - CONFIGURE_KERNEL((transform_cublasdx_kernel), smem_size); - CALL_KERNEL((transform_cublasdx_kernel), std::min(nfuncs, nblocks), (transform_cublasdx_block_dim()), smem_size, stream, (nfuncs, A, B, C, workspace)); - break; - } - default: - printf("Unsupported K value: %d\n", K); - throw std::runtime_error("Unsupported K value in transform_cublasdx_bench"); - } -} - -#else - -#define MRA_HAVE_CUBLASDX 0 - -template -void submit_transform_cublasdx_bench(int nfuncs, int nblocks, int K, - const T* A, const T* B, T* C, T* workspace, - Stream stream) { - std::printf("CUBLASdx not available, cannot run benchmark\n"); -} - - -template -constexpr auto transform_cublasdx_block_size() { - return 1; -} - -template -int transform_cublasdx_shmem_size(int K) { - return 0; -} - -#endif // __has_include() - -#endif // HAVE_TRANSFORM_CUBLASDX_H \ No newline at end of file diff --git a/transform_kron.h b/transform_kron.h deleted file mode 100644 index c95fe1d..0000000 --- a/transform_kron.h +++ /dev/null @@ -1,151 +0,0 @@ -#pragma once - -/** - * Level 6 — Kronecker product GEMM. - * - * MATHEMATICAL BACKGROUND - * ----------------------- - * The standard 3-pass transform applies B^T along each mode of a K×K×K tensor: - * - * Pass 1: T1[j₀,i₁,i₂] = Σ_{i₀} A[i₀,i₁,i₂] · B[i₀,j₀] (contract mode 0) - * Pass 2: T2[j₀,j₁,i₂] = Σ_{i₁} T1[j₀,i₁,i₂] · B[i₁,j₁] (contract mode 1) - * Pass 3: C [j₀,j₁,j₂] = Σ_{i₂} T2[j₀,j₁,i₂] · B[i₂,j₂] (contract mode 2) - * - * Vectorising the tensor (flattening to K³ elements) converts this to a single - * matrix-vector product: - * - * vec(C) = KronMat · vec(A) - * - * where KronMat = B^T ⊗ B^T ⊗ B^T is the three-fold Kronecker product - * (a K³ × K³ matrix). Each entry is: - * - * KronMat[β, α] = B[α%K][β%K] · B[(α/K)%K][(β/K)%K] · B[α/K²][β/K²] - * - * with α = input linear index (i₀ + K·i₁ + K²·i₂) - * β = output linear index (j₀ + K·j₁ + K²·j₂) [same decomposition] - * - * IMPLEMENTATION - * -------------- - * 1. build_kron_kernel — one GPU thread per (β, α) entry; called ONCE before - * the timing loop and cached for all subsequent batches. - * - * 2. submit_transform_kron_bench — calls hipblasDgemm / cublasDgemm: - * - * C [K³ × nfuncs] = KronMat [K³ × K³] × A [K³ × nfuncs] - * - * Tensors are stored contiguously (tensor f occupies A[f·K³ .. (f+1)·K³-1]), - * so the batch dimension maps naturally to GEMM columns. - * - * TRADE-OFFS - * ---------- - * Pros - * • Single API call; a K=8 GEMM (512×512 × 512×2048) saturates HBM and - * compute much better than 128 tiny 64×8 kernels. - * • Faster than L3 for K=6 and K=8 on MI250X despite 12–21× more FLOPs. - * - * Cons - * • KronMat memory = K⁶ × 8 bytes: 6 MB at K=10, 128 MB at K=16, - * 512 MB at K=20 — impractical for K > ~16. - * • FLOPs reported are 2·K⁶·N (actual GEMM work), not the 3·2·K⁴·N - * mathematical minimum, so raw GFlop/s numbers are not directly - * comparable to L1–L4. - * - * CORRECTNESS - * ----------- - * test_kron.hip verifies L6 ≡ L3 to floating-point precision - * (max relative error < 10⁻¹⁴ for K = 6, 8, 10). - */ - -#include "util.h" -#ifdef MRA_HAVE_HIP -# include - using blasHandle_t = hipblasHandle_t; -# define BLAS_OP_N HIPBLAS_OP_N -# define blasCreate hipblasCreate -# define blasDestroy hipblasDestroy -# define blasSetStream hipblasSetStream -# define blasDgemm hipblasDgemm -#elif defined(MRA_HAVE_CUDA) -# include - using blasHandle_t = cublasHandle_t; -# define BLAS_OP_N CUBLAS_OP_N -# define blasCreate cublasCreate -# define blasDestroy cublasDestroy -# define blasSetStream cublasSetStream -# define blasDgemm cublasDgemm -#endif - -// --------------------------------------------------------------------------- -// Kernel: build the K³×K³ Kronecker product matrix (column-major). -// -// KronMat[I, J] = B^T[i₀,j₀] · B^T[i₁,j₁] · B^T[i₂,j₂] -// -// Index decomposition (first index fastest = column-major vector): -// I = i₀ + K·i₁ + K²·i₂ -// J = j₀ + K·j₁ + K²·j₂ -// -// B is row-major K×K, so B^T[i,j] = B[j·K + i]. -// --------------------------------------------------------------------------- -template -__global__ void build_kron_kernel(int K, const T* __restrict__ B, - T* __restrict__ KronMat) -{ - const int K3 = K * K * K; - const int I = blockIdx.x * blockDim.x + threadIdx.x; - const int J = blockIdx.y * blockDim.y + threadIdx.y; - if (I >= K3 || J >= K3) return; - - const int i0 = I % K, j0 = J % K; - const int i1 = (I / K) % K, j1 = (J / K) % K; - const int i2 = I / (K * K), j2 = J / (K * K); - - // B^T[i,j] = B[j*K + i] (B is row-major) - KronMat[I + J * K3] = B[j0*K + i0] * B[j1*K + i1] * B[j2*K + i2]; -} - -// --------------------------------------------------------------------------- -// Build the Kronecker matrix on the device (call once before timing). -// KronMat must already be allocated with K³×K³ elements. -// --------------------------------------------------------------------------- -template -inline void build_kron_matrix(int K, const T* B_dev, T* KronMat_dev, - Stream stream) -{ - const int K3 = K * K * K; - dim3 block(16, 16); - dim3 grid((K3 + 15) / 16, (K3 + 15) / 16); - CALL_KERNEL(build_kron_kernel, grid, block, 0, stream, - (K, B_dev, KronMat_dev)); -} - -// --------------------------------------------------------------------------- -// Submit one round of the Kronecker GEMM (called inside the timing loop). -// -// C[K³ × nfuncs] = KronMat[K³ × K³] × A[K³ × nfuncs] -// -// A and C are treated as column-major (each contiguous K³-block = one tensor). -// --------------------------------------------------------------------------- -template -inline void submit_transform_kron_bench(int nfuncs, int K, - const T* A, const T* KronMat, T* C, - blasHandle_t blas_handle, - Stream stream) -{ - const int K3 = K * K * K; - const double alpha = 1.0, beta = 0.0; - blasSetStream(blas_handle, stream); - blasDgemm(blas_handle, - BLAS_OP_N, BLAS_OP_N, - K3, nfuncs, K3, - &alpha, - KronMat, K3, - A, K3, - &beta, - C, K3); -} - -// Required by the benchmark dispatch (values are unused for level 6). -template -inline int kron_shmem_size(int /*K*/) { return 0; } - -inline Dim3 kron_blockdim(int /*K*/) { return {1, 1, 1}; } diff --git a/transform_level2.h b/transform_level2.h index 363fa45..95791ea 100644 --- a/transform_level2.h +++ b/transform_level2.h @@ -1,65 +1,62 @@ #pragma once - +#include #include "util.h" +#include "mxm.h" #include "mxm_level2.h" -/** - * Transform wrapper for Level-2 (B in LDS). - * Follows the same structure as transform.h / transform_cublasdx.h. - */ +// transform_level2.h — L2: B staged into LDS once per block, then three mTxmq passes. +// Use a byte-typed shared buffer to avoid extern __shared__ type conflicts across TUs. template -__device__ void transform_level2( - int K, - const T* t, /* input tensor K^3 */ - const T* c, /* coefficient matrix K^2 */ - T*& result, /* output tensor K^3 (pointer updated on swap) */ - T* workspace) /* per-block scratch K^3 */ +__global__ void transform_kernel_L2(int nfuncs, int K, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ workspace) { - constexpr int ndim = 3; - const T* pc = c; - T *t0 = workspace, *t1 = result; - /* swap so t0 points at the output buffer first */ - auto tmp = t0; t0 = t1; t1 = tmp; - - const int dimj = K; - const int dimi = dimj * dimj; - - mra::mTxmq_level2(dimi, dimj, dimj, t0, t, pc); - for (int n = 1; n < ndim; ++n) { - mra::mTxmq_level2(dimi, dimj, dimj, t1, t0, pc); - auto tmp2 = t0; t0 = t1; t1 = tmp2; - } - /* mTxmq_level2 ends with __syncthreads(); no extra sync needed */ -} - -template -LAUNCH_BOUNDS(MAX_THREADS_PER_BLOCK, 4) -__global__ void transform_kernel_level2(int nfuncs, int K, - const T* A, const T* B, T* C, T* workspace) { - const int K2NDIM = K * K * K; - T* w = workspace + blockIdx.x * K2NDIM; - for (int i = blockIdx.x; i < nfuncs; i += gridDim.x) { - const T* a = A + i * K2NDIM; - T* c = C + i * K2NDIM; - transform_level2(K, a, B, c, w); - } + extern __shared__ unsigned char _shmem_l2[]; + T* b_shm = reinterpret_cast(_shmem_l2); + + // Stage B into LDS once per block (shared across all cubes in grid-stride loop). + int nthr = blockDim.x * blockDim.y; + int tid = blockDim.x * threadIdx.y + threadIdx.x; + for (int i = tid; i < K * K; i += nthr) b_shm[i] = B[i]; + __syncthreads(); // ensure B visible to all threads before first mTxmq + + const int K3 = K * K * K; + T* w = workspace + (size_t)blockIdx.x * K3; // one workspace slab per block + + for (int cube = blockIdx.x; cube < nfuncs; cube += gridDim.x) { + const T* a = A + (size_t)cube * K3; + T* c = C + (size_t)cube * K3; + + // Three passes: ping-pong between c (output) and w (workspace). + // Pass 1 → c, Pass 2 → w, Pass 3 → c (result in c). + T* t0 = c; + T* t1 = w; + const int dimi = K * K, dimj = K; + + mra::mTxmq(dimi, dimj, dimj, t0, a, b_shm); // pass 1 → c + mra::mTxmq(dimi, dimj, dimj, t1, t0, b_shm); // pass 2 → w + mra::mTxmq(dimi, dimj, dimj, t0, t1, b_shm); // pass 3 → c + // mTxmq ends with SYNCTHREADS, so c is visible to all before next iteration. + } } -template -inline int transform_level2_shmem_size(int K) { - return mra::mTxmq_level2_shmem_size(K); +template +inline size_type transform_L2_shmem_size(int K) { + return mra::mTxmq_L2_shmem_size(K); } -template -inline void submit_transform_level2_bench(int nfuncs, int nblocks, int K, - const T* A, const T* B, T* C, T* workspace, - Stream stream) +template +inline void submit_transform_bench_L2(int nfuncs, int nblocks, int K, + const T* A, const T* B, T* C, T* workspace, + Stream stream) { - Dim3 thread_dims = mra::mTxmq_level2_blockdim(K); - auto smem_size = mra::mTxmq_level2_shmem_size(K); - CONFIGURE_KERNEL(transform_kernel_level2, smem_size); - CALL_KERNEL(transform_kernel_level2, std::min(nfuncs, nblocks), - thread_dims, smem_size, stream, - (nfuncs, K, A, B, C, workspace)); + Dim3 thread_dims = mra::mTxmq_L2_blockdim(K); + assert(block_size(thread_dims) <= MAX_THREADS_PER_BLOCK); + size_type smem_size = mra::mTxmq_L2_shmem_size(K); + CONFIGURE_KERNEL(transform_kernel_L2, smem_size); + CALL_KERNEL(transform_kernel_L2, std::min(nfuncs, nblocks), thread_dims, smem_size, stream, + (nfuncs, K, A, B, C, workspace)); } diff --git a/transform_level3.h b/transform_level3.h index c377277..ad5a31c 100644 --- a/transform_level3.h +++ b/transform_level3.h @@ -1,82 +1,74 @@ #pragma once - +#include #include "util.h" #include "mxm_level3.h" +#include "transform_level2.h" // fallback for K not in {16, 32} -/** - * Transform wrapper for Level-3 (B in LDS + register accumulation). - * Each K value gets its own kernel binary via template, - * isolating register pressure to acc[K] for that specific K. - */ +// transform_level3.h — L3: K-templated register-blocking kernel. +// Falls back to L2 for K values other than 16 and 32. -template -__device__ void transform_level3_k( - const T* t, - const T* c, - T*& result, - T* workspace) +template +__global__ void transform_kernel_L3(int nfuncs, + const T* __restrict__ A, + const T* __restrict__ B, + T* __restrict__ C, + T* __restrict__ workspace) { - constexpr int ndim = 3; - constexpr int K2NDIM = K * K * K; + extern __shared__ unsigned char _shmem_l3[]; + T* b_shm = reinterpret_cast(_shmem_l3); - T *t0 = workspace, *t1 = result; - auto tmp = t0; t0 = t1; t1 = tmp; + // Stage B into LDS once per block. + constexpr int K2 = K * K; + int nthr = blockDim.x * blockDim.y; + int tid = blockDim.x * threadIdx.y + threadIdx.x; + for (int i = tid; i < K2; i += nthr) b_shm[i] = B[i]; + __syncthreads(); - /* B is already in LDS — mTxmq_level3_k loads it each call */ - mra::mTxmq_level3_k(t0, t, c); - for (int n = 1; n < ndim; ++n) { - mra::mTxmq_level3_k(t1, t0, c); - auto tmp2 = t0; t0 = t1; t1 = tmp2; - } -} + constexpr int K3 = K * K * K; + T* w = workspace + (size_t)blockIdx.x * K3; -/* One kernel binary per K — register pressure is proportional to K, not max(K). */ -template -LAUNCH_BOUNDS(MAX_THREADS_PER_BLOCK, 1) -__global__ void transform_kernel_level3_k(int nfuncs, - const T* A, const T* B, T* C, T* workspace) { - constexpr int K2NDIM = K * K * K; - T* w = workspace + blockIdx.x * K2NDIM; - for (int i = blockIdx.x; i < nfuncs; i += gridDim.x) { - const T* a = A + i * K2NDIM; - T* c = C + i * K2NDIM; - /* result pointer starts at c; workspace is w */ - T* result = c; - transform_level3_k(a, B, result, w); - } -} + for (int cube = blockIdx.x; cube < nfuncs; cube += gridDim.x) { + const T* a = A + (size_t)cube * K3; + T* c = C + (size_t)cube * K3; -template -inline int transform_level3_shmem_size(int K) { - return mra::mTxmq_level3_shmem_size(K); + // Pass 1 → c, Pass 2 → w, Pass 3 → c. + mra::mTxmq_L3(c, a, b_shm); // pass 1 → c + mra::mTxmq_L3(w, c, b_shm); // pass 2 → w + mra::mTxmq_L3(c, w, b_shm); // pass 3 → c + } } -template -inline void submit_transform_level3_bench(int nfuncs, int nblocks, int K, - const T* A, const T* B, T* C, T* workspace, - Stream stream) +template +inline void submit_transform_bench_L3(int nfuncs, int nblocks, int K, + const T* A, const T* B, T* C, T* workspace, + Stream stream) { - Dim3 thread_dims = mra::mTxmq_level3_blockdim(K); - int smem_size = mra::mTxmq_level3_shmem_size(K); - -#define DISPATCH_L3(Kval) \ - case Kval: \ - CONFIGURE_KERNEL((transform_kernel_level3_k), smem_size); \ - CALL_KERNEL((transform_kernel_level3_k), std::min(nfuncs, nblocks), \ - thread_dims, smem_size, stream, \ - (nfuncs, A, B, C, workspace)); \ - break; + if (K == 16) { + constexpr int Kv = 16; + Dim3 td = mra::mTxmq_L3_blockdim(Kv); + size_type smem = mra::L3_shmem_size(Kv); + CONFIGURE_KERNEL((transform_kernel_L3), smem); + CALL_KERNEL((transform_kernel_L3), std::min(nfuncs, nblocks), td, smem, stream, + (nfuncs, A, B, C, workspace)); + } else if (K == 20) { + constexpr int Kv = 20; + Dim3 td = mra::mTxmq_L3_blockdim(Kv); + size_type smem = mra::L3_shmem_size(Kv); + CONFIGURE_KERNEL((transform_kernel_L3), smem); + CALL_KERNEL((transform_kernel_L3), std::min(nfuncs, nblocks), td, smem, stream, + (nfuncs, A, B, C, workspace)); + } else if (K == 32) { + constexpr int Kv = 32; + Dim3 td = mra::mTxmq_L3_blockdim(Kv); + size_type smem = mra::L3_shmem_size(Kv); + CONFIGURE_KERNEL((transform_kernel_L3), smem); + CALL_KERNEL((transform_kernel_L3), std::min(nfuncs, nblocks), td, smem, stream, + (nfuncs, A, B, C, workspace)); + } else { + // Fall back to L2 for other K values. + //submit_transform_bench_L2(nfuncs, nblocks, K, A, B, C, workspace, stream); + fprintf(stderr, "Unsupported K=%d for L3 kernel\n", K); + assert(false); - switch (K) { - DISPATCH_L3( 6) - DISPATCH_L3( 8) - DISPATCH_L3(10) - DISPATCH_L3(12) - DISPATCH_L3(16) - DISPATCH_L3(20) - DISPATCH_L3(32) - default: - printf("submit_transform_level3_bench: unsupported K=%d\n", K); - } -#undef DISPATCH_L3 + } } diff --git a/transform_level4.h b/transform_level4.h deleted file mode 100644 index d3e00f9..0000000 --- a/transform_level4.h +++ /dev/null @@ -1,82 +0,0 @@ -#pragma once - -#include "util.h" -#include "mxm_level4.h" - -/** - * Transform wrapper for Level-4 (MFMA on AMD GFX90A/GFX940, falls back to - * Level-3 on other targets or for K values without native MFMA support). - * - * Each K value gets its own kernel binary via template, - * isolating register pressure to the specific K being compiled. - * Block dimension is 64 threads (one wavefront) for MFMA path. - */ - -template -__device__ void transform_level4_k( - const T* t, - const T* c, - T*& result, - T* workspace) -{ - constexpr int ndim = 3; - - T *t0 = workspace, *t1 = result; - auto tmp = t0; t0 = t1; t1 = tmp; - - mra::mTxmq_level4_k(t0, t, c); - for (int n = 1; n < ndim; ++n) { - mra::mTxmq_level4_k(t1, t0, c); - auto tmp2 = t0; t0 = t1; t1 = tmp2; - } -} - -/* One kernel binary per K. */ -template -LAUNCH_BOUNDS(64, 1) -__global__ void transform_kernel_level4_k(int nfuncs, - const T* A, const T* B, T* C, T* workspace) { - constexpr int K2NDIM = K * K * K; - T* w = workspace + blockIdx.x * K2NDIM; - for (int i = blockIdx.x; i < nfuncs; i += gridDim.x) { - const T* a = A + i * K2NDIM; - T* c = C + i * K2NDIM; - T* result = c; - transform_level4_k(a, B, result, w); - } -} - -template -inline int transform_level4_shmem_size(int K) { - return mra::mTxmq_level4_shmem_size(K); -} - -template -inline void submit_transform_level4_bench(int nfuncs, int nblocks, int K, - const T* A, const T* B, T* C, T* workspace, - Stream stream) -{ - Dim3 thread_dims = mra::mTxmq_level4_blockdim(K); - int smem_size = mra::mTxmq_level4_shmem_size(K); - -#define DISPATCH_L4(Kval) \ - case Kval: \ - CONFIGURE_KERNEL((transform_kernel_level4_k), smem_size); \ - CALL_KERNEL((transform_kernel_level4_k), std::min(nfuncs, nblocks), \ - thread_dims, smem_size, stream, \ - (nfuncs, A, B, C, workspace)); \ - break; - - switch (K) { - DISPATCH_L4( 6) - DISPATCH_L4( 8) - DISPATCH_L4(10) - DISPATCH_L4(12) - DISPATCH_L4(16) - DISPATCH_L4(20) - DISPATCH_L4(32) - default: - printf("submit_transform_level4_bench: unsupported K=%d\n", K); - } -#undef DISPATCH_L4 -} diff --git a/transform_level5.h b/transform_level5.h deleted file mode 100644 index 7955894..0000000 --- a/transform_level5.h +++ /dev/null @@ -1,81 +0,0 @@ -#pragma once - -#include "util.h" -#include "mxm_level5.h" - -/** - * Transform wrapper for Level-5 (wave-specialized double-buffering MFMA). - * - * Block = 256 threads (4 wavefronts, all doing MFMA). The 3-GEMM chain is - * preserved: each call to mTxmq_level5_k loads B once then calls the MFMA - * kernel which iterates over A chunks internally. __syncthreads() between - * steps ensures the full output of one GEMM is visible before the next starts. - */ - -template -__device__ void transform_level5_k( - const T* t, - const T* c, - T*& result, - T* workspace) -{ - constexpr int ndim = 3; - - T *t0 = workspace, *t1 = result; - auto tmp = t0; t0 = t1; t1 = tmp; - - mra::mTxmq_level5_k(t0, t, c); - for (int n = 1; n < ndim; ++n) { - mra::mTxmq_level5_k(t1, t0, c); - auto tmp2 = t0; t0 = t1; t1 = tmp2; - } -} - -/* One kernel binary per K. */ -template -LAUNCH_BOUNDS(256, 1) -__global__ void transform_kernel_level5_k(int nfuncs, - const T* A, const T* B, T* C, T* workspace) { - constexpr int K2NDIM = K * K * K; - T* w = workspace + blockIdx.x * K2NDIM; - for (int i = blockIdx.x; i < nfuncs; i += gridDim.x) { - const T* a = A + i * K2NDIM; - T* c = C + i * K2NDIM; - T* result = c; - transform_level5_k(a, B, result, w); - } -} - -template -inline int transform_level5_shmem_size(int K) { - return mra::mTxmq_level5_shmem_size(K); -} - -template -inline void submit_transform_level5_bench(int nfuncs, int nblocks, int K, - const T* A, const T* B, T* C, T* workspace, - Stream stream) -{ - Dim3 thread_dims = mra::mTxmq_level5_blockdim(K); - int smem_size = mra::mTxmq_level5_shmem_size(K); - -#define DISPATCH_L5(Kval) \ - case Kval: \ - CONFIGURE_KERNEL((transform_kernel_level5_k), smem_size); \ - CALL_KERNEL((transform_kernel_level5_k), std::min(nfuncs, nblocks), \ - thread_dims, smem_size, stream, \ - (nfuncs, A, B, C, workspace)); \ - break; - - switch (K) { - DISPATCH_L5( 8) - DISPATCH_L5(12) - DISPATCH_L5(16) - DISPATCH_L5(20) - DISPATCH_L5(32) - default: - printf("submit_transform_level5_bench: unsupported K=%d\n", K); - } -#undef DISPATCH_L5 -} - diff --git a/transform_level7.h b/transform_level7.h deleted file mode 100644 index acd2cab..0000000 --- a/transform_level7.h +++ /dev/null @@ -1,62 +0,0 @@ -#pragma once - -#include "util.h" -#include "mxm_level7.h" - -/** - * Transform wrapper for Level 7. - * - * Unlike levels 1-6, the three-GEMM chain is executed inside a single call to - * mTxmq_level7_k because B must remain resident in VGPRs across all three GEMMs. - * A single LDS buffer of K*(K²+1) doubles is reused across all three GEMMs; only - * the final output is written to global memory C. - * - * For K=16: LDS = 16*257*8 = 32,864 bytes (~32 KB), well within the 64 KB limit. - * occupancy=1 is retained to maximise VGPR headroom for B and A register arrays. - */ - -template -LAUNCH_BOUNDS(256, 1) -__global__ void transform_kernel_level7_k(int nfuncs, - const T* A, const T* B, T* C) -{ - constexpr int K3 = K * K * K; - for (int i = blockIdx.x; i < nfuncs; i += gridDim.x) { - const T* a = A + i * K3; - T* c = C + i * K3; - mra::mTxmq_level7_k(c, a, B); - } -} - -template -inline int transform_level7_shmem_size(int K) { - return (int)mra::mTxmq_level7_shmem_size(K); -} - -template -inline void submit_transform_level7_bench(int nfuncs, int nblocks, int K, - const T* A, const T* B, T* C, T* /*workspace*/, - Stream stream) -{ - Dim3 thread_dims = mra::mTxmq_level7_blockdim(K); - int smem_size = transform_level7_shmem_size(K); - -#define DISPATCH_L7(Kval) \ - case Kval: \ - CONFIGURE_KERNEL((transform_kernel_level7_k), smem_size); \ - CALL_KERNEL((transform_kernel_level7_k), std::min(nfuncs, nblocks), \ - thread_dims, smem_size, stream, \ - (nfuncs, A, B, C)); \ - break; - - switch (K) { - DISPATCH_L7( 8) - DISPATCH_L7(12) - DISPATCH_L7(16) - DISPATCH_L7(20) - DISPATCH_L7(32) - default: - printf("submit_transform_level7_bench: unsupported K=%d\n", K); - } -#undef DISPATCH_L7 -} diff --git a/transform_rocwmma.h b/transform_rocwmma.h deleted file mode 100644 index 999e004..0000000 --- a/transform_rocwmma.h +++ /dev/null @@ -1,234 +0,0 @@ -#ifndef HAVE_TRANSFORM_ROCWMMA_H -#define HAVE_TRANSFORM_ROCWMMA_H - -#include "util.h" -#include "mxm.h" - -#if defined(__HIP_DEVICE_COMPILE__) -#include -#include - -template -__device__ void transform_klt16( - const T* a, - const T* b, - T*& c) -{ - /* hold everything in shared memory */ - extern __shared__ char smem[]; - T* shmem = reinterpret_cast(smem); - T* b_shmem = shmem; - T* a_shmem = b_shmem + K * K; - T* c_shmem = a_shmem + K * K * K; - - const size_type tid = thread_id(); - const size_type num_threads = block_size(); - - /* load A and B into shared memory */ - for (int idx = tid; idx < K * K; idx += num_threads) { - b_shmem[idx] = b[idx]; - } - for (int idx = tid; idx < K * K * K; idx += num_threads) { - a_shmem[idx] = a[idx]; - } - __syncthreads(); - - for (int d = 0; d < 3; ++d) { - /* compute c = a * b, with c also in shared memory */ - for (int i = tid/K; i < K * K; i += num_threads/K) { - T* ci = c_shmem + i * K; - int j = tid % K; - T sum = 0; - for (long k = 0; k < K; ++k) { /* not parallelized */ - sum += a_shmem[k * K * K + i] * b_shmem[k * K + j]; - } - if (d == 0) { - ci[j] = sum; - } else { - ci[j] += sum; - } - } - __syncthreads(); - - /* swap A and C for the next iteration, so we always read from A and write to C */ - std::swap(a_shmem, c_shmem); - } - - // write back result to global memory - for (int idx = tid; idx < K * K * K; idx += num_threads) { - c[idx] = a_shmem[idx]; // a_shmem is the final result after 3 iterations - } - -} - -/** - * This implementation only works on K=16. For other K values, we fall back to the Level-3 implementation. - * The fragment size is 16x16x16. - * The block dimension is 256 threads (one wavefront) to match the MFMA requirements. - * We load B into a fragment and keep it there. - * We load A into fragments. Each wave-front stores 4 input fragments and 4 output fragments. - * - */ -template -__device__ void transform_rocwmma_k( - const T* a, - const T* b, - T*& c, - T* workspace) -{ - constexpr uint32_t WM = 16, WN = 16, WK = 16; - constexpr uint32_t WAVE = 64; // CDNA wavefront size - constexpr const int ndim = 3; // fixed for benchmark - - using FragmentA = rocwmma::fragment; - using FragmentB = rocwmma::fragment; - using FragmentAcc = rocwmma::fragment; - - if constexpr (K < 16) { - // Fallback to non mma implementation - transform_klt16(a, b, c); - return; - } else if constexpr (K > 16) { - // Not supported, fallback to Level-3 - transform_level3_k(a, b, c, workspace); - return; - } else { - - /* single shared memory region, holds A and C */ - extern __shared__ char smem[]; - T* shmem = reinterpret_cast(smem); - - int wave_id = thread_id() / WAVE; - constexpr int num_waves = (MAX_THREADS_PER_BLOCK / WAVE); - constexpr int frags_per_wave = (K / num_waves); - - // load b into a fragment - FragmentB b_frag; - rocwmma::load_matrix_sync(b_frag, b, K); - - /* load A into shared memory */ - for (int idx = thread_id(); idx < K * K; idx += block_size()) { - shmem[idx] = a[idx]; - } - __syncthreads(); - - /* every wavefront handles 4 fragments */ - FragmentA a_frags[frags_per_wave]; - FragmentAcc acc_frags[frags_per_wave]; - - for (int d = 0; d < ndim; ++d) { - /* load all wavefront fragments */ - for (int i = 0; i < frags_per_wave; ++i) - { - /* load the current fragment */ - if (i < frags_per_wave - 1 || frags_per_wave == 1) { - rocwmma::load_matrix_sync(a_frags[i], shmem + (i + wave_id * frags_per_wave) * K, K*K); - // TODO: is it worth prefetching the next fragment? - //if constexpr (frags_per_wave > 1) { - // rocwmma::load_matrix_sync(a_frags[i+1], shmem + (i+1 + wave_id * frags_per_wave) * K, K*K); - //} - } - rocwmma::fill_fragment(acc_frags[i], static_cast(0)); - rocwmma::mma_sync(acc_frags[i], a_frags[i], b_frag, acc_frags[i]); - } - - /* write back all fragments */ - if (d == ndim - 1) { - /* last iteration, write back to global memory */ - for (int i = 0; i < frags_per_wave; ++i) - { - rocwmma::store_matrix_sync(c + (i + wave_id * frags_per_wave) * K * K, - acc_frags[i], K); - } - } else { - /* wait for all fragments to be loaded from shared memory */ - rocwmma::synchronize_workgroup(); - /* write back to shared memory */ - for (int i = 0; i < frags_per_wave; ++i) - { - rocwmma::store_matrix_sync(shmem + (i + wave_id * frags_per_wave) * K * K, - acc_frags[i], K); - } - } - - rocwmma::synchronize_workgroup(); - } - } -} - -#endif // not __HIP_DEVICE_COMPILE__ - -// fwd-decl for kernel -template -__device__ void transform_rocwmma_k( - const T* a, - const T* b, - T*& c, - T* workspace); - -/* One kernel binary per K — register pressure is proportional to K, not max(K). */ -template -LAUNCH_BOUNDS(MAX_THREADS_PER_BLOCK, 1) -__global__ void transform_rocwmma(int nfuncs, - const T* A, const T* B, T* C, T* workspace) { - constexpr int K2NDIM = K * K * K; - T* w = workspace + blockIdx.x * K2NDIM; - for (int i = blockIdx.x; i < nfuncs; i += gridDim.x) { - const T* a = A + i * K2NDIM; - T* c = C + i * K2NDIM; - /* result pointer starts at c; workspace is w */ - T* result = c; - transform_rocwmma_k(a, B, result, w); - } -} - -template -inline int transform_rocwmma_shmem_size(int K) { - if (K <= 16) { - // For K<=16, we load A and B into shared memory. We need space for A (K^3), B (K^2), and C (K^3). - return (K*K*K + K*K); - } else if (K == 16) { - // For K==16, we hold one copy of A/C in LDS - return K*K*K; - } else { - return transform_level3_shmem_size(K); - } -} - -template -inline Dim3 transform_rocwmma_blockdim(int K) { - return {256, 1, 1}; -} - -template -inline void submit_transform_rocwmma_bench(int nfuncs, int nblocks, int K, - const T* A, const T* B, T* C, T* workspace, - Stream stream) -{ - Dim3 thread_dims = transform_rocwmma_blockdim(K); - int smem_size = transform_rocwmma_shmem_size(K); - -#define DISPATCH_L3(Kval) \ - case Kval: \ - CONFIGURE_KERNEL((transform_rocwmma), smem_size); \ - CALL_KERNEL((transform_rocwmma), std::min(nfuncs, nblocks), \ - thread_dims, smem_size, stream, \ - (nfuncs, A, B, C, workspace)); \ - break; - - switch (K) { - DISPATCH_L3( 6) - DISPATCH_L3( 8) - DISPATCH_L3(10) - DISPATCH_L3(12) - DISPATCH_L3(16) - DISPATCH_L3(20) - DISPATCH_L3(32) - default: - printf("submit_transform_level3_bench: unsupported K=%d\n", K); - } -#undef DISPATCH_L3 -} - - -#endif // HAVE_TRANSFORM_ROCWMMA_H \ No newline at end of file diff --git a/transformbench.cu b/transformbench.cu index b81b3cf..ecb7533 100644 --- a/transformbench.cu +++ b/transformbench.cu @@ -1,164 +1,116 @@ - -#include #include +#include +#include +#include #include "transform.h" -#include "transform_cublasdx.h" #include "transform_level2.h" #include "transform_level3.h" -#include "transform_level4.h" -#include "transform_kron.h" -#include "mxm_cublasdx.h" +#include "transform_blocked.h" +#include "transform_blocked_rocwmma.h" #include "util.h" -/** - * Optimization levels: - * 1 - L1: thread-parallel over j, serial k-loop, all global memory (mxm.h fallback) - * 2 - L2: B in LDS, threads distributed over rows - * 3 - L3: B in LDS + register accumulation (acc[K] in VGPRs) - * 4 - L4: AMD MFMA (GFX90A/GFX940) for K=16,32; falls back to L3 elsewhere - * 5 - L5: cuBLASDx (NVIDIA only, double-buffered block GEMM with Tensor Cores) - * 6 - L6: Single GEMM via K³×K³ Kronecker product (B^T ⊗ B^T ⊗ B^T) - */ - -template -void transform_bench(int nreps, int ntasks, int nfuncs, int nblocks, int K, int level, int num_streams) { - - std::vector streams(num_streams); // PaRSEC uses 4 streams by default - T* A, *B, *C, *workspace; - MALLOC(&A, nfuncs * K * K * K * sizeof(T)); // N x KxKxK tensors - MALLOC(&B, K * K * sizeof(T)); // KxK matrix - MALLOC(&C, nfuncs * K * K * K * sizeof(T)); // N x KxKxK tensors - MALLOC(&workspace, nblocks * K * K * K * sizeof(T)); // per-block scratch +template +void transform_bench(int nreps, int ntasks, int nfuncs, int nblocks, int K, + int num_streams, int level) { + // Auto-select level if not specified. + if (level < 1) { + // crash out - for (int i = 0; i < num_streams; ++i) { - CREATE_STREAM(&streams[i]); - } - - /* Warn early if a level is unavailable */ - if (level == 5 && !MRA_HAVE_CUBLASDX) { - std::cerr << "Warning: level 5 (cuBLASDx) requested but not available; " - "falling back to level 1\n"; - level = 1; + fprintf(stderr, "Invalid level %d. Must be 1, 2, or 3.\n", level); + assert(false); + return; } - /* Resolve default level */ - if (level <= 0) { - level = (MRA_HAVE_CUBLASDX) ? 5 : 3; - } + // Level metadata for output line. + Dim3 thread_dims; + int smem_size; + std::string level_name; - const char* level_names[] = { - "", /* unused [0] */ - "L1-global", /* 1 */ - "L2-lds_b", /* 2 */ - "L3-regblk", /* 3 */ - "L4-mfma", /* 4 */ - "L5-cublasdx",/* 5 */ - "L6-kron" /* 6 */ - }; - - /* Print shmem and thread dims for this level */ - int smem_size = 0; - Dim3 thread_dims = {1, 1, 1}; switch (level) { - case 1: - smem_size = mra::mTxmq_shmem_size(K); - thread_dims = mra::mTxmq_blockdim(K); - break; - case 2: - smem_size = transform_level2_shmem_size(K); - thread_dims = mra::mTxmq_level2_blockdim(K); - break; - case 3: - smem_size = transform_level3_shmem_size(K); - thread_dims = mra::mTxmq_level3_blockdim(K); - break; - case 4: - smem_size = transform_level4_shmem_size(K); - thread_dims = mra::mTxmq_level4_blockdim(K); - break; - case 5: - smem_size = transform_cublasdx_shmem_size(K); - thread_dims = mra::mTxmq_blockdim(K); - break; - case 6: - smem_size = kron_shmem_size(K); - thread_dims = kron_blockdim(K); - break; + case 2: + thread_dims = mra::mTxmq_L2_blockdim(K); + smem_size = (int)mra::mTxmq_L2_shmem_size(K); + level_name = "L2-lds"; + break; + case 3: + thread_dims = mra::mTxmq_L3_blockdim(K); + smem_size = (int)mra::L3_shmem_size(K); + level_name = "L3-regblk"; + break; + case 7: + thread_dims = blocked_blockdim(K); + smem_size = (int)blocked_shmem_size(K); + level_name = "L7-blocked"; + break; + case 8: + thread_dims = blocked_rocwmma_blockdim(K); + smem_size = (int)blocked_rocwmma_shmem_size(K); + level_name = "L8-blocked-rocwmma"; + break; + default: // level == 1 + thread_dims = mra::mTxmq_blockdim(K); + smem_size = 0; + level_name = "L1-global"; + break; } - /* Level 6: build Kronecker matrix once, before the timing loop */ - T* KronMat = nullptr; - blasHandle_t blas_handle{}; - if (level == 6) { - const int K3 = K * K * K; - const size_t kron_bytes = (size_t)K3 * K3 * sizeof(T); - std::cout << "L6-kron: allocating " << kron_bytes / (1024*1024.0) - << " MB for " << K3 << "x" << K3 << " Kronecker matrix\n"; - MALLOC(&KronMat, kron_bytes); - blasCreate(&blas_handle); - build_kron_matrix(K, B, KronMat, streams[0]); - SYNC_STREAM(streams[0]); - } + std::vector streams(num_streams); + T *A; + T *B; + T *C; + T *workspace; + MALLOC(&A, (size_t)nfuncs * K * K * K * sizeof(T)); + MALLOC(&B, (size_t)K * K * sizeof(T)); + MALLOC(&C, (size_t)nfuncs * K * K * K * sizeof(T)); + MALLOC(&workspace, (size_t)nblocks * K * K * K * sizeof(T)); + + for (int i = 0; i < num_streams; ++i) + CREATE_STREAM(&streams[i]); std::chrono::time_point beg, end; - for (int i = 0; i < nreps+1; ++i) { + for (int i = 0; i < nreps + 1; ++i) { beg = std::chrono::high_resolution_clock::now(); for (int t = 0; t < ntasks; ++t) { + Stream s = streams[t % num_streams]; switch (level) { - case 1: - submit_transform_bench(nfuncs, nblocks, K, A, B, C, workspace, streams[t%num_streams]); - break; - case 2: - submit_transform_level2_bench(nfuncs, nblocks, K, A, B, C, workspace, streams[t%num_streams]); - break; - case 3: - submit_transform_level3_bench(nfuncs, nblocks, K, A, B, C, workspace, streams[t%num_streams]); - break; - case 4: - submit_transform_level4_bench(nfuncs, nblocks, K, A, B, C, workspace, streams[t%num_streams]); - break; - case 5: - submit_transform_cublasdx_bench(nfuncs, nblocks, K, A, B, C, workspace, streams[t%num_streams]); - break; - case 6: - submit_transform_kron_bench(nfuncs, K, A, KronMat, C, blas_handle, streams[t%num_streams]); - break; + case 2: + submit_transform_bench_L2(nfuncs, nblocks, K, A, B, C, workspace, s); + break; + case 3: + submit_transform_bench_L3(nfuncs, nblocks, K, A, B, C, workspace, s); + break; + case 7: + submit_transform_bench_blocked(nfuncs, nblocks, K, A, B, C, workspace, s); + break; + case 8: + submit_transform_bench_blocked_rocwmma(nfuncs, nblocks, K, A, B, C, workspace, s); + break; + default: + submit_transform_bench(nfuncs, nblocks, K, A, B, C, workspace, s); + break; } } - for (int t = 0; t < num_streams; ++t) { + for (int t = 0; t < num_streams; ++t) SYNC_STREAM(streams[t]); - } end = std::chrono::high_resolution_clock::now(); - /* skip warm-up */ if (i > 0) { - auto us = (std::chrono::duration_cast(end - beg).count()); - /* L6 does one K³×K³ GEMM per task (2·K⁶ FLOPs); others do 3 passes (2·3·K⁴ FLOPs) */ - uint64_t flops = (level == 6) - ? (uint64_t)ntasks * 2 * (uint64_t)K*K*K * (uint64_t)K*K*K * nfuncs - : (uint64_t)ntasks * K * K * K * K * 3 * 2 * nfuncs; + auto us = std::chrono::duration_cast(end - beg) + .count(); + uint64_t flops = (uint64_t)ntasks * K * K * K * K * 3 * 2 * nfuncs; std::cout << "Transform" - << ";level=" << level_names[level] - << ";nfuncs=" << nfuncs - << ";nblocks=" << nblocks - << ";K=" << K - << ";tasks=" << ntasks - << ";threads={" << thread_dims.x << "," << thread_dims.y << "," << thread_dims.z << "}" - << ";smem=" << smem_size - << ";Time(us)=" << us - << ";GFlop=" << flops*1e-9 - << ";Gflop/s=" << (1e-3 * flops) / us - << std::endl; + << ";level=" << level_name << ";nfuncs=" << nfuncs + << ";nblocks=" << nblocks << ";K=" << K << ";tasks=" << ntasks + << ";threads={" << thread_dims.x << "," << thread_dims.y << "," + << thread_dims.z << "}" + << ";smem=" << smem_size << ";Time(us)=" << us + << ";GFlop=" << flops * 1e-9 + << ";Gflop/s=" << (1e-3 * flops) / us << std::endl; } } - if (level == 6) { - blasDestroy(blas_handle); - FREE(KronMat); - } - FREE(A); FREE(B); FREE(C); @@ -166,28 +118,19 @@ void transform_bench(int nreps, int ntasks, int nfuncs, int nblocks, int K, int } int main(int argc, char **argv) { - auto opt = OptionParser(argc, argv); - int nreps = opt.parse("-r", 5); - int ntasks = opt.parse("-n", 500); - int N = opt.parse("-N", 2048); /* number of functions */ - int K = opt.parse("-K", 16); /* number of coefficients */ - int M = opt.parse("-M", 512); /* max number of blocks */ - int level = opt.parse("-l", 0); /* 0 = auto, 1-5 = explicit */ - int num_streams = opt.parse("-s", 4); /* number of concurrent streams to use */ - - /* Legacy -m flag: force level 1 */ - if (opt.exists("-m")) level = 1; + int nreps = opt.parse("-r", 5); + int ntasks = opt.parse("-n", 500); // Number + int N = opt.parse("-N", 2048); // Number of functions + int K = opt.parse("-K", 16); + int M = opt.parse("-M", 512); + int num_streams = opt.parse("-s", 4); + int level = opt.parse("-l", -1); std::cout << "Running benchmark" - << " nreps=" << nreps - << " ntasks=" << ntasks - << " N=" << N - << " K=" << K - << " M=" << M - << " level=" << (level <= 0 ? (MRA_HAVE_CUBLASDX ? 5 : 3) : level) - << std::endl; + << " nreps=" << nreps << " ntasks=" << ntasks << " N=" << N + << " K=" << K << " M=" << M << " level=" << level << std::endl; - transform_bench(nreps, ntasks, N, M, K, level, num_streams); + transform_bench(nreps, ntasks, N, M, K, num_streams, level); } diff --git a/validate.hip b/validate.hip new file mode 100644 index 0000000..c17cb39 --- /dev/null +++ b/validate.hip @@ -0,0 +1,263 @@ +/** + * validate.hip — compare GPU L1 output against the CPU reference transform. + * + * CPU reference mirrors madness/src/madness/tensor/transform3d.cc :: transform3d(): + * + * result(a,b,c) = sum(i',j',k') A(i',j',k') C(i',a) C(j',b) C(k',c) + * + * implemented as three sequential mTxm passes: + * mTxm(K², K, K, r, A, B) pass 1: contract dim-0 + * mTxm(K², K, K, tmp, r, B) pass 2: contract dim-1 + * mTxm(K², K, K, r, tmp, B) pass 3: contract dim-2 + * + * where mTxm(dimi, dimj, dimk, c, a, b): + * c(i,j) += sum(k) a(k,i) * b(k,j) + */ + +#include +#include +#include +#include +#include + +#include "transform.h" +#include "util.h" + +// --------------------------------------------------------------------------- +// CPU reference +// --------------------------------------------------------------------------- + +// c(i,j) += sum(k) a(k,i) * b(k,j) (from MADNESS mTxm) +static void cpu_mTxm(int dimi, int dimj, int dimk, + double* c, const double* a, const double* b) { + for (int i = 0; i < dimi; ++i) { + double* ci = c + i * dimj; + const double* aik = a + i; // a(k,i) strides by dimi + for (int k = 0; k < dimk; ++k, aik += dimi) { + double aki = *aik; + for (int j = 0; j < dimj; ++j) + ci[j] += aki * b[k * dimj + j]; + } + } +} + +// Three-pass transform matching transform3d.cc +static void cpu_transform3d(int K, const double* A, const double* B, double* result) { + int K3 = K * K * K; + std::vector tmp(K3, 0.0); + std::fill(result, result + K3, 0.0); + + cpu_mTxm(K * K, K, K, result, A, B); // pass 1 + cpu_mTxm(K * K, K, K, tmp.data(), result, B); // pass 2 + std::fill(result, result + K3, 0.0); + cpu_mTxm(K * K, K, K, result, tmp.data(), B); // pass 3 +} + +// --------------------------------------------------------------------------- +// Block-distributed transform (corner-turn variant). +// +// Block A along the fastest tensor index k': block s owns A[:,:,k'=s] as a +// K x K matrix indexed by (i', j'). +// +// Pass 1 (contract i'): local. blk_s <- B^T . blk_s -> [a, j'] +// Pass 2 (contract j'): local. blk_s <- blk_s . B -> [a, b] +// Pass 3 (contract k'): all-to-all "corner turn" where block t pulls row +// a=t from every other block, storing into the +// (b, k') layout. Then local right-multiply: +// blk_t <- blk_t . B -> [b, c] +// +// Applying B on the right in pass 3 makes the output land in canonical (b, c) +// order directly; the in-block transpose is absorbed into the corner turn for +// free (same number of loads/stores, different destination cell). +// --------------------------------------------------------------------------- + +static void print_blocks(const char* label, int K, + const std::vector>& blk) { + std::cout << " " << label << ":\n"; + for (int s = 0; s < (int)blk.size(); ++s) { + std::cout << " block " << s << ":\n"; + for (int i = 0; i < K; ++i) { + std::cout << " "; + for (int j = 0; j < K; ++j) + std::cout << " " << std::setw(8) << blk[s][i*K + j]; + std::cout << "\n"; + } + } +} + +static void cpu_transform3d_blocked(int K, const double* A, const double* B, + double* out, bool trace = false) { + const int K2 = K * K; + std::vector> blk(K, std::vector(K2)); + + // distribute: block s := A[:,:,k=s] (strided read from the flat layout) + for (int s = 0; s < K; ++s) + for (int i = 0; i < K; ++i) + for (int j = 0; j < K; ++j) + blk[s][i*K + j] = A[i*K2 + j*K + s]; + if (trace) print_blocks("after distribute (block_s[i,j] = A[i,j,s])", K, blk); + + // Pass 1: contract i'. blk_s <- B^T . blk_s -> indexed [a, j] + for (int s = 0; s < K; ++s) { + std::vector t(K2, 0.0); + for (int a = 0; a < K; ++a) + for (int j = 0; j < K; ++j) + for (int i = 0; i < K; ++i) + t[a*K + j] += B[i*K + a] * blk[s][i*K + j]; + blk[s] = std::move(t); + } + if (trace) print_blocks("after pass 1 (block_s[a,j])", K, blk); + + // Pass 2: contract j'. blk_s <- blk_s . B -> indexed [a, b] + for (int s = 0; s < K; ++s) { + std::vector t(K2, 0.0); + for (int a = 0; a < K; ++a) + for (int b = 0; b < K; ++b) + for (int j = 0; j < K; ++j) + t[a*K + b] += blk[s][a*K + j] * B[j*K + b]; + blk[s] = std::move(t); + } + if (trace) print_blocks("after pass 2 (block_s[a,b], fixed k'=s)", K, blk); + + // Corner turn: block t pulls row a=t from every block, stored as (b, k'). + // Before: blk[s][a, b] (block-index = k' = s) + // After : blk[t][b, k] (block-index = a = t) + { + std::vector> exch(K, std::vector(K2)); + for (int t = 0; t < K; ++t) + for (int b = 0; b < K; ++b) + for (int k = 0; k < K; ++k) + exch[t][b*K + k] = blk[k][t*K + b]; + blk = std::move(exch); + } + if (trace) print_blocks("after corner turn (block_t[b,k], fixed a=t)", K, blk); + + // Pass 3: contract k'. blk_t <- blk_t . B -> indexed [b, c] + for (int t = 0; t < K; ++t) { + std::vector r(K2, 0.0); + for (int b = 0; b < K; ++b) + for (int c = 0; c < K; ++c) + for (int k = 0; k < K; ++k) + r[b*K + c] += blk[t][b*K + k] * B[k*K + c]; + blk[t] = std::move(r); + } + if (trace) print_blocks("after pass 3 (block_t[b,c] = result[t,b,c])", K, blk); + + // Canonical write-out: block_a holds result[a, :, :] in (b, c) order. + for (int a = 0; a < K; ++a) + for (int b = 0; b < K; ++b) + for (int c = 0; c < K; ++c) + out[a*K2 + b*K + c] = blk[a][b*K + c]; +} + +// K=2 walkthrough with the counting tensor A[i,j,k] = 100i + 10j + k and B = I. +// With B = I the GEMMs become no-ops, so every stage just shows the current +// distribution of elements across blocks -- the shuffle pattern is laid bare. +// Expected final output == A (since B = I acts as identity in all three passes). +static void debug_blocked_K2() { + const int K = 2, K2 = 4, K3 = 8; + + std::vector A(K3); + for (int i = 0; i < K; ++i) + for (int j = 0; j < K; ++j) + for (int k = 0; k < K; ++k) + A[i*K2 + j*K + k] = 100.0*i + 10.0*j + k; + + std::vector B(K2, 0.0); + for (int i = 0; i < K; ++i) B[i*K + i] = 1.0; + + std::cout << "=== debug_blocked_K2 (A[i,j,k] = 100i+10j+k, B = I) ===\n"; + std::cout << " B = I makes every GEMM a no-op; the only motion is the shuffle.\n"; + std::cout << " expected final output == A.\n"; + + std::vector out(K3, 0.0); + cpu_transform3d_blocked(K, A.data(), B.data(), out.data(), /*trace=*/true); + + double max_err = 0.0; + for (int i = 0; i < K3; ++i) + max_err = std::max(max_err, std::abs(out[i] - A[i])); + std::cout << " final max_abs_err vs A = " << max_err + << (max_err < 1e-10 ? " MATCH" : " DIFFER") << "\n\n"; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- + +int main(int argc, char** argv) { + auto opt = OptionParser(argc, argv); + int K = opt.parse("-K", 16); + int N = opt.parse("-N", 8); + int debug = opt.parse("-debug", 0); + + if (debug) { + debug_blocked_K2(); + return 0; + } + + const int K3 = K * K * K; + + // Deterministic input + std::vector h_A(N * K3); + std::vector h_B(K * K); + for (int i = 0; i < N * K3; ++i) h_A[i] = (double)(i % 13) * 0.1 - 0.6; + for (int i = 0; i < K * K; ++i) h_B[i] = (double)(i % 7) * 0.1 - 0.3; + + // CPU reference (3D tensor transform) + std::vector h_ref(N * K3); + for (int n = 0; n < N; ++n) + cpu_transform3d(K, h_A.data() + n * K3, h_B.data(), h_ref.data() + n * K3); + + // Blocked algorithm: per-tensor distribute along k', corner-turn for pass 3 + std::vector h_blk(N * K3, 0.0); + for (int n = 0; n < N; ++n) + cpu_transform3d_blocked(K, h_A.data() + n * K3, h_B.data(), h_blk.data() + n * K3); + + double blk_max_err = 0.0; + for (int i = 0; i < N * K3; ++i) + blk_max_err = std::max(blk_max_err, std::abs(h_blk[i] - h_ref[i])); + std::cout << "K=" << K << " N=" << N + << " blk_vs_ref max_abs_err=" << blk_max_err + << (blk_max_err < 1e-10 ? " MATCH" : " DIFFER") + << std::endl; + + // GPU L1 + double *d_A, *d_B, *d_C, *d_workspace; + MALLOC(&d_A, N * K3 * sizeof(double)); + MALLOC(&d_B, K * K * sizeof(double)); + MALLOC(&d_C, N * K3 * sizeof(double)); + MALLOC(&d_workspace, K3 * sizeof(double)); // 1 block + + MEMCPY_H2D(d_A, h_A.data(), N * K3 * sizeof(double)); + MEMCPY_H2D(d_B, h_B.data(), K * K * sizeof(double)); + + Stream stream; + CREATE_STREAM(&stream); + submit_transform_bench(N, /*nblocks=*/1, K, d_A, d_B, d_C, d_workspace, stream); + SYNC_STREAM(stream); + + std::vector h_gpu(N * K3); + MEMCPY_D2H(h_gpu.data(), d_C, N * K3 * sizeof(double)); + + FREE(d_A); FREE(d_B); FREE(d_C); FREE(d_workspace); + + // Compare + double max_abs_err = 0.0, max_ref_val = 0.0; + for (int i = 0; i < N * K3; ++i) { + double err = std::abs(h_gpu[i] - h_ref[i]); + double ref = std::abs(h_ref[i]); + if (err > max_abs_err) max_abs_err = err; + if (ref > max_ref_val) max_ref_val = ref; + } + double rel_err = (max_ref_val > 0.0) ? max_abs_err / max_ref_val : max_abs_err; + + bool pass = rel_err < 1e-10; + std::cout << "K=" << K + << " N=" << N + << " max_abs_err=" << max_abs_err + << " max_rel_err=" << rel_err + << (pass ? " PASS" : " FAIL") + << std::endl; + + return pass ? 0 : 1; +} diff --git a/validate_levels.hip b/validate_levels.hip index f7177e0..3989ae9 100644 --- a/validate_levels.hip +++ b/validate_levels.hip @@ -1,170 +1,140 @@ /** - * Correctness test: compare any optimization level against the level 1 reference. + * validate_levels.hip — correctness test for all optimization levels. * - * Usage: - * ./validate_levels [-l ] [-K ] [-N ] + * Compares GPU output from the selected level against the CPU reference transform: + * result(a,b,c) = sum(i,j,k) A(i,j,k) * B(i,a) * B(j,b) * B(k,c) + * implemented via three sequential mTxm passes. * - * -l level to validate (2-7, default 3) - * 2 L2: B cached in LDS - * 3 L3: register blocking (K-templated) - * 4 L4: AMD MFMA + L3 fallback - * 5 L5: wave-specialised double-buffering MFMA (HIP only) - * 6 L6: Kronecker product GEMM (hipBLAS) - * 7 L7: multi-wave MFMA, B resident in VGPRs (HIP only) - * -K single K value; if omitted sweeps K in {4,6,8,10} - * -N batch size (default 16) + * Usage: + * ./validate_levels -l -K -N + * level: 1=L1-global, 2=L2-lds, 3=L3-regblk, 7=blocked (corner-turn) (default: 3) + * K: transform order (default: 16) + * N: number of cubes (default: 8) */ -#include #include +#include #include #include +#include "transform.h" +#include "transform_level2.h" +#include "transform_level3.h" +#include "transform_blocked.h" +#include "transform_blocked_rocwmma.h" #include "util.h" -#include "transform.h" // L1 — reference -#include "transform_level2.h" // L2 -#include "transform_level3.h" // L3 -#include "transform_level4.h" // L4 -#ifdef MRA_HAVE_HIP -# include "transform_level5.h" // L5 — wave-specialised MFMA (HIP only) -# include "transform_level7.h" // L7 — multi-wave MFMA, B in VGPRs (HIP only) -#else -# include "transform_cublasdx.h" // L5 — cuBLASDx (NVIDIA only) -#endif -#include "transform_kron.h" // L6 - -template -void test_level(int level, int K, int nfuncs) { - const int K3 = K * K * K; - const int nblocks = nfuncs; - - // Allocate and fill host arrays with random data - std::vector h_A(nfuncs * K3), h_B(K * K); - std::vector h_Cref(nfuncs * K3), h_Ctest(nfuncs * K3); - std::srand(42); - for (auto& v : h_A) v = (T)std::rand() / RAND_MAX; - for (auto& v : h_B) v = (T)std::rand() / RAND_MAX; - - // Device allocations - T *d_A, *d_B, *d_Cref, *d_Ctest, *d_workspace_ref, *d_workspace; - MALLOC(&d_A, nfuncs * K3 * sizeof(T)); - MALLOC(&d_B, K * K * sizeof(T)); - MALLOC(&d_Cref, nfuncs * K3 * sizeof(T)); - MALLOC(&d_Ctest, nfuncs * K3 * sizeof(T)); - MALLOC(&d_workspace_ref, nfuncs * K3 * sizeof(T)); - MALLOC(&d_workspace, nfuncs * K3 * sizeof(T)); - - T *d_KronMat = nullptr; - if (level == 6) { - MALLOC(&d_KronMat, (size_t)K3 * K3 * sizeof(T)); + +// --------------------------------------------------------------------------- +// CPU reference (matches madness/src/madness/tensor/transform3d.cc) +// c(i,j) += sum(k) a(k,i) * b(k,j) +// --------------------------------------------------------------------------- +static void cpu_mTxm(int dimi, int dimj, int dimk, + double* c, const double* a, const double* b) +{ + for (int i = 0; i < dimi; ++i) { + double* ci = c + i * dimj; + const double* aik = a + i; // a(k,i) strides by dimi + for (int k = 0; k < dimk; ++k, aik += dimi) { + double aki = *aik; + for (int j = 0; j < dimj; ++j) + ci[j] += aki * b[k * dimj + j]; + } } +} + +static void cpu_transform3d(int K, const double* A, const double* B, double* result) +{ + int K3 = K * K * K; + std::vector tmp(K3, 0.0); + std::fill(result, result + K3, 0.0); - // Copy inputs to device - MEMCPY_H2D(d_A, h_A.data(), nfuncs * K3 * sizeof(T)); - MEMCPY_H2D(d_B, h_B.data(), K * K * sizeof(T)); + cpu_mTxm(K * K, K, K, result, A, B); // pass 1 + cpu_mTxm(K * K, K, K, tmp.data(), result, B); // pass 2 + std::fill(result, result + K3, 0.0); + cpu_mTxm(K * K, K, K, result, tmp.data(), B); // pass 3 +} + +// --------------------------------------------------------------------------- +int main(int argc, char** argv) +{ + auto opt = OptionParser(argc, argv); + int K = opt.parse("-K", 16); + int N = opt.parse("-N", 8); + int level = opt.parse("-l", 3); + + const int K3 = K * K * K; + + // Deterministic inputs (same as validate.hip) + std::vector h_A(N * K3); + std::vector h_B(K * K); + for (int i = 0; i < N * K3; ++i) h_A[i] = (double)(i % 13) * 0.1 - 0.6; + for (int i = 0; i < K * K; ++i) h_B[i] = (double)(i % 7) * 0.1 - 0.3; + + // CPU reference + std::vector h_ref(N * K3); + for (int n = 0; n < N; ++n) + cpu_transform3d(K, h_A.data() + n * K3, h_B.data(), h_ref.data() + n * K3); + + // GPU buffers + double *d_A, *d_B, *d_C, *d_workspace; + MALLOC(&d_A, (size_t)N * K3 * sizeof(double)); + MALLOC(&d_B, (size_t)K * K * sizeof(double)); + MALLOC(&d_C, (size_t)N * K3 * sizeof(double)); + MALLOC(&d_workspace, (size_t)K3 * sizeof(double)); // 1 block + + MEMCPY_H2D(d_A, h_A.data(), (size_t)N * K3 * sizeof(double)); + MEMCPY_H2D(d_B, h_B.data(), (size_t)K * K * sizeof(double)); Stream stream; CREATE_STREAM(&stream); - // --- Reference: level 1 --- - submit_transform_bench(nfuncs, nblocks, K, d_A, d_B, d_Cref, d_workspace_ref, stream); - SYNC_STREAM(stream); - - // --- Tested level --- + // Dispatch to selected level + const int nblocks = 1; switch (level) { + case 1: + submit_transform_bench(N, nblocks, K, d_A, d_B, d_C, d_workspace, stream); + break; case 2: - submit_transform_level2_bench(nfuncs, nblocks, K, d_A, d_B, d_Ctest, d_workspace, stream); - SYNC_STREAM(stream); + submit_transform_bench_L2(N, nblocks, K, d_A, d_B, d_C, d_workspace, stream); break; case 3: - submit_transform_level3_bench(nfuncs, nblocks, K, d_A, d_B, d_Ctest, d_workspace, stream); - SYNC_STREAM(stream); - break; - case 4: - submit_transform_level4_bench(nfuncs, nblocks, K, d_A, d_B, d_Ctest, d_workspace, stream); - SYNC_STREAM(stream); + submit_transform_bench_L3(N, nblocks, K, d_A, d_B, d_C, d_workspace, stream); break; -#ifdef MRA_HAVE_HIP - case 5: - submit_transform_level5_bench(nfuncs, nblocks, K, d_A, d_B, d_Ctest, d_workspace, stream); - SYNC_STREAM(stream); - break; -#else - case 5: - submit_transform_cublasdx_bench(nfuncs, nblocks, K, d_A, d_B, d_Ctest, d_workspace, stream); - SYNC_STREAM(stream); - break; -#endif - case 6: { - build_kron_matrix(K, d_B, d_KronMat, stream); - SYNC_STREAM(stream); - blasHandle_t blas_handle; - blasCreate(&blas_handle); - submit_transform_kron_bench(nfuncs, K, d_A, d_KronMat, d_Ctest, blas_handle, stream); - SYNC_STREAM(stream); - blasDestroy(blas_handle); - break; - } -#ifdef MRA_HAVE_HIP case 7: - submit_transform_level7_bench(nfuncs, nblocks, K, d_A, d_B, d_Ctest, d_workspace, stream); - SYNC_STREAM(stream); + submit_transform_bench_blocked(N, nblocks, K, d_A, d_B, d_C, d_workspace, stream); + break; + case 8: + submit_transform_bench_blocked_rocwmma(N, nblocks, K, d_A, d_B, d_C, d_workspace, stream); break; -#endif default: - std::cerr << "Unknown level " << level << " (valid: 2-7)\n"; - FREE(d_A); FREE(d_B); FREE(d_Cref); FREE(d_Ctest); - FREE(d_workspace_ref); FREE(d_workspace); - if (d_KronMat) FREE(d_KronMat); -#ifdef MRA_HAVE_HIP - (void)hipStreamDestroy(stream); -#else - (void)cudaStreamDestroy(stream); -#endif - return; + std::cerr << "Unknown level " << level << " (supported: 1, 2, 3, 7, 8)\n"; + return 1; } + SYNC_STREAM(stream); - // Copy results to host - MEMCPY_D2H(h_Cref.data(), d_Cref, nfuncs * K3 * sizeof(T)); - MEMCPY_D2H(h_Ctest.data(), d_Ctest, nfuncs * K3 * sizeof(T)); - - // Compare - T max_abs_err = 0, max_rel_err = 0; - for (int i = 0; i < nfuncs * K3; ++i) { - T abs_err = std::abs(h_Cref[i] - h_Ctest[i]); - T rel_err = abs_err / (std::abs(h_Cref[i]) + 1e-14); - max_abs_err = std::max(max_abs_err, abs_err); - max_rel_err = std::max(max_rel_err, rel_err); - } + std::vector h_gpu(N * K3); + MEMCPY_D2H(h_gpu.data(), d_C, (size_t)N * K3 * sizeof(double)); - std::cout << "K=" << K << " nfuncs=" << nfuncs << " level=" << level - << " max_abs_err=" << max_abs_err - << " max_rel_err=" << max_rel_err - << (max_rel_err < 1e-10 ? " PASS" : " FAIL") - << "\n"; - - FREE(d_A); FREE(d_B); FREE(d_Cref); FREE(d_Ctest); - FREE(d_workspace_ref); FREE(d_workspace); - if (d_KronMat) FREE(d_KronMat); -#ifdef MRA_HAVE_HIP - (void)hipStreamDestroy(stream); -#else - (void)cudaStreamDestroy(stream); -#endif -} + FREE(d_A); FREE(d_B); FREE(d_C); FREE(d_workspace); -int main(int argc, char** argv) { - OptionParser opts(argc, argv); - int level = opts.parse(std::string("-l"), 3); - int nfuncs = opts.parse(std::string("-N"), 16); - - if (opts.exists(std::string("-K"))) { - int K = opts.parse(std::string("-K"), 8); - test_level(level, K, nfuncs); - } else { - for (int K : {4, 6, 8, 10}) { - test_level(level, K, nfuncs); - } + // Compare + double max_abs_err = 0.0, max_ref_val = 0.0; + for (int i = 0; i < N * K3; ++i) { + double err = std::abs(h_gpu[i] - h_ref[i]); + double ref = std::abs(h_ref[i]); + if (err > max_abs_err) max_abs_err = err; + if (ref > max_ref_val) max_ref_val = ref; } - return 0; + double rel_err = (max_ref_val > 0.0) ? max_abs_err / max_ref_val : max_abs_err; + bool pass = (rel_err < 1e-10); + + std::cout << "level=L" << level + << " K=" << K + << " N=" << N + << " max_abs_err=" << max_abs_err + << " max_rel_err=" << rel_err + << (pass ? " PASS" : " FAIL") + << std::endl; + + return pass ? 0 : 1; }