Add a ROCm/HIP build of k2 for AMD GPUs#1353
Conversation
k2 previously had no AMD path. This adds a from-scratch CUDA-to-HIP port that
builds and runs on ROCm. k2 is a CMake project that consumes a PyTorch build but
does NOT use torch's build-time hipify (setup.py shells out to cmake), so the
port enables HIP directly: enable_language(HIP), one force-included compat
header (k2/csrc/cuda_to_hip.h) that aliases the ~30 cudaXxx runtime symbols and
maps cub -> hipcub, and the existing .cu sources marked LANGUAGE HIP. We have made every
effort to leave the NVIDIA build unchanged: every change is under
K2_WITH_HIP/USE_HIP, and
the compat header is never compiled on the CUDA path.
The substance is replacement of two non-portable third-party deps, not symbol
renaming. Review in this order:
1. moderngpu is not portable to ROCm (its intrinsics.hxx #errors under non-nvcc,
hardcodes a 32-lane warp, and uses inline PTX). On the HIP build it is not
compiled. The small surface k2 uses (transform_lbs, mergesort,
segmented_sort[_indices], load_balance_search, sorted_search, transform_scan,
plus the allocator base) is reimplemented in k2/csrc/moderngpu_shim.h with the
exact mgpu::-shaped API, backed by rocThrust and a few small kernels, so every
call site compiles unchanged. moderngpu.h includes the shim on HIP, and
moderngpu_allocator.cu is replaced by moderngpu_allocator_hip.cu.
2. The vendored CUDPP segmented scan is warp-synchronous (WARP_SIZE=32, two
32-lane warps per block) and so is incorrect on a 64-lane wavefront. cudpp.cu
is dropped on HIP; cudpp/cudpp_hip.cu reimplements SegmentedExclusiveSum with
hipCUB (inclusive-scan the head flags into a monotone segment key, then
DeviceScan::ExclusiveScanByKey). No warp arithmetic; correct on both wave32
and wave64.
3. The rest is build wiring (the K2_WITH_HIP option, relocatable device code via
HIP_SEPARABLE_COMPILATION on the library and every consumer, .cu marked
LANGUAGE HIP, hip::host linkage), the cub.h/rand.cu library swaps to
hipCUB/hipRAND, and a handful of clang/HIP and C++20 fixes. <cuda/std/*> is
supplied by the vendored ROCm fork of libcu++ on the include path. On Windows the torch import libraries are
linked via file(GLOB) over the torch lib dir (clang+HIP and MSVC+CUDA
alike), because CMake 4.x does not expand imported SHARED targets in the
Ninja/HIP link rules; the Linux path keeps the cmake imported targets.
The port is wave-size aware: the cooperative-groups sub-warp tiles k2 uses are
all <= 32 lanes (safe on wave64 and wave32), and the two replaced primitives use
hipCUB/rocThrust, which are wave-agnostic. Several clang/HIP and C++20 specifics
are handled: a __host__ __device__ function is preprocessed once in the host pass
under clang, so device-vs-host dispatch keys on K2_DEVICE_CODE
(__HIP_DEVICE_COMPILE__) rather than a bare #ifdef __CUDA_ARCH__; the ROCm torch
forces C++20, where a user-declared (even = default) special member disqualifies
aggregate initialization, so RaggedShapeLayer's defaulted members are left
implicit; and the _k2 pybind module is built NO_EXTRAS with IPO off so HIP LTO
does not strip PyInit__k2. On Windows the LLP64 ABI required __builtin_clzll for
64-bit bit counting.
The c10 device namespace is selected by the torch source-hipify generation, not
by the OS. torch's hipify has two generations that disagree on the c10 device
classes: generation 1 renamed them (only c10::hip exists), while generation 2
(pytorch#174087) stops renaming so the CUDA spelling c10::cuda is the public
masquerading API and c10::hip survives only as thin wrappers. Because k2 never
runs torch source-hipify, cmake/torch.cmake probes torch.utils.hipify.__version__
at build time and defines TORCH_HIPIFY_V2 when it is >= 2.0.0; the c10 call sites
use c10::hip under hipify v1 and c10::cuda under v2 (or a pure CUDA build). The
c10/hip/* headers are included on every generation because c10/cuda/* pulls a
generated macro header that only exists as the hip variant on a ROCm torch.
Keying on the hipify generation rather than the OS keeps the selection correct
when a platform pairs with either generation.
Known limitations (not part of this change):
- k2/torch, the standalone libtorch C++ decoder layer, is not yet built on the
ROCm path. It pulls in kaldifeat (a separate CUDA project). The Python _k2
module and the C++ gtest suite -- the FSA core that icefall and sherpa consume
-- are built and validated.
- The moderngpu fast-path is replaced, not ported; the replacements favor
correctness via hipCUB/rocThrust over moderngpu's cached segmented loads.
Documentation for the ROCm build is added to the Sphinx install guide alongside
the existing CUDA instructions.
Test Plan:
Validated on real GPUs on three AMD architectures: gfx90a (MI250X, CDNA2,
wave64), gfx1100 (Radeon Pro W7800, RDNA3, wave32), and gfx1201 (Radeon RX
9070 XT, RDNA4, wave32).
Build (Linux, gfx90a, against a ROCm PyTorch, C++20):
cmake -DK2_WITH_HIP=ON -DK2_WITH_CUDA=OFF -DCMAKE_HIP_ARCHITECTURES=gfx90a \
-DCMAKE_CXX_STANDARD=20 -DK2_ENABLE_TESTS=ON \
-DK2_LIBHIPCXX_INCLUDE_DIR=/path/to/libhipcxx/include ..
cmake --build . -j 16
C++ gtests (each compares CPU vs GPU internally), one GPU, serial:
HIP_VISIBLE_DEVICES=0 ctest --output-on-failure
Result: 30/30 cu_*_test executables pass (298 individual tests), including the
replacement coverage -- SegmentedExclusiveSum (CUDPP replacement),
Prune+SortSublists (segmented_sort_indices), transform_lbs/mergesort/
transform_scan, the cooperative-groups sub-warp-tile kernels on wave64, and
hipRAND.
Python integration slice on GPU:
HIP_VISIBLE_DEVICES=0 python3 -m pytest k2/python/tests
Result: all pass except two device-independent pre-existing artifacts that
reproduce on a CUDA build with the same torch -- the pickle/setstate tests
(torch 2.6+ weights_only=True default refuses the custom RaggedTensor global)
and one float32-only normalize_scores tolerance case (catastrophic
cancellation in 10 - log(exp(2)+exp(10)); GPU float64 is exact).
Followers (gfx1100/gfx1201) build with only
-DCMAKE_HIP_ARCHITECTURES=<arch> changed and pass at identical counts.
Authored with assistance from Claude (Anthropic).
There was a problem hiding this comment.
Code Review
This pull request introduces AMD GPU (ROCm/HIP) support to the k2 library, enabling compilation of .cu sources as HIP and replacing non-portable dependencies with hipCUB and rocThrust. Key changes include build system updates, a CUDA-to-HIP compatibility header, a HIP-based ModernGPU shim, and C++20 compatibility adjustments. Feedback on these changes highlights a performance bottleneck in segmented_sort due to host-side synchronization and looping, which should be refactored into a single global GPU sort. Additionally, a null check for reduction_out in transform_scan is recommended to prevent potential GPU segmentation faults, and PairOutputIterator should implement operator-(size_t) for consistency on 64-bit platforms.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| void segmented_sort(Key *keys, int32_t count, const int32_t *segments, | ||
| int32_t num_segments, Comp comp, context_t &ctx) { | ||
| if (count <= 0) return; | ||
| std::vector<int32_t> off = | ||
| shim_internal::SegmentsToHost(ctx, segments, num_segments); | ||
| auto policy = thrust::hip::par.on(ctx.stream()); | ||
| for (int32_t s = 0; s < num_segments; ++s) { | ||
| int32_t begin = off[s], end = off[s + 1]; | ||
| if (end - begin > 1) | ||
| thrust::stable_sort(policy, shim_internal::dptr(keys) + begin, | ||
| shim_internal::dptr(keys) + end, comp); | ||
| } | ||
| } | ||
|
|
||
| template <typename Key, typename Index, typename Comp> | ||
| void segmented_sort_indices(Key *keys, Index *indices, int32_t count, | ||
| const int32_t *segments, int32_t num_segments, | ||
| Comp comp, context_t &ctx) { | ||
| if (count <= 0) return; | ||
| // moderngpu's segmented_sort_indices fills `indices` with the GLOBAL identity | ||
| // permutation (0..count-1) and then stable-sorts each segment's slice | ||
| // alongside the keys, so afterwards indices[p] is the original global index of | ||
| // the element now at p. k2 relies on this (it does NOT pre-seed `indices`, | ||
| // unlike mergesort which seeds with Range()): PruneRaggedAxis1 reads | ||
| // order_map[idx01] as a global original index. Seed the identity here. | ||
| auto policy = thrust::hip::par.on(ctx.stream()); | ||
| thrust::sequence(policy, shim_internal::dptr(indices), | ||
| shim_internal::dptr(indices) + count, Index(0)); | ||
| std::vector<int32_t> off = | ||
| shim_internal::SegmentsToHost(ctx, segments, num_segments); | ||
| for (int32_t s = 0; s < num_segments; ++s) { | ||
| int32_t begin = off[s], end = off[s + 1]; | ||
| if (end - begin > 1) | ||
| thrust::stable_sort_by_key(policy, shim_internal::dptr(keys) + begin, | ||
| shim_internal::dptr(keys) + end, | ||
| shim_internal::dptr(indices) + begin, comp); | ||
| } | ||
| } |
There was a problem hiding this comment.
The current implementation of segmented_sort and segmented_sort_indices copies the segment offsets to the host using SegmentsToHost, which calls hipStreamSynchronize. This introduces a major performance bottleneck by blocking the host CPU thread and destroying stream concurrency. Furthermore, launching thrust::stable_sort in a host-side loop num_segments times causes massive driver overhead and poor GPU utilization when num_segments is large.
To avoid host-side loops and stream synchronizations, you can perform a single global sort on the GPU:
- Use
ComputeRowIds(which is already implemented) to compute the segment ID for each element on the GPU. - Perform a global
thrust::stable_sortorthrust::stable_sort_by_keyusing a custom comparator that first compares the segment IDs, and if they are equal, compares the keys using the user-providedcomp.
There was a problem hiding this comment.
Valid point on the host sync and the per-segment loop. I'm keeping the current implementation in this PR: it's correct and validated, and the goal here is AMD-enablement / functional parity rather than perf-tuning. The global stable_sort_by_key with a composite (row_id, key) comparator via ComputeRowIds is the right optimization, but it's a non-trivial rewrite of a validated kernel -- the composite comparator has to reproduce the per-segment ordering and stability exactly -- so it's better done and benchmarked as a separate change. Tracked as a follow-up; not blocking this PR.
| if (count <= 0) { | ||
| if (count == 0) | ||
| K2_CHECK_EQ(hipMemsetAsync(reduction_out, 0, sizeof(T), ctx.stream()), | ||
| hipSuccess); | ||
| return; | ||
| } | ||
| // values: count+1 transformed inputs (index `count` is a readable don't-care). | ||
| shim_internal::DeviceScratch values_buf(ctx.k2_context, | ||
| (count + 1) * sizeof(T)); | ||
| T *values = values_buf.as<T>(); | ||
| int32_t grid = (count + shim_internal::kBlockSize - 1) / | ||
| shim_internal::kBlockSize; | ||
| shim_internal::MaterializeKernel<T, Lambda> | ||
| <<<grid, shim_internal::kBlockSize, 0, ctx.stream()>>>(count, f, values); | ||
|
|
||
| size_t temp_bytes = 0; | ||
| K2_CHECK_EQ(hipcub::DeviceScan::ExclusiveSum(nullptr, temp_bytes, values, | ||
| output, count + 1, ctx.stream()), | ||
| hipSuccess); | ||
| shim_internal::DeviceScratch temp_buf(ctx.k2_context, temp_bytes); | ||
| K2_CHECK_EQ( | ||
| hipcub::DeviceScan::ExclusiveSum(temp_buf.data, temp_bytes, values, output, | ||
| count + 1, ctx.stream()), | ||
| hipSuccess); | ||
| // output[count] now holds the total; reduction_out is output + count for k2's | ||
| // K2_TRANS_EXCSUM, so it is already populated, but write it explicitly in case | ||
| // reduction_out aliases elsewhere. | ||
| if (reduction_out != output + count) | ||
| K2_CHECK_EQ(hipMemcpyAsync(reduction_out, output + count, sizeof(T), | ||
| hipMemcpyDeviceToDevice, ctx.stream()), | ||
| hipSuccess); |
There was a problem hiding this comment.
In transform_scan, if reduction_out is nullptr, calling hipMemsetAsync or hipMemcpyAsync on it will cause a GPU segmentation fault or runtime error. Adding a null check for reduction_out ensures defensive programming and robustness.
if (count <= 0) {
if (count == 0 && reduction_out != nullptr)
K2_CHECK_EQ(hipMemsetAsync(reduction_out, 0, sizeof(T), ctx.stream()),
hipSuccess);
return;
}
// values: count+1 transformed inputs (index `count` is a readable don't-care).
shim_internal::DeviceScratch values_buf(ctx.k2_context,
(count + 1) * sizeof(T));
T *values = values_buf.as<T>();
int32_t grid = (count + shim_internal::kBlockSize - 1) /
shim_internal::kBlockSize;
shim_internal::MaterializeKernel<T, Lambda>
<<<grid, shim_internal::kBlockSize, 0, ctx.stream()>>>(count, f, values);
size_t temp_bytes = 0;
K2_CHECK_EQ(hipcub::DeviceScan::ExclusiveSum(nullptr, temp_bytes, values,
output, count + 1, ctx.stream()),
hipSuccess);
shim_internal::DeviceScratch temp_buf(ctx.k2_context, temp_bytes);
K2_CHECK_EQ(
hipcub::DeviceScan::ExclusiveSum(temp_buf.data, temp_bytes, values, output,
count + 1, ctx.stream()),
hipSuccess);
// output[count] now holds the total; reduction_out is output + count for k2's
// K2_TRANS_EXCSUM, so it is already populated, but write it explicitly in case
// reduction_out aliases elsewhere.
if (reduction_out != nullptr && reduction_out != output + count)
K2_CHECK_EQ(hipMemcpyAsync(reduction_out, output + count, sizeof(T),
hipMemcpyDeviceToDevice, ctx.stream()),
hipSuccess);There was a problem hiding this comment.
In k2's call sites reduction_out is always output + count (K2_TRANS_EXCSUM passes output + count), so neither the count==0 memset nor the final memcpy ever receives a null pointer -- the guard would be purely defensive against a caller that doesn't exist here. I'd rather not add a speculative null-check no path exercises; it can go in if a real caller can pass null. Leaving as-is for now.
| __host__ __device__ __forceinline__ PairOutputIterator | ||
| operator-(int32_t offset) const { | ||
| return PairOutputIterator{i_ - offset}; | ||
| } |
There was a problem hiding this comment.
For consistency with the operator+(size_t) overload, PairOutputIterator should also implement operator-(size_t) to prevent potential compilation errors on 64-bit platforms where size_t offsets are used with subtraction.
__host__ __device__ __forceinline__ PairOutputIterator
operator-(int32_t offset) const {
return PairOutputIterator{i_ - offset};
}
__host__ __device__ __forceinline__ PairOutputIterator
operator-(size_t offset) const {
return PairOutputIterator{i_ - offset};
}There was a problem hiding this comment.
The build compiles without operator-(size_t) -- hipCUB doesn't instantiate a size_t subtraction on this iterator in any path k2 uses -- so this isn't a current compile error, just a symmetry gap with operator+(size_t). Leaving it for now; it's a trivial add if a future path needs it.
|
Thanks for the review. None of these are blocking for this AMD-enablement PR; details in the inline replies. In short: the |
Brings the files added or modified by the ROCm port into line with the project's cpplint style: wraps the long comments and a few statements to 80 columns, adds the two include-what-you-use headers (<tuple>, <utility>) the new moderngpu_shim.h uses, and marks the third-party thrust includes that cpplint misclassifies as C system headers with NOLINT(build/include_order) (matching the NOLINT convention already used elsewhere in k2, e.g. for cub/regex/mutex). Only lines introduced by the port are touched; no pre-existing upstream code changed. This is a style-only change with no effect on generated code. Note: the style_check CI job is currently red for an unrelated, upstream reason -- its actions/setup-python@v1 step cannot find Python 3.9 on the current ubuntu-24 runner image (the same job fails on upstream master), so flake8/cpplint never actually run there. This commit keeps the port cpplint-clean for when that job is repaired. Authored with assistance from Claude.
|
Thanks for the AMD support! @csukuangfj would you mind taking a quick look at this before we merge-- IDK if anything might have to be done regarding the python packaging. |
The HIP build pinned CMAKE_HIP_ARCHITECTURES to gfx90a before project()
enables the HIP language, which preempts CMake's own host-GPU detection.
A user on a non-gfx90a AMD GPU (e.g. gfx1100/gfx1201) who does not pass
-DCMAKE_HIP_ARCHITECTURES would silently build gfx90a code objects that
fail to load on their card at runtime ("no kernel image").
Removing the pin lets the project() HIP-language probe honor an explicit
-DCMAKE_HIP_ARCHITECTURES, otherwise auto-detect the host GPU via
rocm_agent_enumerator, and otherwise error out (the desired safety net on
a build host with no GPU). The arch propagates to targets unchanged via
the per-target HIP_ARCHITECTURES property.
Test Plan: explicit-arch builds are unaffected (the pin only applied when
CMAKE_HIP_ARCHITECTURES was unset), so a build with
-DCMAKE_HIP_ARCHITECTURES=gfx90a is byte-identical before and after.
Authored with assistance from Claude.
Sure, I am looking into it. |
csukuangfj
left a comment
There was a problem hiding this comment.
Thank you for your contribution! The changes look great to me.
I am trying to build k2 with AMD GPU support. Please help test it when it is released.
|
Can you have a look at #1356 and help test the provided wheels? |
k2 previously had no AMD path. This adds a from-scratch CUDA-to-HIP port that builds and runs on ROCm. k2 is a CMake project that consumes a PyTorch build but does NOT use torch's build-time hipify (setup.py shells out to cmake), so the port enables HIP directly:
enable_language(HIP), one force-included compat header (k2/csrc/cuda_to_hip.h) that aliases the ~30cudaXxxruntime symbols and maps cub -> hipcub, and the existing.cusources markedLANGUAGE HIP. We have made every effort to leave the NVIDIA build unchanged: every change is underK2_WITH_HIP/USE_HIP, and the compat header is never compiled on the CUDA path.This work was authored with the assistance of Claude (Anthropic), an AI coding assistant, and validated on real AMD GPU hardware (see Test Plan).
The substance is replacement of two non-portable third-party deps, not symbol renaming. Review in this order:
moderngpu is not portable to ROCm (its
intrinsics.hxx#errors under non-nvcc, hardcodes a 32-lane warp, and uses inline PTX). On the HIP build it is not compiled. The small surface k2 uses (transform_lbs,mergesort,segmented_sort[_indices],load_balance_search,sorted_search,transform_scan, plus the allocator base) is reimplemented ink2/csrc/moderngpu_shim.hwith the exactmgpu::-shaped API, backed by rocThrust and a few small kernels, so every call site compiles unchanged.moderngpu.hincludes the shim on HIP, andmoderngpu_allocator.cuis replaced bymoderngpu_allocator_hip.cu.The vendored CUDPP segmented scan is warp-synchronous (
WARP_SIZE=32, two 32-lane warps per block) and so is incorrect on a 64-lane wavefront.cudpp.cuis dropped on HIP;cudpp/cudpp_hip.cureimplementsSegmentedExclusiveSumwith hipCUB (inclusive-scan the head flags into a monotone segment key, thenDeviceScan::ExclusiveScanByKey). No warp arithmetic; correct on both wave32 and wave64.The rest is build wiring (the
K2_WITH_HIPoption, relocatable device code viaHIP_SEPARABLE_COMPILATION,.cumarkedLANGUAGE HIP,hip::hostlinkage), thecub.h/rand.culibrary swaps to hipCUB/hipRAND, and a handful of clang/HIP and C++20 fixes.<cuda/std/*>is supplied by the vendored ROCm fork of libcu++ on the include path. On Windows the torch import libraries are linked viafile(GLOB)over the torch lib dir (clang+HIP and MSVC+CUDA alike), because CMake 4.x does not expand imported SHARED targets in the Ninja/HIP link rules; the Linux path keeps the cmake imported targets.The port is wave-size aware: the cooperative-groups sub-warp tiles k2 uses are all <= 32 lanes (safe on wave64 and wave32), and the two replaced primitives use hipCUB/rocThrust, which are wave-agnostic. Several clang/HIP and C++20 specifics are handled: a
__host__ __device__function is preprocessed once in the host pass under clang, so device-vs-host dispatch keys onK2_DEVICE_CODE(__HIP_DEVICE_COMPILE__) rather than a bare#ifdef __CUDA_ARCH__; the ROCm torch forces C++20, where a user-declared (even= default) special member disqualifies aggregate initialization, soRaggedShapeLayer's defaulted members are left implicit; and the_k2pybind module is builtNO_EXTRASwith IPO off so HIP LTO does not stripPyInit__k2. On Windows the LLP64 ABI required__builtin_clzllfor 64-bit bit counting.The c10 device namespace is selected by the torch source-hipify generation, not by the OS. torch's hipify has two generations that disagree on the c10 device classes: generation 1 renamed them (only
c10::hipexists), while generation 2 (pytorch#174087) stops renaming so the CUDA spellingc10::cudais the public masquerading API andc10::hipsurvives only as thin wrappers. Because k2 never runs torch source-hipify,cmake/torch.cmakeprobestorch.utils.hipify.__version__at build time and definesTORCH_HIPIFY_V2when it is >= 2.0.0; the c10 call sites usec10::hipunder hipify v1 andc10::cudaunder v2 (or a pure CUDA build). Keying on the hipify generation rather than the OS keeps the selection correct when a platform pairs with either generation.Known limitations (not part of this change)
k2/torch, the standalone libtorch C++ decoder layer, is not yet built on the ROCm path. It pulls in kaldifeat (a separate CUDA project). The Python_k2module and the C++ gtest suite -- the FSA core that icefall and sherpa consume -- are built and validated.Documentation for the ROCm build is added to the Sphinx install guide alongside the existing CUDA instructions.
Test Plan
Validated on real GPUs on three AMD architectures: gfx90a (MI250X, CDNA2, wave64), gfx1100 (Radeon Pro W7800, RDNA3, wave32), and gfx1201 (Radeon RX 9070 XT, RDNA4, wave32).
Build (Linux, gfx90a, against a ROCm PyTorch, C++20):
C++ gtests (each compares CPU vs GPU internally), one GPU, serial:
Result: 30/30
cu_*_testexecutables pass (298 individual tests), including the replacement coverage --SegmentedExclusiveSum(CUDPP replacement), Prune+SortSublists (segmented_sort_indices),transform_lbs/mergesort/transform_scan, the cooperative-groups sub-warp-tile kernels on wave64, and hipRAND.Python integration slice on GPU:
Result: all pass except two device-independent pre-existing artifacts that reproduce on a CUDA build with the same torch -- the pickle/setstate tests (torch 2.6+
weights_only=Truedefault refuses the custom RaggedTensor global) and one float32-onlynormalize_scorestolerance case (catastrophic cancellation in10 - log(exp(2)+exp(10)); GPU float64 is exact).Followers (gfx1100/gfx1201) build with only
-DCMAKE_HIP_ARCHITECTURES=<arch>changed and pass at identical counts.