[Experiment] ROCm backend#2300
Conversation
|
What an unexpected and amazing surprise! I'm absolutely thrilled. |
|
@awni |
|
I think this is good to stay as an experiment branch for some time while we work on core and CUDA. I don't think we have the bandwidth to merge this for a few months at least. Sorry if this is disappointing @NripeshN I don't mean to discourage you working on it. |
|
I would love to see the ROCm backend get more traction. The new AI series of processors by AMD have a similar advantage to Apple Silicon with unified memory and getting MLX to run on those processors would be neat. |
|
Stole my idea :( |
|
How is this even possible for such an awesome PR to be left like this? |
There was a problem hiding this comment.
Pull request overview
This PR adds experimental ROCm backend support to MLX, enabling execution on AMD GPUs. The implementation mirrors the CUDA backend structure, providing HIP-based implementations of core operations, memory management, and device handling.
Changes:
- Added ROCm backend infrastructure with device management, memory allocation, and stream handling
- Implemented HIP kernels for unary, binary, ternary operations, reductions, normalization (softmax, layer_norm, rms_norm), RoPE, and sorting
- Updated build system (CMake) to support ROCm compilation with configurable GPU architectures
Reviewed changes
Copilot reviewed 59 out of 59 changed files in this pull request and generated 13 comments.
Show a summary per file
| File | Description |
|---|---|
| CMakeLists.txt | Added MLX_BUILD_ROCM option and ROCm library detection |
| mlx/CMakeLists.txt | Integrated ROCm backend build configuration |
| mlx/device.cpp | Added ROCm device availability checks |
| mlx/backend/rocm/*.hip | HIP kernel implementations for various operations |
| mlx/backend/rocm/device.* | ROCm device and stream management |
| mlx/backend/rocm/allocator.* | ROCm-specific memory allocator using HIP unified memory |
| mlx/backend/rocm/worker.* | Async task execution worker for stream synchronization |
| mlx/backend/rocm/utils.* | HIP utility functions and error handling |
| mlx/backend/rocm/jit_module.* | JIT compilation support using HIPRTC |
| mlx/backend/rocm/device/*.hpp | Device-side utility functions and type definitions |
| mlx/backend/rocm/CMakeLists.txt | ROCm backend build configuration |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
👑👑👑 |
|
Can anyone run CMAKE_ARGS="-DMLX_BUILD_ROCM=ON" pip install -e .
CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES={based on your GPU}" pip install -e .Replace {based on your GPU} with your GPU architecture You can run rocm-smito get your GPU information |
|
I'm getting this CMake error: Running on Strix Halo (gfx1151) |
Could you retry with the latest push please (p.s. keep your fingers crossed while it compiles, worked for me 138th time)😅 |
Now what can I test? 😍 |
|
I'm getting this: |
I forgot to test the Python build my bad, can you try it now? Unfortunately I might not be able to help after it compiles, I don't have an AMD GPU to run tests😔 I've tried replicating most things from cuda, so hopefully it works |
|
Now fails on load with this: |
Omg I don't believe you did it without AMD card 😱😱 |
Haha docker literally saves me and humbles me at the same time |
|
Wait it works?😅 Ah unfortunately unless a magic fairy sends me a PC with AMD GPU I cannot help after this😭 With the ram prices I doubt the magic fairy has the funds either🥲 |
|
Lemme try adding a fix for both the issues above actually. I had just made a stub implementation earlier. |
|
@goniz give the last push a try maybe. It might not work but you will definitely not have the same error atleast |
|
|
Might fix it(????) |
|
… resident Follow-up to the device-binding fix. Three changes that keep the discrete GPU out of queue-wedge states: - AtomicEvent: signal via hipLaunchHostFunc and wait by host poll instead of hipStreamWriteValue64/WaitValue64. The value ops require hipMallocSignalMemory and silently no-op on a plain pinned-host counter, so the GPU-side wait never observes the value and the queue spins forever (100% busy, 0 mem traffic). - set_cache_limit: trim the reuse pool down to the new cap immediately (caller is at an idle point) instead of lazily on the next malloc, so the eviction's blocking hipFree doesn't fire mid-forward and wedge the queue. - device(): set blocking-sync flags per device index, not behind a single global bool (which left device 1 unflagged if device 0 was touched first). - Keep fine-grained (VRAM-resident, host-mappable) allocations on both the APU and discrete GPU; bump the discrete driver reserve 256MB -> 512MB.
The reduce-type kernel launch macro referenced data_offset, a structured binding, which C++17 forbids capturing. Copy it to a plain local first so the file compiles when rebuilt. No behavior change; the KV/None path is unaffected.
The dims_<D "RoPE without copy" path shared a donatable input's buffer and indexed it with a contiguous T*D head stride (strides[0]=mat_size) while taking the row/element strides from the input's actual layout. For a non-contiguous (transposed [B,H,T,D]) Q view that mix produces out-of-bounds addressing; the error grows with head index and T, so it stays in-bounds for short prompts but runs off the buffer at T=512, faulting the GPU command processor and wedging the queue on long (>512-token) prefills. Gate the in-place path on row_contiguous; fall back to copy_gpu otherwise (matches the CUDA backend, which always materializes a contiguous output here).
RMSNorm::eval_gpu forced a contiguous_copy_gpu whenever the input rows were not tightly packed (stride[-2] != axis_size), e.g. the sliced per-head q/k norm where each head's vector sits in a wider stride. The kernel indexed rows as row*axis_size, so it could only handle packed input. Add a strided kernel that computes each row's base offset from the leading dims' shape/strides (last dim must be contiguous) and writes a packed output, and route the previously-copying case to it. The packed fast path is unchanged; only inputs that would otherwise be copied now run in place. On Qwen3 MoE q4 this removes ~490 copy launches over a 40-token run with identical output.
The general gather kernel runs one thread per output element and redoes the full src/index stride decomposition (mod/div loops) for every element. For an axis-0 gather of a row-contiguous source (e.g. the MoE token reorder, gathering [N, hidden] rows), all elements of a row share the same source-row base, so this is pure integer-math overhead. Add a fast path (gather_rows_kernel) for that case: one block per output row, source-row base computed once, coalesced copy of the contiguous row. Gated to nidx==1, axis 0, row-contiguous src, ndim>=2, full-row slices, contiguous index; everything else still uses the general kernel.
- DynamicSliceUpdate (gpu/primitives.cpp): donate the input buffer when uniquely owned (contiguous, full) so a device-position slice_update writes IN PLACE instead of copying the whole buffer every call — O(1) preallocated KV updates with a stable address. Mirrors the existing SliceUpdate donation. - CustomKernel output->input aliasing (fast.h, fast_primitives.h, rocm/metal/cuda custom_kernel.cpp): hip_kernel() takes an optional output_input_aliases map; an aliased output reuses the input's buffer in place. Lets a recurrent-state kernel (gated-delta SSM) write its new state into the same buffer it read, so a captured HIP graph's recurrence accumulates across replays. Honored on all three GPU backends; no-op when unset. - indexing.hip: in-place device-scalar kernels (gpu_kv_pos_set/increment) and an in-place KV row-write (gpu_kv_row_write) for the device-position decode loop. Raw kernels (no host-constant upload) so the value survives graph capture/replay.
The unconditional per-GEMM fprintf(stderr) serialized the host thread on every matmul (prefill hot path). Gate it behind an env flag (off by default).
6-bit QMV ran on qmv_warp_shared at half-wave (block=16) occupancy because the generic tiled kernel needs integer pack_factor (32/6 isn't). New qmv_tiled_6bit_kernel gives 6-bit the tiled kernel's full Wave32 + column-tiling + LDS-X structure with byte-aligned 6-bit loads (K%64==0). +26% decode on gfx1151 (30.3 -> 38.3 tok/s), +2% on gfx1201. On by default; MLX_ROCM_QMV_6BIT_SLOW reverts to warp_shared.
MoE expert gather-QMV (gather_qmv_warp_shared) ran at half-wave occupancy like the dense path. gather_qmv_tiled_6bit_kernel mirrors qmv_tiled_6bit_kernel (full Wave32 + column-tiling + LDS-X, byte-aligned 6-bit loads) with the expert-index gather. Dense+ MoE together: +33% decode on gfx1151 (30.2 -> 40.2 tok/s). Default-on; MLX_ROCM_QMV_6BIT_SLOW reverts.
…apability table - Route dequant prefill GEMM through hipBLASLt (all dtypes), eliminating the rocBLAS Tensile missing-kernel churn on gfx1201. - Cache the selected hipBLASLt algorithm per (shape,dtype,transpose,device) so warm GEMMs skip AlgoGetHeuristic; recovers prefill parity with rocBLAS. - Probe GEMM input-type support (bf16/fp8 e4m3/e5m2/int8) once per device at first use and print a capability table; select precision via enum instead of an arch-string match. - Add hipblaslt_gemm_fp8_raw (e4m3 inputs, scale pointers, bf16 out, best-algo tuned) primitive for the gfx1201 fp8 path. - Gate allocator slab hints (hipMemAdvise/prefetch) to integrated GPUs only.
…e, bf16) Dequantize packed affine weights straight to e4m3 (no bf16 intermediate) and cast activations to e4m3, then run the projection GEMM on fp8 matrix cores via hipblaslt_gemm_fp8_raw, descaled back to bf16. Per-tensor weight scale is derived from quant-param endpoints (no full-weight pass). Capability-gated to devices with e4m3 kernels; bf16 path elsewhere. ~+20% warm prefill on gfx1201.
free() ran a blocking hipFree on the completion-worker thread when the reuse cache was full; on the APU's fine-grained unified memory that free waits on GPU completions the worker itself delivers — a self-deadlock that wedged decode under heavy async load (MTP speculative decode). Defer such frees to a pending list drained by malloc on the eval thread, where blocking is safe. Also size the integrated memory_limit_ to system RAM (the unified/GTT pool the allocations actually draw from) rather than the device VRAM figure, so the reuse pool never evicts mid-generation.
…adlock clear_cache() freed every cached buffer with a blocking hipFree while holding the allocator mutex. On unified memory that free waits for outstanding GPU work whose completion the worker thread delivers — and the worker frees through the same mutex, so a long-prompt prefill (large cache + many in-flight frees) deadlocked with the GPU idle. Synchronize the device first so the frees have nothing to wait on, and release any deferred frees in the same pass.
…ync) Adopt the CUDA backend's stream-ordered allocation model. Primitive output buffers allocate from a per-device hipMemPool via malloc_async(size, encoder) on the encoder's stream, and free non-blocking via hipFreeAsync on that same stream so the frees retire in order behind the buffer's last use and the pool reclaims memory (a separate free stream never executes mid-forward, leaking VRAM). CPU access to pool buffers (device>=0, non-coherent) is served by the existing pinned host-shadow path. Wired malloc_async into every primitive that allocates an output, mirroring the CUDA backend: copy, binary_two, reductions, softmax, logsumexp, scan, norms, rope, random, arange, sort, indexing, attention (sdpa/flash/wmma), conv, distributed, quantized (qmm/gather/convert_fp8), matmul. The pool is always on where the device supports memory pools. Stream-less allocations (model load, KV, non-wired ops) stay on the unified path with deferred frees off the completion-worker thread. clear_cache trims the pool instead of blocking-freeing under handler pressure. Verified stable on gfx1151 (APU) and gfx1201 (R9700) across prefill, decode, and MTP: D1 297 pp/s / 47.8 tps, D0 247 pp/s / 42.1 tps; no wedge, no OOM.
…ture During capture the async pipeline inflates the input buffer use_count, so can_donate fails and the update copies into a fresh buffer — the captured graph then reconstructs (frozen capture input + current row) every replay and loses accumulation (growing KV cache freezes -> repeated tokens). Force the in-place donation for a contiguous, fully-materialized buffer while a graph is being captured.
Add a mark-based rewind so per-token sampling allocations reuse [mark, ...) while the captured graph's deterministic buffer region [0, mark) stays reserved across replays.
- malloc_async routes through DecodeArena during capture (was emitting MemAlloc graph nodes that fail on the 2nd replay with 'invalid argument') - DecodeArena: reserve 16384 descriptors so the descriptor vector never reallocates (returned RocmBuffer* point into it; realloc dangled them -> heap corruption) - DecodeArena::reset_to(byte_mark, desc_mark) rewinds BOTH counters so the graph region stays reserved while per-token sampling reuses the tail - is_hipblaslt_available() returns false during capture (force rocBLAS): a warm hipBLASLt handle still runs AlgoGetHeuristic/workspace hipMalloc that invalidates the capture With these + the DynamicSliceUpdate donation fix, capture-once graph decode replays the full forward coherently on gfx1151.
…6 verified coherent
After a capture-once graph is built, set_paused(true) keeps the arena backing valid (captured-graph buffers stay at baked addresses) but routes per-token sampling allocations to the pool, so sampling can't clobber graph buffers and corrupt the next replay. Fixes replay token N+1 corruption from arena reset_to.
…[gated off] Foundation mirroring the CUDA backend: CommandEncoder gains add_kernel_node / add_kernel_node_raw, a build_graph_ accumulator, dependency tracking in set_input_array/set_output_array, needs_commit(), and commit() that builds the per-eval HIP graph, reuses the exec via hipGraphExecUpdate (LRU keyed on topology hash), and submits one hipGraphLaunch. eval.cpp wires needs_commit/ commit. hipBLASLt workspace pre-allocated so capture never hipMallocs. Gated behind MLX_USE_HIP_GRAPHS (default OFF) — default build is unchanged eager (verified coherent, 41 tok/s on gfx1151). The graphs-ON path currently uses a per-lambda stream-capture bridge in launch_kernel which DEADLOCKS on the first eval (library/alloc calls under capture) — to be replaced by real per-kernel migration to add_kernel_node (host-side node construction).
…des) Convert elementwise (unary/binary/binary_two/ternary), norms (rms_norm/ layer_norm), softmax/logsumexp, scan, arg_reduce, sort, rope, indexing (gather/scatter/slice_update/masked_scatter), random, and attention (sdpa/flash/flash_wmma) launch sites from launch_kernel(lambda) to encoder.add_kernel_node(&kernel, grid, block, smem, args...). Fix the add_kernel_node_ex param marshalling to strip const (gpu_ptr returns const for const inputs). Graphs-OFF (default) is unchanged immediate-launch and builds clean; sets up automatic per-eval graph batching when graphs-ON. Residual launch_kernel sites (memsets, rocprim sort path, copy/ subdir, KV helpers, JIT custom_kernel/compiled) still pending migration.
Wave 2: copy/ subdir, reduce/ subdir (row/col/all/init), quantized (affine_quantize, fp_quantize, convert_fp8), qmm.hip (~63 sites: qmv/qvm tiled+warp+gather, all bit/group/dtype combos), gemv.hip. Builds clean, graphs-OFF unchanged. Residual launch_kernel: copy/arg_reduce memsets (-> memset nodes), JIT custom_kernel/compiled, GEMM library (rocblas/ hipblaslt), gemv malloc fallback, rocprim sort.
…idge add_kernel_node_ex now copies arg VALUES into a heap pack kept alive through commit() (HIP graph nodes reference kernelParams until instantiate/exec-update, after which the pack is cleared) — fixes dangling kernelParams. The per-op micro-capture bridge in launch_kernel is now behind MLX_HIP_GRAPH_BRIDGE. graphs-OFF (default) unchanged.
Diagnostic: pure add_kernel_node kernel-node graphs launch correctly on this ROCm build (model-load evals pass). Remaining graphs-ON blockers are the non-kernel residuals only: library GEMM (aborts/crashes under graph) and the child-graph bridge nodes. graphs-OFF (default) unaffected.
…lifetime graphs-ON (MLX_USE_HIP_GRAPHS, default OFF) now RUNS end-to-end on the ROCm 7.13 runtime (7.12 segfaulted hipGraphLaunch). launch_kernel graph-splits un-graphable residuals (JIT module kernels, GEMM, memsets): flush+launch the accumulated kernel-node graph, run the residual immediately on the same stream, start a fresh graph. hipBLASLt forced to rocBLAS in graph mode (its lazy init aborts under graph activity). kernelParams arg-packs freed at synchronize (exec references them through async launch). KNOWN WIP: graphs-ON output is incorrect (incomplete set_input/output_array dependency edges -> races) and slower than eager due to graph-split fragmentation. Default graphs-OFF unchanged (41 tok/s).
graphs-ON (default OFF): graph nodes serialized into a linear chain in submission order (matches eager stream order; robust vs incomplete set_input/output_array edges) and arg-packs freed at synchronize. Runs on the 7.13 runtime without crashing but output is still incorrect (an unisolated race) and slower than eager due to graph-split fragmentation. Default graphs-OFF eager unchanged (41 tok/s coherent).
Bisection of the graphs-ON correctness bug (all on 7.13 runtime, graphs-OFF default unaffected, eager 41 tok/s coherent): - 1 node/graph is ALSO wrong -> not multi-node dependency/race. - exec-cache keyed by node-type-only collided distinct kernel sequences -> hipGraphExecUpdate mis-reused execs -> garbage. Now key by func ptr + dims. - fresh hipGraphInstantiate per commit + destroy-at-synchronize (no reuse) -> segfaults; ExecUpdate-reuse -> runs but garbage. Both point to a deeper hipGraph instantiate/exec instability for this GDN+MoE workload on ROCm 7.13. graphs-ON still not correct; eager + 7.13 is the working path.
…solated Standalone repro proved hipGraphAddKernelNode + tuple-marshaling are correct on 7.13 (identical to hipLaunchKernel). Bisection of full-forward graphs-ON: - BUG1 buffer lifetime: graph nodes execute at commit, but the allocator frees intermediates at eval time -> reused before the graph runs -> segfault. Deferring frees (graph_active) prevents the segfault but balloons memory. - BUG2 computation: even with buffers kept alive/non-aliased, output is garbage -> a remaining error in the full multi-kernel forward not reproduced by the single-kernel repro. Needs per-kernel eager-vs-graph output bisection. Default graphs-OFF eager unchanged (41 tok/s coherent on 7.12 and 7.13).



Experiment with ROCm backend.
install MLX with ROCm backend using:
closes #2556
Inspired by @zcbenz