Skip to content

Build wheels for AMD GPU with ROCm#1356

Open
csukuangfj wants to merge 15 commits into
k2-fsa:masterfrom
csukuangfj:ci-rocm
Open

Build wheels for AMD GPU with ROCm#1356
csukuangfj wants to merge 15 commits into
k2-fsa:masterfrom
csukuangfj:ci-rocm

Conversation

@csukuangfj

Copy link
Copy Markdown
Collaborator

csukuangfj and others added 15 commits June 26, 2026 12:30
Co-Authored-By: Claude <noreply@anthropic.com>
- Cast cudaDeviceSynchronize()/hipDeviceSynchronize() to (void) in
  K2_CUDA_SAFE_CALL to suppress -Wunused-result warnings on HIP builds
- Add ROCm version, HIP version, and kWithHip to version.h.in
- Detect ROCM_VERSION from hip_VERSION cmake package or env var
- Detect TORCH_HIP_VERSION from torch.version.hip

Co-Authored-By: Claude <noreply@anthropic.com>
- version.cu: expose rocm_version, torch_hip_version, with_hip
- version.py: print ROCm version, HIP version, and with_hip flag

Co-Authored-By: Claude <noreply@anthropic.com>
The hipcub functions (e.g. DeviceScan::ExclusiveScan) are [[nodiscard]].
Cast the whole expression to (void) to suppress the warning.

Co-Authored-By: Claude <noreply@anthropic.com>
#pragma unroll(N) with parentheses triggers -Wcuda-compat warning.
Use #pragma unroll N instead.

Co-Authored-By: Claude <noreply@anthropic.com>
rocprim/hipcub requires iterator operators (operator+, operator[],
operator()) to be callable from __host__ context during template
instantiation. Change __device__ __forceinline__ to K2_CUDA_HOSTDEV
on RowSplitsDiff, HashInputIterator, PairInputIterator, and
HashCombineOp.

Co-Authored-By: Claude <noreply@anthropic.com>

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces ROCm/HIP support to the k2 library, updating build configurations, documentation, versioning scripts, and CI workflows to support AMD GPUs. It also refactors CUDA kernels and macros for compatibility with HIP. The review feedback highlights a potential AttributeError in get_version.py if the ROCm version cannot be resolved, and recommends pinning the libhipcxx repository to a specific tag in the build script to ensure reproducible builds.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread get_version.py
Comment on lines +98 to +100
rocm_version = get_rocm_version()
# Keep only major.minor (e.g., 7.1.52802 -> 7.1)
rocm_version = '.'.join(rocm_version.split('.')[:2])

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If get_rocm_version() returns None (which can happen if is_rocm() is True due to K2_WITH_HIP=ON in K2_CMAKE_ARGS but the ROCm version cannot be detected), calling rocm_version.split('.') will raise an AttributeError: 'NoneType' object has no attribute 'split'. We should handle the case where rocm_version is None to prevent build crashes.

Suggested change
rocm_version = get_rocm_version()
# Keep only major.minor (e.g., 7.1.52802 -> 7.1)
rocm_version = '.'.join(rocm_version.split('.')[:2])
rocm_version = get_rocm_version()
if rocm_version:
# Keep only major.minor (e.g., 7.1.52802 -> 7.1)
rocm_version = '.'.join(rocm_version.split('.')[:2])
else:
rocm_version = 'unknown'


# Install libhipcxx (provides <cuda/std/*> headers needed by k2's HIP build)
echo "Installing libhipcxx..."
git clone --depth 1 https://github.com/ROCm/libhipcxx.git /tmp/libhipcxx

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Cloning the libhipcxx repository without specifying a tag or commit hash can lead to non-reproducible builds if upstream changes break compatibility. It is highly recommended to pin the repository to a specific stable release tag or commit hash.

Suggested change
git clone --depth 1 https://github.com/ROCm/libhipcxx.git /tmp/libhipcxx
git clone --depth 1 --branch rocm-6.3.0 https://github.com/ROCm/libhipcxx.git /tmp/libhipcxx

@jeffdaily

Copy link
Copy Markdown
Contributor

Tested on an AMD Instinct MI250X (gfx90a), ROCm 7.2.1, Python 3.10, torch 2.12.1+rocm7.2. The wheel installs and runs correctly: arc_sort, get_forward_scores, and intersect_dense + shortest_path (CTC) all execute on-device with correct results. k2.with_cuda is True.

Two notes:

1. The documented pip command fails. This, from the installation page:

pip install k2==1.24.4.dev20260626+rocm7.2.torch2.12.1 -f https://k2-fsa.github.io/k2/rocm.html

errors with inconsistent version: expected '1.24.4.dev20260626+rocm7.2.torch2.12.1', but metadata has '1.24.4.dev20260626+rocm7.2.53211.torch2.12.1'. The wheel's internal metadata version embeds the full ROCm patch (rocm7.2.53211) while the filename truncates to rocm7.2, and pip rejects the mismatch. Downloading the wheel file and pip install ./<wheel> (the direct-download path in the docs) works. Worth aligning the filename and the metadata version so the documented one-liner works.

2. A small FYI on the arch list (not a problem, just an option). The arch selection in scripts/github_actions/build-ubuntu-rocm.sh is clearly deliberate -- enumerating the gfx targets from the ROCm support docs is a sound, well-reasoned way to do it, and the wheel runs correctly on every arch it ships. So this is more of a "did you know" than a suggested change.

If it's ever useful to have the list stay in lockstep with the torch build the wheel is paired against, torch.cuda.get_arch_list() returns exactly the gfx targets the installed ROCm torch wheel was compiled for:

HIP_ARCH=$(python -c "import torch; print(';'.join(a for a in torch.cuda.get_arch_list() if a.startswith('gfx')))")
cmake -DCMAKE_HIP_ARCHITECTURES="$HIP_ARCH" ...

For reference, on this build the two lists differ only by gfx900/gfx906 (old Vega) -- the torch wheel still includes them, while the k2 list (reasonably) tracks the current ROCm support matrix instead. Either policy is defensible; deriving from torch is just one way to avoid maintaining the list by hand as new arches land.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants