Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
291 commits
Select commit Hold shift + click to select a range
4746543
Add proper CastOp for ROCm copy to handle all type conversions
NripeshN Feb 4, 2026
aa4ff37
Add missing half/bfloat16 conversions in CastOp
NripeshN Feb 4, 2026
6e4d799
Remove duplicate is_complex definition, use from utils.hpp
NripeshN Feb 4, 2026
97afbd5
Improve ROCm backend to match CUDA functionality
NripeshN Feb 4, 2026
ad9c9cc
Fix reduce operations to match CUDA type constraints
NripeshN Feb 4, 2026
5269e6a
Fix Max/Min reduce ops to use explicit specializations instead of con…
NripeshN Feb 4, 2026
6e4e202
Exclude complex types from reduce operations (not yet supported on ROCm)
NripeshN Feb 4, 2026
4aec5ec
Fix type_identity usage - use mlx::core::type_identity instead of std…
NripeshN Feb 4, 2026
a17961e
Include reduce_utils.hpp for allocate_same_layout
NripeshN Feb 4, 2026
216e533
Enhance ROCm support in CMake and backend
NripeshN Feb 4, 2026
4bf5f22
Add hipFloatComplex support for scan and reduce operations
NripeshN Feb 4, 2026
abc2634
Add debug output to copy_contiguous
NripeshN Feb 4, 2026
833bfc7
Fix const cast in debug output
NripeshN Feb 4, 2026
f10845a
Add more debug output to copy_contiguous
NripeshN Feb 4, 2026
f2f976b
Add stream sync before kernel launch
NripeshN Feb 4, 2026
9426d6c
Use hipMemcpy for small copies
NripeshN Feb 4, 2026
94868fa
Revert to simple kernel launch
NripeshN Feb 4, 2026
3990c3d
Remove debug output from copy_contiguous
NripeshN Feb 4, 2026
a74e904
Fix WARP_SIZE mismatch between host and device code
NripeshN Feb 4, 2026
9a05cd0
Refactor all_reduce to support all types using dispatch_all_types
NripeshN Feb 4, 2026
474f921
Fix all_reduce type casting for And/Or operations
NripeshN Feb 4, 2026
700de96
Add is_valid_reduce_op check to skip invalid type/op combinations
NripeshN Feb 4, 2026
5a9b067
Add complex type support for Min/Max reduce operations
NripeshN Feb 4, 2026
e2c5fcd
Add complex type support to reduce.hpp operators
NripeshN Feb 4, 2026
1766e04
Use SFINAE instead of if constexpr for complex type handling in reduc…
NripeshN Feb 4, 2026
af0acd6
Add complex type support for unary operations
NripeshN Feb 4, 2026
d655bbe
Include hip_complex.h in fp16_math.hpp for hipFloatComplex type
NripeshN Feb 4, 2026
d33bd4c
Refactor unary ops to use dispatch_all_types with type checking
NripeshN Feb 4, 2026
59e8097
Handle -inf case in complex exp function
NripeshN Feb 4, 2026
363b7eb
Add float16 and bfloat16 support to arange
NripeshN Feb 4, 2026
edb9cd7
Fix GPU architecture string in JIT module - gcnArchName already conta…
NripeshN Feb 4, 2026
f2a7f4f
Replace hip/std/array with simple array implementation for JIT
NripeshN Feb 4, 2026
31093f5
Add standard type definitions for JIT compilation
NripeshN Feb 4, 2026
6cf9a3f
Add missing unary and binary ops to JIT includes
NripeshN Feb 4, 2026
3082c41
Add uint16_t, int16_t, uint8_t, int8_t type definitions for JIT
NripeshN Feb 4, 2026
5f1a4d4
Add complex64 support to binary_op_gpu_inplace
NripeshN Feb 4, 2026
0c7e7ea
Add if constexpr check for supports_binary_op in launch_kernel
NripeshN Feb 4, 2026
5d0deba
Fix supports_binary_op for comparison operators with complex types
NripeshN Feb 4, 2026
f61797f
Remove complex64 from binary_op_gpu_inplace (not all ops support it)
NripeshN Feb 4, 2026
6870081
Fix supports_binary_op to use else if constexpr chain
NripeshN Feb 4, 2026
eed4267
Remove if constexpr check from launch_kernel (was causing issues)
NripeshN Feb 4, 2026
cef0bbc
Enhance ROCm backend with general binary operation support and improv…
NripeshN Feb 5, 2026
49c1dce
Enhance ROCm backend with dynamic memory management and kernel optimi…
NripeshN Feb 5, 2026
1fa3a44
Enhance ROCm backend with dynamic memory initialization and kernel ar…
NripeshN Feb 5, 2026
8a21489
Enhance ROCm backend with new all-reduce functionality and kernel opt…
NripeshN Feb 5, 2026
780a83d
Remove input dilation check from gemm_conv function in ROCm backend t…
NripeshN Feb 5, 2026
c40fd68
Refactor ROCm backend gather and scatter operations for improved perf…
NripeshN Feb 6, 2026
5993979
lint
NripeshN Feb 6, 2026
436b65d
Add hip_kernel support for ROCm backend and enhance Python bindings
NripeshN Feb 6, 2026
d6019c0
Enhance row_reduce function in ROCm backend to support contiguous data
NripeshN Feb 6, 2026
3be5a10
Remove unused type traits from ROCm unary kernel implementation to st…
NripeshN Feb 6, 2026
7672448
Implement single position RoPE kernel in ROCm backend
NripeshN Feb 7, 2026
b4a2a36
Refactor warp reduction logic in ROCm layer and RMS normalization ker…
NripeshN Feb 7, 2026
c550158
Add support for bfloat16 data type in scaled dot product attention ke…
NripeshN Feb 7, 2026
16c1ef4
Disable ROCm SDPA kernel due to warp size incompatibility
NripeshN Feb 7, 2026
f5aac8d
Rewrite ROCm SDPA kernel to be warp-size agnostic
NripeshN Feb 7, 2026
a6bf8cb
Temporarily disable ROCm SDPA kernel to debug memory fault
NripeshN Feb 7, 2026
af26ee9
Re-enable warp-agnostic ROCm SDPA kernel
NripeshN Feb 7, 2026
c6d9a92
ci trigger
NripeshN Feb 8, 2026
9d73b71
Added github workflow for rocm strix halo
goniz Jan 27, 2026
2285120
Fix ROCm bfloat16 matmul and kernel type handling
goniz Feb 25, 2026
0a08672
Fix ROCm non-uniform batched matmul for fp16/bfloat16
goniz Feb 25, 2026
3a9c39b
Fix ROCm affine quantized matmul sign handling
goniz Feb 25, 2026
8684c46
Fix ROCm non-power-of-two quantized packing
goniz Feb 25, 2026
fb3a67e
Replace Qwen3 smoke script with pytest suite
goniz Feb 25, 2026
a01a7bd
Merge original MLX main into rocm-support-fixes
goniz Feb 25, 2026
8dec0d4
Fix ROCm LogAddExp bf16 handling and expand generation matrix
goniz Feb 25, 2026
9c8718d
Fix ROCm GatherQMM index contiguity
goniz Feb 25, 2026
ac27e78
Support strided GatherQMM indices on ROCm
goniz Feb 25, 2026
11b2920
Fix ROCm hot-path pointer access to avoid host synchronization
goniz Feb 25, 2026
4758c15
Accelerate ROCm depthwise Conv1d grouped path
goniz Feb 25, 2026
1e7e977
Fix ROCm GatherMM hard sync in fallback path
goniz Feb 25, 2026
cbcd332
Fix ROCm BLAS pytest failures in direct test runs
goniz Feb 25, 2026
f3a30e0
Implement ROCm MaskedScatter kernel for boolean indexing
goniz Feb 25, 2026
926fdee
Fix ROCm SDPA crashes in GQA causal paths
goniz Feb 25, 2026
1d95664
Fix ROCm fp quantized matmul decode paths
goniz Feb 26, 2026
b5c0ba3
Fix ROCm quantized fallback paths for fp and qqmm
goniz Feb 26, 2026
77320af
Accelerate ROCm quantized decode path for generation
goniz Feb 26, 2026
9d23561
Optimize ROCm quantized matmul decode kernels
goniz Feb 26, 2026
04805fd
Optimize ROCm GatherQMM warp decode path
goniz Feb 26, 2026
fed4ca0
Tune ROCm quantized warp kernels for decode throughput
goniz Feb 26, 2026
0618c69
Tune ROCm 8-bit quantized decode kernels
goniz Feb 26, 2026
ff3fcfc
Tune ROCm quantized subgroup threading for decode
goniz Feb 26, 2026
43cd9dc
Optimize ROCm GEMV batched launch parameter handling
goniz Feb 26, 2026
2f5964f
Fix ROCm gather GEMV indexing for batched layouts
goniz Feb 26, 2026
698f86c
Optimize ROCm APU allocator and fix high CPU spin-wait
goniz Feb 27, 2026
17b7cb8
Add bfloat16 support for rocBLAS GEMM operations
goniz Feb 27, 2026
f29e4e4
Optimize ROCm GEMV with vectorized loads and wider n_per_thread
goniz Feb 27, 2026
a6967d2
Increase ROCm max ops per buffer from 20 to 1000
goniz Feb 27, 2026
8c56f29
Fix quantized matmul array creation bug and simplify kernels
goniz Feb 27, 2026
197e844
Merge branch 'main' of github.com:ml-explore/mlx into rocm-support-fixes
goniz Feb 27, 2026
a1a642e
Optimize ROCm backend: Fix SDPA fallback, enable QMM rocBLAS dequant,…
goniz Feb 28, 2026
719dc9d
Add optimized Flash Attention and reduce rocBLAS dispatch overhead
goniz Feb 28, 2026
0c5144a
ROCm: Add MLA Flash Attention support and fix rocBLAS dispatch
goniz Mar 1, 2026
7d5eb69
benchmark: update default max-tokens to 1000
goniz Mar 1, 2026
e8e3a45
benchmark: remove --no-warmup from llama-completion
goniz Mar 1, 2026
958240a
benchmark: redact prompt from logs to reduce terminal clutter
goniz Mar 1, 2026
d55d2a2
ROCm: Fix JIT compilation 'File name too long' error
goniz Mar 1, 2026
805d272
ROCm: Add math function overloads for bfloat16 and half types
goniz Mar 1, 2026
b44396a
ROCm: Fix quantized GEMM fallback correctness
goniz Mar 1, 2026
f1687cc
ROCm: fix 5/6-bit affine quantized matmul page faults
goniz Mar 1, 2026
108195a
ROCm: Fix quantized matmul with singleton batch dimensions
goniz Mar 1, 2026
ec84dfd
ROCm: Optimize quantized matmul and MoE gather for decode shapes
goniz Mar 2, 2026
f4634b4
ROCm: Vectorize 4-bit and 6-bit memory access in qmv_warp_shared_kernel
goniz Mar 2, 2026
a69c471
ROCm: Set default THREADS_PER_COL to 16 for qmv warp kernels
goniz Mar 2, 2026
24ecc76
ROCm: Optimize RoPE kernel for decode with sincosf and 1D layout
goniz Mar 2, 2026
4353b1b
ROCm: vectorize 6-bit fallback QMV kernels
goniz Mar 2, 2026
b811a89
ROCm: optimize QMM dispatch and extend SDPA head-dim support
goniz Mar 3, 2026
b38695f
ROCm: harden QMM cache keys and tune QMV launch defaults
goniz Mar 3, 2026
bc3bd38
ROCm: improve SDPA decode dispatch and avoid AddMM copy
goniz Mar 3, 2026
2884e85
ROCm: broaden batched GEMM fast-path stride detection
goniz Mar 3, 2026
7c80030
ROCm: add configurable rocBLAS GEMM solution-index dispatch
goniz Mar 3, 2026
184ef21
ROCm: make QMV launch defaults shape-adaptive
goniz Mar 3, 2026
c6883ca
ROCm: increase shared QMV tile size for decode
goniz Mar 3, 2026
d5d8b31
ROCm: reduce command-encoder scheduling overhead
goniz Mar 3, 2026
7bca990
ROCm: add sorted-rhs gather scheduling fast path
goniz Mar 3, 2026
20bcdd2
ROCm: extend sorted-rhs gather schedule across QMV dispatch
goniz Mar 3, 2026
d07f6a5
Benchmarks: route Qwen3.5 vision models through mlx-vlm
goniz Mar 3, 2026
1c93a6f
ROCm: add architecture-aware QMV crossover and tiny-K dispatch
goniz Mar 3, 2026
6be6435
ROCm: add alignment-aware QMV variant selection
goniz Mar 3, 2026
3ca29dc
ROCm: fix no-shared QMV accumulator shadowing
goniz Mar 3, 2026
879a200
Merge branch 'main' of github.com:ml-explore/mlx into rocm-support-fixes
goniz Mar 3, 2026
9193df5
Merge NripeshN/mlx rocm-support into upstream main
Geramy Mar 25, 2026
9fddf1c
Add RDNA 3.5/4 architectures and parallel HIP compilation
Geramy Mar 25, 2026
3ae44dc
Fix parallel-jobs flag: single dash for hipcc/clang
Geramy Mar 25, 2026
2b8a7d1
Limit HIP parallel-jobs to half of available CPUs
Geramy Mar 25, 2026
c2eb919
Add missing gpu::init() and SliceUpdate::eval_gpu stub for ROCm
Geramy Mar 25, 2026
26e733c
Implement ROCm-optimized SliceUpdate::eval_gpu
Geramy Mar 25, 2026
edd89a1
Fix bfloat16/half JIT compilation for ROCm fused kernels
Geramy Mar 25, 2026
1ab4186
Simplify JIT preamble ops: always promote through float
Geramy Mar 25, 2026
d03fa7c
Fix critical bug: JIT KernelArgs passed CPU pointers instead of GPU
Geramy Mar 25, 2026
76741bc
Remove gfx1150/1151/1152/1200/1201 from rocBLAS supported list
Geramy Mar 25, 2026
9336df8
Add rocBLAS fallback to naive_gemm when Tensile kernel missing
Geramy Mar 25, 2026
f92d2d2
Add missing kernel_utils.hpp include for gpu_ptr in rocblas_gemm
Geramy Mar 25, 2026
8acadb4
Probe rocBLAS bf16 GEMM at device init, fallback to naive_gemm
Geramy Mar 25, 2026
bfab6fb
Always use naive_gemm for bfloat16 GEMM on ROCm
Geramy Mar 25, 2026
c8c9c8e
ROCm bug fixes + optimized quantized GEMV kernel
Geramy Mar 26, 2026
2f47aeb
Promote JIT binary ops through float, restore rocBLAS for gfx1151
Geramy Mar 26, 2026
6520667
GatherQMM: ensure contiguous indices, SDPA: add head_dim=256
Geramy Mar 26, 2026
00d8c2e
SDPA GPU decomposition, naive_gemm for all types, GatherQMM contiguou…
Geramy Mar 26, 2026
4a5bb0f
Metal-compatible QMM accumulation, JIT stderr suppression
Geramy Mar 26, 2026
73470d8
Fix GatherQMM memory corruption, add index bounds clamping
Geramy Mar 26, 2026
1e50c74
Kernel audit: match Metal precision across RMSNorm, sort, softmax, ops
Geramy Mar 26, 2026
1793485
Fix batched matmul: missing bfloat16/float16 in loop-based GQA path
Geramy Mar 27, 2026
840d028
Add head_dim=256 dispatch to SDPA vector kernel
Geramy Mar 27, 2026
b48adae
Merge upstream main into rocm-support
NripeshN Mar 27, 2026
fe75135
Merge goniz/rocm-support-fixes with extensive ROCm optimizations
NripeshN Mar 27, 2026
d30fe29
Merge upstream NripeshN/mlx rocm-support with ROCm optimizations
Geramy Mar 27, 2026
5ffb863
Enable 4-bit fast gather QMV dispatch for MoE decode
Geramy Mar 27, 2026
b1300b9
Optimize ROCm allocator for integrated GPUs (APU)
Geramy Mar 27, 2026
780b4fe
Prefer shared-memory QMV over noshared variant for decode
Geramy Mar 27, 2026
0ec6b45
Add expert-grouped prefill kernel for GatherQMM (3.4x prompt speedup)
Geramy Mar 27, 2026
c9167d2
Allocator: prefer hipExtMallocWithFlags for APU, fallback to hipMallo…
Geramy Mar 27, 2026
a66e273
Add WMMA-accelerated prefill kernel for GatherQMM on RDNA 3/3.5/4
Geramy Mar 27, 2026
e35d6aa
WMMA prefill kernel: support non-aligned M, sort unsorted indices
Geramy Mar 27, 2026
435afdc
Add GPU-only expert-batched gather QMV kernel for low-expert MoE
Geramy Mar 27, 2026
bc4d62f
Add hipBLASLt GEMM integration for bf16/fp16 matmul on ROCm
Geramy Mar 27, 2026
b8b56b1
hipBLASLt: add to QMM dequant+GEMM path for bf16 (2.6x prompt speedup)
Geramy Mar 27, 2026
7ac6efd
hipBLASLt in QMM dequant path + CommandEncoder graph capture API
Geramy Mar 27, 2026
b913c68
Strided copy kernels for ensure_row_contiguous in QMM
Geramy Mar 27, 2026
da1925b
Allocator: power-of-2 rounding for large allocs (>= 1MB)
Geramy Mar 28, 2026
65958fa
Allocator: use system RAM limit for iGPU, power-of-2 rounding for lar…
Geramy Mar 28, 2026
b010eee
Allocator: revert power-of-2 rounding, keep hipExtMallocWithFlags
Geramy Mar 28, 2026
f26c802
Fix CU count comment: 40 CUs (20 WGPs) on gfx1151
Geramy Mar 28, 2026
251c8d8
Merge pull request #5 from lemonade-sdk/rocm-optimizations
Geramy Mar 30, 2026
ce31887
Add multi-tier slab allocator for ROCm backend
Geramy Mar 31, 2026
ef8190c
hipBLASLt auto-tune + eliminate hipMemcpyAsync in copy kernels
Geramy Mar 31, 2026
25f5912
Skip hipStreamSynchronize on iGPU when stream is idle
Geramy Mar 31, 2026
a057095
Disable hipBLASLt auto-tune by default, fix warm prompt regression
Geramy Mar 31, 2026
3ddba6a
Replace hipEventSynchronize with spin-wait polling on iGPU
Geramy Mar 31, 2026
6b3713e
Add L2-optimized tiled QMV kernel with TILE_N=16 column blocking
Geramy Mar 31, 2026
e6563a6
ROCm backend: arch-tunable QMV, WMMA flash attention, arena allocator…
Geramy Mar 31, 2026
bc9d8ba
Fix custom kernel stdout spam breaking MoE model output, vectorize QM…
Geramy Apr 1, 2026
ed8c2aa
Merge branch 'ml-explore:main' into rocm-support
Geramy Apr 4, 2026
a866ff4
[ROCm] Guard placement new/delete to fix build on ROCm 7.12+
Geramy Apr 4, 2026
71d03e5
[ROCm] Add hip_kernel stub to no_rocm.cpp to fix undefined symbol
Geramy Apr 4, 2026
4f60779
[ROCm] Guard WMMA compilation for non-WMMA architectures
Geramy Apr 13, 2026
39fac95
ROCm: wire up tiled 8-bit QMV launches for fp16 and bf16
soloish90 Apr 23, 2026
d999ca6
Merge branch 'ml-explore:main' into rocm-support
Geramy Apr 23, 2026
516b5a1
Merge pull request #6 from soloish90/fix/rocm-qmv-tiled-8bit
Geramy Apr 23, 2026
971451e
Merge branch 'ml-explore:main' into rocm-support
Geramy May 1, 2026
526dbbd
[ROCm] Guard rocWMMA dispatch on per-device arch allowlist
Geramy May 4, 2026
767b0aa
Merge branch 'ml-explore:main' into rocm-support
Geramy May 7, 2026
e15fcef
ROCm: fix 8-bit affine QMV miscompile from uint4 weight load
Geramy May 20, 2026
9e768e4
Merge pull request #7 from NripeshN/geramy/fix-rocm-qmv-8bit-uint4-mi…
Geramy May 20, 2026
597ccd3
Add clear_streams() for upstream MLX PR #3395 compatibility
antmikinka Jun 3, 2026
94d325c
Merge pull request #9 from antmikinka/add-clear-streams
Geramy Jun 3, 2026
bb8c8bc
fix(rocm): Avoid #pragma unroll in affine_dequantize_packed_kernel on…
antmikinka Jun 8, 2026
647f452
style(rocm): Run clang-format on affine_quantize.hip
antmikinka Jun 8, 2026
8342ccb
style(rocm): Fix pre-commit formatting across ROCm backend
antmikinka Jun 8, 2026
45f6ee1
Merge pull request #10 from antmikinka/fix/affine-dequantize-rdna35
Geramy Jun 8, 2026
b0b905e
ROCm QMV: persistent grid-stride + streaming weight loads + RDNA4 til…
Geramy Jun 13, 2026
2e4e90e
Merge branch 'main' into rocm-support
Geramy Jun 13, 2026
05f4ed3
ROCm QMV: dual-issue-friendly dequant (4 independent accumulators)
Geramy Jun 13, 2026
a268fe8
rocm: key the JIT HSACO cache by GPU arch
Geramy Jun 14, 2026
830bf1d
rocm: VRAM-resident memory for discrete RDNA4 GPUs (gfx1201)
Geramy Jun 14, 2026
ca79450
rocm/qmv: 4-accumulator dual-issue for 4-bit decode/prefill matvec
Geramy Jun 14, 2026
e803eca
rocm/sdpa: use the vector kernel for single-query decode, not flash
Geramy Jun 14, 2026
e296a56
rocm: make HIP graph capture/replay work (capture-aware completion Ev…
Geramy Jun 14, 2026
94a8a39
rocm: make HIP graph replay work on RDNA4 — pass gather metadata by v…
Geramy Jun 15, 2026
24c1065
rocm: defer hipBLASLt init while a HIP graph is capturing
Geramy Jun 15, 2026
d263cf1
rocm/graph: add async (non-draining) replay variant
Geramy Jun 15, 2026
cca7da2
rocm/quantized: fix 6-bit (and 2/5-bit) matmul producing garbage
Geramy Jun 16, 2026
e0ab9b6
rocm/quantized: optional tiled gather-QMV for MoE decode (env-gated)
Geramy Jun 16, 2026
e48368a
rocm/rope: skip the copy for partial rope on donatable inputs (PR #37…
Geramy Jun 16, 2026
e0ad799
rocm: pass general elementwise shape/strides by value (capture-safe +…
Geramy Jun 16, 2026
7cf95c7
rocm: defer buffer frees while a captured graph is alive
Geramy Jun 16, 2026
60ec82d
rocm: make remaining strided kernels capture-safe (by-value metadata)
Geramy Jun 16, 2026
10d03fa
fix(rocm/jit): invalidate hsaco disk cache when kernel source changes
Geramy Jun 16, 2026
d6b26ad
rocm: print the bound HIP device and arch on Device creation
Geramy Jun 16, 2026
cf71d31
rocm: allocate on the active default GPU, not the current HIP device
Geramy Jun 16, 2026
1220969
Revert "rocm: allocate on the active default GPU, not the current HIP…
Geramy Jun 17, 2026
12c3e10
rocm: run inference on a discrete GPU over a non-coherent link (TB5 e…
Geramy Jun 17, 2026
906835d
rocm: discrete-GPU CPU readback, event signaling, and per-arch JIT cache
Geramy Jun 17, 2026
812b843
rocm: KV memcpy and graph encoder use the selected device's stream
Geramy Jun 17, 2026
f80aa8f
rocm: discrete-GPU memory limit, read-only CPU mirror, device-flags r…
Geramy Jun 17, 2026
40dcd92
rocm: per-device hipEvent pool + bind stream's device on encoder access
Geramy Jun 17, 2026
6c10072
rocm: WMMA flash attention supports all head dims (incl. D=256) withi…
Geramy Jun 17, 2026
275019b
rocm: bind selected device on worker thread, encoder, and JIT load
Geramy Jun 17, 2026
f6e7d54
rocm: reliable cross-stream signaling, eager cache trim, fine-grained…
Geramy Jun 17, 2026
e7248b1
rocm: fix SliceUpdate reduce-path compile (structured binding capture)
Geramy Jun 17, 2026
3b270c9
rocm: fix rope partial-rotary no-copy path (gate on row_contiguous)
Geramy Jun 17, 2026
74c3c16
rocm: strided-input RMSNorm to avoid per-row contiguous copies
Geramy Jun 17, 2026
a6751bf
rocm: fast contiguous row-gather path for axis-0 gather
Geramy Jun 17, 2026
e1851e2
rocm: in-place KV / recurrent-state writes for graph-replay decode
Geramy Jun 18, 2026
db19eae
rocm: gate per-call hipBLASLt GEMM trace behind MLX_ROCM_GEMM_DEBUG
Geramy Jun 18, 2026
df2211d
rocm: full-wave tiled 6-bit QMV decode kernel (default)
Geramy Jun 18, 2026
7a80afd
rocm: full-wave tiled 6-bit gather-QMV for MoE expert decode
Geramy Jun 18, 2026
00d6024
rocm: trim over-verbose comments to one-line descriptions (comment-only)
Geramy Jun 18, 2026
44a9935
rocm: hipBLASLt-first quantized GEMM, per-shape algo cache, runtime c…
Geramy Jun 18, 2026
172025c
rocm: fused dequant-to-fp8 GEMM for RDNA4 prefill (non-batched, affin…
Geramy Jun 18, 2026
2285627
rocm: fix unified-memory free deadlock on integrated APU
Geramy Jun 19, 2026
abf1af8
rocm: drain device before clear_cache to avoid unified-memory free de…
Geramy Jun 19, 2026
3ac3f80
rocm: CUDA-style stream-ordered memory pool (hipMallocAsync/hipFreeAs…
Geramy Jun 19, 2026
b3e5088
rocm: force DynamicSliceUpdate in-place donation during HIP-graph cap…
Geramy Jun 20, 2026
63d445c
rocm: DecodeArena reset_to(mark) for capture-once graph replay
Geramy Jun 20, 2026
5b8ac9e
rocm: make HIP-graph capture-once decode replay-safe
Geramy Jun 20, 2026
f730214
rocm: MLX_NO_HIPBLASLT env to force rocBLAS (diagnostic); rocBLAS bf1…
Geramy Jun 20, 2026
0ce67b8
rocm: DecodeArena pause (keep backing, route new allocs to pool)
Geramy Jun 20, 2026
91d8a40
rocm: WIP auto graph-batching infra (node construction, exec-update) …
Geramy Jun 20, 2026
cc17b6b
rocm: migrate kernel launches to add_kernel_node (CUDA-style graph no…
Geramy Jun 21, 2026
b6e5858
rocm: migrate copy/reduce/quantized/qmm/gemv to add_kernel_node
Geramy Jun 21, 2026
4ffea35
rocm: persist kernel-node args for graph build; gate micro-capture br…
Geramy Jun 21, 2026
f0737c5
rocm: set capture flag in graph bridge so GEMM uses capture-safe rocBLAS
Geramy Jun 21, 2026
0908d96
rocm: graph node-type histogram + dot dump (MLX_HIP_GRAPH_DUMP) for d…
Geramy Jun 21, 2026
f3cb2e0
rocm: graph-split residuals + force rocBLAS in graph mode + arg-pack …
Geramy Jun 21, 2026
51843ec
rocm: linear-chain graph deps + free arg-packs at sync + nocache toggle
Geramy Jun 21, 2026
0c15717
rocm: func-keyed graph nodes + fresh-instantiate exec lifetime (WIP)
Geramy Jun 21, 2026
90f557b
rocm: graphs-ON WIP — func-keyed nodes, fresh-instantiate, two bugs i…
Geramy Jun 21, 2026
8070d50
rocm: HIP-graph decode WORKS — coherent 1000 tok (cap graphs at 2 nodes)
Geramy Jun 21, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions .github/workflows/build_rocm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
name: Build ROCm and Test

on:
push:
branches: [ rocm-support ]
workflow_dispatch:

jobs:
build-and-test:
runs-on: strix-halo

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
run: |
uv venv venv
source venv/bin/activate
uv pip install --upgrade mlx-lm

- name: Build and install MLX ROCm wheel
run: |
source venv/bin/activate
export CMAKE_ARGS="-DMLX_BUILD_ROCM=ON -DMLX_ROCM_ARCHITECTURES=gfx1151 -DBLA_VENDOR=OpenBLAS -DCMAKE_BUILD_TYPE=RelWithDebInfo"
rm -rf wheelhouse
mkdir -p wheelhouse
uv build --wheel --out-dir wheelhouse .
uv pip install --force-reinstall wheelhouse/mlx-*.whl

- name: Basic MLX GPU test
run: |
source venv/bin/activate
python3 -c "
import mlx.core as mx
print('MLX version:', mx.__version__)
print('Default device:', mx.default_device())
mx.set_default_device(mx.gpu)
print('GPU device set')

# Test basic operations
a = mx.ones((10, 10))
mx.eval(a)
print('Basic array creation: OK')

# Test matmul
b = mx.random.normal((256, 256))
c = mx.matmul(b, b)
mx.eval(c)
print('Matmul test: OK')

# Test softmax
d = mx.softmax(b, axis=-1)
mx.eval(d)
print('Softmax test: OK')

print('All basic tests passed!')
"

- name: Run inference tests
run: |
source venv/bin/activate
export HIP_LAUNCH_BLOCKING=1
export PYTHONFAULTHANDLER=1
mkdir -p "${GITHUB_WORKSPACE}/rocm-stacktraces"

run_and_trace() {
local name="$1"
shift
lldb -Q -b \
-o "run" \
-k "bt" \
-k "quit 1" \
-- python3 "$(which mlx_lm.generate)" "$@" \
> >(tee "${GITHUB_WORKSPACE}/rocm-stacktraces/${name}.log") 2>&1
}

run_and_trace qwen3_bf16 --model mlx-community/Qwen3-0.6B-bf16 --prompt "Hi" --max-tokens 5
run_and_trace qwen3_8bit --model mlx-community/Qwen3-0.6B-8bit --prompt "How tall is Mt Everest?" --max-tokens 128

- name: Upload ROCm wheel artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
with:
name: rocm-wheel-${{ github.run_attempt }}
path: wheelhouse/mlx-*.whl
if-no-files-found: warn
retention-days: 14

- name: Upload ROCm stacktrace artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
with:
name: rocm-stacktraces-${{ github.run_attempt }}
path: ${{ github.workspace }}/rocm-stacktraces/*
if-no-files-found: warn
retention-days: 14
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,10 @@ uv.lock
.cache/
# vim
*.swp

# keys
*.pem

build.sh
github-runner/
sync_fork.sh
44 changes: 42 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
option(MLX_BUILD_METAL "Build metal backend" ON)
option(MLX_BUILD_CPU "Build cpu backend" ON)
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
option(MLX_BUILD_ROCM "Build rocm backend" OFF)
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
Expand Down Expand Up @@ -164,6 +165,43 @@ if(MLX_BUILD_CUDA)
endif()
endif()

if(MLX_BUILD_ROCM)
# Set HIP architectures - these will be used by the ROCm backend
# CMakeLists.txt
#
# Supported architectures from ROCm 6.4.0 - 7.2.0 compatibility matrix: CDNA:
# gfx908 (MI100), gfx90a (MI200), gfx942 (MI300) CDNA4: gfx950 (MI400 series)
# RDNA2: gfx1030 (RX 6000 series) RDNA3: gfx1100 (RX 7900), gfx1101 (RX 7600)
# RDNA4: gfx1200, gfx1201 (RX 8000 series)
if(NOT DEFINED CMAKE_HIP_ARCHITECTURES)
if(DEFINED MLX_ROCM_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES
${MLX_ROCM_ARCHITECTURES}
CACHE STRING "HIP architectures")
else()
set(CMAKE_HIP_ARCHITECTURES
"gfx908;gfx90a;gfx942;gfx1010;gfx1011;gfx1012;gfx1030;gfx1031;gfx1032;gfx1100;gfx1101;gfx1102"
CACHE STRING "HIP architectures")
endif()
endif()
message(
STATUS "Setting CMAKE_HIP_ARCHITECTURES to: ${CMAKE_HIP_ARCHITECTURES}")
# Note: We don't enable_language(HIP) here because it causes CMake to add -x
# hip to all CXX files in targets that link to HIP libraries. Instead, we
# compile HIP files using custom commands in the ROCm backend CMakeLists.txt.
# Find the HIP compiler
find_program(
CMAKE_HIP_COMPILER
NAMES hipcc clang++
PATHS /opt/rocm/bin /opt/rocm-6.0.0/bin /opt/rocm/llvm/bin
PATH_SUFFIXES bin
DOC "HIP compiler")
if(NOT CMAKE_HIP_COMPILER)
message(FATAL_ERROR "Could not find HIP compiler (hipcc or clang++)")
endif()
message(STATUS "Found HIP compiler: ${CMAKE_HIP_COMPILER}")
endif()

if(MLX_BUILD_METAL)
find_library(METAL_LIB Metal)
find_library(FOUNDATION_LIB Foundation)
Expand Down Expand Up @@ -310,10 +348,12 @@ if(MLX_BUILD_CPU)
message(FATAL_ERROR "Must have LAPACK installed")
endif()
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
/usr/local/opt/openblas/include)
/usr/local/opt/openblas/include /usr/include/openblas)
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
if(LAPACK_INCLUDE_DIRS)
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
endif()
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
# List blas after lapack otherwise we may accidentally incldue an old
# version of lapack.h from the include dirs of blas.
Expand Down
Loading