Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 55 additions & 14 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,37 +1,78 @@
cmake_minimum_required(VERSION 3.24)
set(CMAKE_CUDA_COMPILER "/opt/cuda/bin/nvcc")
set(CUDACXX "/opt/cuda/bin/nvcc")
project(tiny-vllm LANGUAGES CXX CUDA)

option(USE_HIP "Build with HIP for AMD GPUs" OFF)

if(NOT USE_HIP)
set(CMAKE_CUDA_COMPILER "/opt/cuda/bin/nvcc")
set(CUDACXX "/opt/cuda/bin/nvcc")
endif()

if(USE_HIP)
project(tiny-vllm LANGUAGES CXX HIP)
else()
project(tiny-vllm LANGUAGES CXX CUDA)
endif()

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)

set(CMAKE_CUDA_ARCHITECTURES 120)
if(USE_HIP)
# HIP architecture: default to gfx90a if not specified
if(NOT DEFINED CMAKE_HIP_ARCHITECTURES OR CMAKE_HIP_ARCHITECTURES STREQUAL "")
set(CMAKE_HIP_ARCHITECTURES "gfx90a")
endif()
else()
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_ARCHITECTURES 120)
endif()

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
endif()

set(CMAKE_CXX_FLAGS_RELEASE "-O2")
set(CMAKE_CUDA_FLAGS_RELEASE "-O2")
if(USE_HIP)
set(CMAKE_HIP_FLAGS_RELEASE "-O2")
set(CMAKE_HIP_FLAGS_DEBUG "-g -DDEBUG")
else()
set(CMAKE_CUDA_FLAGS_RELEASE "-O2")
set(CMAKE_CUDA_FLAGS_DEBUG "-G -g -DDEBUG")
endif()

set(CMAKE_CUDA_FLAGS_DEBUG "-G -g -DDEBUG")
set(CMAKE_CXX_FLAGS_DEBUG "-g -DDEBUG")

find_package(CUDAToolkit REQUIRED)
if(USE_HIP)
find_package(hipblas REQUIRED)
find_package(hip REQUIRED)
else()
find_package(CUDAToolkit REQUIRED)
endif()

add_executable(tiny-vllm
src/main.cpp
src/kernels.cu
)

if(USE_HIP)
# Both main.cpp and kernels.cu need HIP compilation because they use bfloat16 types
# which require the HIP compiler (hip_bf16.h uses clang-specific builtins)
set_source_files_properties(src/main.cpp src/kernels.cu PROPERTIES LANGUAGE HIP)
target_compile_definitions(tiny-vllm PRIVATE USE_HIP)
set_target_properties(tiny-vllm PROPERTIES HIP_ARCHITECTURES "${CMAKE_HIP_ARCHITECTURES}")
endif()

target_include_directories(tiny-vllm PRIVATE src)
target_include_directories(tiny-vllm PRIVATE include)

target_link_libraries(tiny-vllm PRIVATE
CUDA::cublas
CUDA::cudart
)

if(USE_HIP)
target_link_libraries(tiny-vllm PRIVATE
hip::host
roc::hipblas
)
else()
target_link_libraries(tiny-vllm PRIVATE
CUDA::cublas
CUDA::cudart
)
endif()
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ The exact setup on which I develop and test it:

Install the dependencies and run the program with `./test.sh` - it will build it and immediately execute it

It also runs on AMD GPUs through ROCm/HIP. Pass `-DUSE_HIP=ON` to CMake and it builds with hipcc against hipBLAS instead of nvcc and cuBLAS; the CUDA sources are reused as-is through a thin `src/cuda_to_hip.h` compatibility header. Pick your GPU's architecture with `-DCMAKE_HIP_ARCHITECTURES` (for example `gfx90a` for MI200, `gfx1100` for RDNA3, `gfx1201` for RDNA4) - it is not hardcoded, so set it to match your card:

```bash
cmake -B build -DUSE_HIP=ON -DCMAKE_HIP_ARCHITECTURES=gfx1100 -DCMAKE_PREFIX_PATH=/opt/rocm -G Ninja
cmake --build build
```

The `-DCMAKE_PREFIX_PATH=/opt/rocm` lets CMake find the hip and hipBLAS packages; drop it if `/opt/rocm/bin` is already on your `PATH`, or change it if ROCm lives elsewhere. I tested the AMD path on gfx90a, gfx1100, and gfx1201. The default build (no `-DUSE_HIP`) is unchanged and still targets NVIDIA through CUDA.

If you fail to build or run it and your AI of choice won't be able to help, please open an Issue on GitHub - I will try to help. Make sure to provide all useful context

## Safetensors and your model
Expand Down
64 changes: 64 additions & 0 deletions src/cuda_to_hip.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#pragma once

// Copyright (c) 2026 Advanced Micro Devices, Inc.
// Author: Jeff Daily <jeff.daily@amd.com>
//
// CUDA-to-HIP compatibility header for tiny-vllm
// Keeps CUDA spellings in source and aliases them to HIP on AMD GPUs

#if defined(USE_HIP) || defined(__HIP_PLATFORM_AMD__)

#include <hip/hip_runtime.h>
#include <hip/hip_bf16.h>
#include <hipblas/hipblas.h>

// bfloat16 type mappings
#define __nv_bfloat16 __hip_bfloat16
#define nv_bfloat16 __hip_bfloat16

// CUDA runtime -> HIP runtime
#define cudaMalloc hipMalloc
#define cudaFree hipFree
#define cudaMemcpy hipMemcpy
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaGetLastError hipGetLastError
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaMemGetInfo hipMemGetInfo
#define cudaDeviceProp hipDeviceProp_t
#define cudaError hipError_t
#define cudaError_t hipError_t
#define cudaSuccess hipSuccess

// cuBLAS -> hipBLAS
#define cublasHandle_t hipblasHandle_t
#define cublasStatus_t hipblasStatus_t
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT

// Data types for GEMM
#define CUDA_R_16BF HIP_R_16BF

// Warp shuffle mask for HIP (64-bit required)
// HIP requires 64-bit masks for __shfl_* functions
#define WARP_FULL_MASK 0xffffffffffffffffULL

#else

#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <cublas_v2.h>

// On CUDA, use the standard 32-bit mask
#define WARP_FULL_MASK 0xffffffff

#endif
11 changes: 6 additions & 5 deletions src/kernels.cu
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "cuda_to_hip.h"
#include "kernels.cuh"
#include <iostream>

Expand Down Expand Up @@ -407,11 +408,11 @@ __global__ void pagedAttentionKernel(int layer, int num_active_slots, __nv_bfloa
float qk = (float)q * (float)*k;
// tree reduction within current warp, thread 0 gets sum of all 32 elements within warp
// could be done with __syncthreads but accessing memory of other threads in warp is op
qk += __shfl_down_sync(0xffffffff, qk, 16);
qk += __shfl_down_sync(0xffffffff, qk, 8);
qk += __shfl_down_sync(0xffffffff, qk, 4);
qk += __shfl_down_sync(0xffffffff, qk, 2);
qk += __shfl_down_sync(0xffffffff, qk, 1);
qk += __shfl_down_sync(WARP_FULL_MASK, qk, 16);
qk += __shfl_down_sync(WARP_FULL_MASK, qk, 8);
qk += __shfl_down_sync(WARP_FULL_MASK, qk, 4);
qk += __shfl_down_sync(WARP_FULL_MASK, qk, 2);
qk += __shfl_down_sync(WARP_FULL_MASK, qk, 1);
if (thread_id == 0)
{
dot_products[0] = qk;
Expand Down
7 changes: 7 additions & 0 deletions src/kernels.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
#pragma once

#if defined(USE_HIP) || defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_bf16.h>
#define __nv_bfloat16 __hip_bfloat16
#define nv_bfloat16 __hip_bfloat16
#else
#include <cuda_bf16.h>
#endif

// prefill
void embeddingGather(int *gpu_input_tokens, __nv_bfloat16 *gpu_input_embeds, __nv_bfloat16 *embed_tokens, int num_input_tokens);
Expand Down
3 changes: 1 addition & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#include <iostream>
#include <numeric>
#include <fstream>
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include "cuda_to_hip.h"
#include <queue>
#define JSON_USE_IMPLICIT_CONVERSIONS 0
#include "json.hpp"
Expand Down